Merge remote-tracking branch 'gitlab/main' into user/xiweny/merge_main_0819

Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com>
This commit is contained in:
Xiwen Yu 2025-08-23 16:13:30 +08:00
commit 808059da34
191 changed files with 5878 additions and 2986 deletions

2
.gitattributes vendored
View File

@ -7,3 +7,5 @@
triton_backend/tools/gpt/input_data.json filter=lfs diff=lfs merge=lfs -text
*cubin.cpp filter=lfs diff=lfs merge=lfs -text
docs/source/blogs/media/tech_blog3_mla_absorb.png filter=lfs diff=lfs merge=lfs -text
tests/integration/test_input_files/*.png filter=lfs diff=lfs merge=lfs -text
tests/integration/test_input_files/*.jpg filter=lfs diff=lfs merge=lfs -text

View File

@ -9,7 +9,7 @@ TensorRT-LLM
[![python](https://img.shields.io/badge/python-3.10-green)](https://www.python.org/downloads/release/python-31012/)
[![cuda](https://img.shields.io/badge/cuda-12.9.1-green)](https://developer.nvidia.com/cuda-downloads)
[![trt](https://img.shields.io/badge/TRT-10.11.0-green)](https://developer.nvidia.com/tensorrt)
[![version](https://img.shields.io/badge/release-1.1.0rc1-green)](./tensorrt_llm/version.py)
[![version](https://img.shields.io/badge/release-1.1.0rc2-green)](./tensorrt_llm/version.py)
[![license](https://img.shields.io/badge/license-Apache%202-blue)](./LICENSE)
[Architecture](./docs/source/torch/arch_overview.md)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Performance](./docs/source/performance/perf-overview.md)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Examples](https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Documentation](./docs/source/)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Roadmap](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue%20state%3Aopen%20label%3Aroadmap)
@ -18,10 +18,9 @@ TensorRT-LLM
<div align="left">
## Tech Blogs
* [08/06] Running a High Performance GPT-OSS-120B Inference Server with TensorRT-LLM
* [08/05] Running a High-Performance GPT-OSS-120B Inference Server with TensorRT-LLM
✨ [➡️ link](./docs/source/blogs/tech_blog/blog9_Deploying_GPT_OSS_on_TRTLLM.md)
* [08/01] Scaling Expert Parallelism in TensorRT-LLM (Part 2: Performance Status and Optimization)
✨ [➡️ link](./docs/source/blogs/tech_blog/blog8_Scaling_Expert_Parallelism_in_TensorRT-LLM_part2.md)
@ -44,6 +43,7 @@ TensorRT-LLM
✨ [➡️ link](./docs/source/blogs/tech_blog/blog1_Pushing_Latency_Boundaries_Optimizing_DeepSeek-R1_Performance_on_NVIDIA_B200_GPUs.md)
## Latest News
* [08/05] 🌟 TensorRT-LLM delivers Day-0 support for OpenAI's latest open-weights models: GPT-OSS-120B [➡️ link](https://huggingface.co/openai/gpt-oss-120b) and GPT-OSS-20B [➡️ link](https://huggingface.co/openai/gpt-oss-20b)
* [07/15] 🌟 TensorRT-LLM delivers Day-0 support for LG AI Research's latest model, EXAONE 4.0 [➡️ link](https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B)
* [06/17] Join NVIDIA and DeepInfra for a developer meetup on June 26 ✨ [➡️ link](https://events.nvidia.com/scaletheunscalablenextgenai)
* [05/22] Blackwell Breaks the 1,000 TPS/User Barrier With Metas Llama 4 Maverick
@ -253,5 +253,5 @@ Deprecation is used to inform developers that some APIs and tools are no longer
## Useful Links
- [Quantized models on Hugging Face](https://huggingface.co/collections/nvidia/model-optimizer-66aa84f7966b3150262481a4): A growing collection of quantized (e.g., FP8, FP4) and optimized LLMs, including [DeepSeek FP4](https://huggingface.co/nvidia/DeepSeek-R1-FP4), ready for fast inference with TensorRT-LLM.
- [NVIDIA Dynamo](https://github.com/ai-dynamo/dynamo): A datacenter scale distributed inference serving framework that works seamlessly with TensorRT-LLM.
- [AutoDeploy](./examples/auto_deploy/README.md): A prototype backend for TensorRT-LLM to simplify and accelerate the deployment of PyTorch models.
- [AutoDeploy](https://nvidia.github.io/TensorRT-LLM/torch/auto_deploy/auto-deploy.html): A prototype backend for TensorRT-LLM to simplify and accelerate the deployment of PyTorch models.
- [WeChat Discussion Group](https://github.com/NVIDIA/TensorRT-LLM/issues/5359): A real-time channel for TensorRT-LLM Q&A and news.

View File

@ -69,7 +69,7 @@ add_compile_definitions("TLLM_GEN_EXPORT_INTERFACE")
add_compile_definitions("TLLM_ENABLE_CUDA")
set(BINDING_TYPE
"pybind"
"nanobind"
CACHE STRING
"Binding type of Python bindings for C++ runtime and batch manager")

View File

@ -24,7 +24,6 @@
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/modelConfig.h"
#include "tensorrt_llm/runtime/request.h"
#include "tensorrt_llm/runtime/worldConfig.h"
namespace tensorrt_llm::runtime
@ -88,37 +87,6 @@ public:
SizeType32 maxSequenceLength, OptionalRef<MedusaBuffers const> medusaBuffers) const;
private:
//! @brief Setups decoder internal tensors for new speculative decoding request
static void newRequestSpeculativeDecoding(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
SamplingConfig const& samplingConfig, runtime::ModelConfig const& modelConfig,
DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream,
CudaStream const& decoderStream, SpeculativeDecodingMode const& speculativeDecodingMode,
SizeType32 maxDecodingEngineTokens);
//! @brief Setups decoder internal tensors for new request in Draft model Sps mode
static void newRequestDraftTokensExternal(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
SamplingConfig const& samplingConfig, DecodingInput& jointDecodingInput, CudaStream const& decoderStream);
//! @brief Setups decoder internal tensors for new Medusa request
static void newRequestMedusa(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
DecodingInput& jointDecodingInput, CudaStream const& decoderStream, SizeType32 maxDecodingEngineTokens);
//! @brief Setups decoder internal tensors for new Lookahead request
static void newRequestLookahead(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream);
//! @brief Setups decoder internal tensors for new Explicit draft tokens request
static void newRequestExplicitDraftTokens(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream);
//! @brief Setups decoder internal tensors for new Eagle request
static void newRequestEagle(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
runtime::ModelConfig const& modelConfig, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream);
[[nodiscard]] std::shared_ptr<runtime::ITensor> retrieveDraftLogits(runtime::ModelConfig const& modelConfig,
runtime::WorldConfig const& worldConfig, std::shared_ptr<runtime::ITensor> const& tensor,
runtime::BufferManager const& bufferManager) const;
bool mSpeculativeDecodingFastLogits;
bool mIsLeaderInOrchMode;
bool mIsNormalizeLogProbs;

View File

@ -1110,7 +1110,7 @@ public:
[[nodiscard]] SizeType32 getNumDraftTokens() const
{
return mDraftTokens->size();
return hasDraftTokens() ? mDraftTokens->size() : 0;
}
void discardDraftTokens(SizeType32 numTokensToDiscard)

View File

@ -54,20 +54,21 @@ public:
#if defined(_MSC_VER)
template <typename... Args>
void log(Level level, char const* format, Args const&... args);
void log(Level const level, char const* format, Args const&... args);
template <typename... Args>
void log(Level level, int rank, char const* format, Args const&... args);
void log(Level const level, int const rank, char const* format, Args const&... args);
#else
template <typename... Args>
void log(Level level, char const* format, Args const&... args) __attribute__((format(printf, 3, 0)));
void log(Level const level, char const* format, Args const&... args) __attribute__((format(printf, 3, 0)));
template <typename... Args>
void log(Level level, int rank, char const* format, Args const&... args) __attribute__((format(printf, 4, 0)));
void log(Level const level, int const rank, char const* format, Args const&... args)
__attribute__((format(printf, 4, 0)));
#endif
template <typename... Args>
void log(Level level, std::string const& format, Args const&... args)
void log(Level const level, std::string const& format, Args const&... args)
{
return log(level, format.c_str(), args...);
}
@ -134,7 +135,7 @@ private:
};
template <typename... Args>
void Logger::log(Logger::Level level, char const* format, Args const&... args)
void Logger::log(Logger::Level const level, char const* format, Args const&... args)
{
if (isEnabled(level))
{

View File

@ -52,29 +52,30 @@ public:
AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2)
: mModelConfig(std::move(modelConfig))
, mParallelConfig{worldConfig.getTensorParallelism(), worldConfig.getPipelineParallelism(),
worldConfig.enableAttentionDP(), worldConfig.getTensorParallelRank(), worldConfig.getTensorParallelism()}
worldConfig.getContextParallelism(), worldConfig.enableAttentionDP(), worldConfig.getTensorParallelRank(),
worldConfig.getTensorParallelism()}
, mDataType{dataType}
, mAttentionConfig(attentionType, kvFactor)
{
}
CacheState(std::vector<SizeType32> nbKvHeadPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
SizeType32 tensorParallelism, SizeType32 pipelineParallelism, nvinfer1::DataType dataType,
AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableAttentionDP = false,
int DPrank = 0, int DPsize = 0)
SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism,
nvinfer1::DataType dataType, AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2,
bool enableAttentionDP = false, int DPrank = 0, int DPsize = 0)
: mModelConfig{std::move(nbKvHeadPerLayer), sizePerHead, tokensPerBlock}
, mParallelConfig{tensorParallelism, pipelineParallelism, enableAttentionDP, DPrank, DPsize}
, mParallelConfig{tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize}
, mDataType{dataType}
, mAttentionConfig(attentionType, kvFactor)
{
}
CacheState(SizeType32 nbAttentionLayers, SizeType32 nbKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
SizeType32 tensorParallelism, SizeType32 pipelineParallelism, nvinfer1::DataType dataType,
AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableAttentionDP = false,
int DPrank = 0, int DPsize = 0)
SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism,
nvinfer1::DataType dataType, AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2,
bool enableAttentionDP = false, int DPrank = 0, int DPsize = 0)
: mModelConfig{std::vector(nbAttentionLayers, nbKvHeads), sizePerHead, tokensPerBlock}
, mParallelConfig{tensorParallelism, pipelineParallelism, enableAttentionDP, DPrank, DPsize}
, mParallelConfig{tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize}
, mDataType{dataType}
, mAttentionConfig(attentionType, kvFactor)
{
@ -83,7 +84,7 @@ public:
[[nodiscard]] bool operator==(kv_cache::CacheState const& other) const noexcept
{
return mModelConfig == other.mModelConfig && mParallelConfig == other.mParallelConfig
&& mDataType == other.mDataType;
&& mAttentionConfig == other.mAttentionConfig && mDataType == other.mDataType;
}
struct ModelConfig
@ -103,6 +104,7 @@ public:
{
SizeType32 mTensorParallelism;
SizeType32 mPipelineParallelism;
SizeType32 mContextParallelism;
bool mEnableAttentionDP;
SizeType32 mDPrank;
SizeType32 mDPsize;
@ -110,8 +112,8 @@ public:
[[nodiscard]] bool operator==(ParallelConfig const& other) const noexcept
{
return mTensorParallelism == other.mTensorParallelism && mPipelineParallelism == other.mPipelineParallelism
&& mEnableAttentionDP == other.mEnableAttentionDP && mDPrank == other.mDPrank
&& mDPsize == other.mDPsize;
&& mContextParallelism == other.mContextParallelism && mEnableAttentionDP == other.mEnableAttentionDP
&& mDPrank == other.mDPrank && mDPsize == other.mDPsize;
}
};
@ -125,6 +127,11 @@ public:
{
}
[[nodiscard]] bool operator==(AttentionConfig const& other) const noexcept
{
return mAttentionType == other.mAttentionType && mKvFactor == other.mKvFactor;
}
// attentionType ;
AttentionType mAttentionType;
int mKvFactor;
@ -162,6 +169,7 @@ public:
sstring << "mTokensPerBlock:" << mModelConfig.mTokensPerBlock << "\n";
sstring << "tp:" << mParallelConfig.mTensorParallelism << "\n";
sstring << "pp:" << mParallelConfig.mPipelineParallelism << "\n";
sstring << "cp:" << mParallelConfig.mContextParallelism << "\n";
sstring << "enableAttentionDP:" << mParallelConfig.mEnableAttentionDP << "\n";
sstring << "datatype:" << static_cast<int32_t>(mDataType) << "\n";
sstring << "attentionType:" << static_cast<int32_t>(mAttentionConfig.mAttentionType) << "\n";

View File

@ -102,11 +102,13 @@ public:
{
public:
TensorPtr draftLogits;
TensorPtr draftLogitsHost;
TensorPtr draftProbs;
TensorPtr targetProbs;
TensorPtr numDraftTokens;
TensorPtr numDraftTokensHost;
TensorPtr draftTokenIds;
TensorPtr draftTokenIdsHost;
TensorPtr useDraftLogits;
TensorPtr useDraftLogitsHost;

View File

@ -1,54 +0,0 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include <optional>
namespace tensorrt_llm::runtime::decoder_batch
{
class Request
{
public:
using TensorConstPtr = ITensor::SharedConstPtr;
using TensorPtr = ITensor::SharedPtr;
using BufferPtr = IBuffer::SharedPtr;
explicit Request(SizeType32 inputLen)
: inputLen(inputLen)
{
}
//! Mandatory parameters
SizeType32 inputLen; // Input length without draft tokens, increasing with generation steps
// optional parameters
SizeType32 generatedTokensPerEngineStep{1}; //
//! Optional parameters for speculative decoding
BufferPtr draftTokens; // [generatedTokensPerEngineStep - 1] on gpu
std::optional<TensorPtr> draftLogits; // [generatedTokensPerEngineStep - 1, vocabSize] on gpu
TensorPtr medusaPaths; // [maxDecodingTokens, maxPathLen], on gpu
TensorPtr medusaTreeIds; // [maxDecodingTokens], on gpu
std::optional<executor::LookaheadDecodingConfig> lookaheadRuntimeConfig;
std::optional<executor::EagleConfig> eagleConfig;
};
} // namespace tensorrt_llm::runtime::decoder_batch

View File

@ -1012,7 +1012,7 @@ CUBIN_EXPORT __global__
if (threadIdx.x < smem.gemm1AccColMax.size)
{
auto const idx = threadIdx.x;
smem.gemm1AccColMax[idx] = mha::numeric_limits<float>::lowest();
smem.gemm1AccColMax[idx] = safeInitRowMax;
smem.gemm1AccColSum[idx] = 0;
}
smem.gemm1WarpGrpBar.arrive_and_wait();
@ -1949,7 +1949,7 @@ __device__ inline void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec,
uint32_t const globalRow = tileStartRow + row;
if (globalRow >= cacheSeqLen)
{
acc(m, n)(i, j) = mha::numeric_limits<float>::lowest();
acc(m, n)(i, j) = safeInitRowMax;
continue;
}
if (globalRow >= maskStartRow)
@ -1957,7 +1957,7 @@ __device__ inline void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec,
uint32_t const maskRow = globalRow - maskStartRow;
if ((bit_mask >> maskRow) == 0)
{
acc(m, n)(i, j) = mha::numeric_limits<float>::lowest();
acc(m, n)(i, j) = safeInitRowMax;
}
}
}
@ -2087,7 +2087,7 @@ __device__ inline void warpGrpApplyMask(uint32_t warpRank, Gemm0Acc& acc, uint32
#pragma unroll
for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++)
{
acc(m, n)(i, j) = mha::numeric_limits<float>::lowest();
acc(m, n)(i, j) = safeInitRowMax;
}
}
}
@ -2380,9 +2380,9 @@ __device__ inline void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec,
{
uint32_t const col = GmmaAccCoreMat::cols * (4 * n + idxInQuad) + j;
assert((col < nbValidCols) == bool(endMask & (1ULL << col)));
if (((mask >> col) & 1) == 0)
if ((mask & (1ULL << col)) == 0)
{
acc(m, n)(i, j) = mha::numeric_limits<float>::lowest();
acc(m, n)(i, j) = safeInitRowMax;
}
}
}
@ -2410,7 +2410,7 @@ __device__ inline void warpGrpApplyMask(Gemm0Acc& acc, uint32_t validColBeg, uin
#pragma unroll
for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++)
{
acc(m, n)(i, j) = mha::numeric_limits<float>::lowest();
acc(m, n)(i, j) = safeInitRowMax;
}
}
}

View File

@ -833,7 +833,7 @@ public:
// Runs for 3 iterations or 1 second and picks the best option
int pickBestTactic(MOEParallelismConfig parallelism_config, GemmToProfile gemm_to_profile)
{
auto tactics = mMoERunner.getTactics();
auto tactics = mMoERunner.getTactics(static_cast<MoeGemmId>(gemm_to_profile));
::nvtx3::scoped_range nvtx(tensorrt_llm::common::nvtx::nextColor(),
"Tactic Profiling GEMM " + std::to_string(static_cast<int>(gemm_to_profile)));
// We save space by reusing the same workspace buffer for all tactics when doing full layer profiling. So we
@ -925,12 +925,14 @@ public:
std::pair<int, int> setTactic(
int tactic_idx1, int tactic_idx2, MOEParallelismConfig parallelism_config, GemmToProfile gemm_to_profile)
{
auto tactics = mMoERunner.getTactics();
auto tactics1 = mMoERunner.getTactics(MoeGemmId::GEMM_1);
auto tactics2 = mMoERunner.getTactics(MoeGemmId::GEMM_2);
std::vector<std::pair<std::reference_wrapper<int>, GemmToProfile>> tactics_to_profile{
{tactic_idx1, GemmToProfile::GEMM_1}, {tactic_idx2, GemmToProfile::GEMM_2}};
for (auto& combo : tactics_to_profile)
{
auto& t = combo.first.get();
auto& tactics = combo.second == GemmToProfile::GEMM_1 ? tactics1 : tactics2;
if (combo.second != gemm_to_profile && gemm_to_profile != GemmToProfile::LAYER)
{
t = 0; // Unneeded tactic, set to 0
@ -947,7 +949,7 @@ public:
}
}
mMoERunner.setTactic(tactics[tactic_idx1], tactics[tactic_idx2]);
mMoERunner.setTactic(tactics1[tactic_idx1], tactics2[tactic_idx2]);
mBestTacticGemm1 = tactic_idx1;
mBestTacticGemm2 = tactic_idx2;
return {tactic_idx1, tactic_idx2};
@ -965,7 +967,7 @@ public:
auto expert_weights_size
= gemm_to_profile == GemmToProfile::GEMM_1 ? mExpertWeight1Size : mExpertWeight2Size;
auto tactics = mMoERunner.getTactics()[tactic_idx];
auto tactics = mMoERunner.getTactics(static_cast<MoeGemmId>(gemm_to_profile))[tactic_idx];
if (static_cast<int>(gemm_to_profile) != static_cast<int>(mGemmProfilerBackend.mGemmToProfile))
{
throw std::runtime_error("Configuration mismatch between mGemmProfilerBackend and runMoEPermute");
@ -1074,11 +1076,12 @@ void MixtureOfExpertsBenchmark<TypeTuple_>::runBenchmark(benchmark::State& state
}
if (LOG_LEVEL >= INFO)
{
auto tactics = mMoERunner.getTactics();
std::cout << "Selected tactic #1: " << tactic_idx1 << "/" << tactics.size() << "\n"
<< tactics[tactic_idx1].toString() << std::endl;
std::cout << "Selected tactic #2: " << tactic_idx2 << "/" << tactics.size() << "\n"
<< tactics[tactic_idx2].toString() << std::endl;
auto tactics1 = mMoERunner.getTactics(MoeGemmId::GEMM_1);
auto tactics2 = mMoERunner.getTactics(MoeGemmId::GEMM_2);
std::cout << "Selected tactic #1: " << tactic_idx1 << "/" << tactics1.size() << "\n"
<< tactics1[tactic_idx1].toString() << std::endl;
std::cout << "Selected tactic #2: " << tactic_idx2 << "/" << tactics2.size() << "\n"
<< tactics2[tactic_idx2].toString() << std::endl;
}
state.counters["tactic_idx1"] = tactic_idx1;
state.counters["tactic_idx2"] = tactic_idx2;

View File

@ -42,148 +42,15 @@ struct WeightParams
->Apply(argGen<MixtureOfExpertsBenchmark<WeightParams<atype, wtype, otype>>>)
template <class BenchClass>
auto listAllTactics()
auto listAllTactics(MoeGemmId gemm_id)
{
int const sm = getSMVersion();
using RunnerType = decltype(BenchClass::mMoERunner);
return RunnerType::getTactics(sm);
return RunnerType::getTactics(sm, gemm_id);
}
template <class BenchClass>
int parseTacticToId(nlohmann::json tactic_config)
{
bool is_tma_warp_specialized = tactic_config.at("is_tma_warp_specialized").get<bool>();
int tile_shape_id = -1;
std::array<int, 3> tile_shape;
if (tactic_config.at("tile_shape").is_array())
tactic_config.at("tile_shape").get_to(tile_shape);
else
tile_shape_id = tactic_config.at("tile_shape").get<int>();
std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> confs = listAllTactics<BenchClass>();
try
{
for (int i = 0; i < confs.size(); i++)
{
auto const& c = confs[i];
if (c.is_tma_warp_specialized != is_tma_warp_specialized)
continue;
if (!is_tma_warp_specialized)
{
int stages = tactic_config.at("stages").get<int>();
if (c.stages != stages)
continue;
}
if (tile_shape_id != -1)
{
int comp = c.getTileConfigAsInt();
if (tile_shape_id != comp)
continue;
if (is_tma_warp_specialized && (int) c.cluster_shape != tactic_config.at("cluster_shape").get<int>())
continue;
// Found matching config
return i;
}
// Handle if the user provided a shape instead of the enum value
if (is_tma_warp_specialized)
{
// TODO Add cases for blackwell shapes
using Kv = uint64_t;
constexpr static auto K = [](int m, int n) { return (uint64_t(m) << 32) | uint64_t(n); };
static std::unordered_map<Kv, CutlassTileConfigSM90> const tile_map{
{K(64, 16), CutlassTileConfigSM90::CtaShape64x16x128B},
{K(64, 32), CutlassTileConfigSM90::CtaShape64x32x128B},
{K(64, 64), CutlassTileConfigSM90::CtaShape64x64x128B},
{K(64, 128), CutlassTileConfigSM90::CtaShape64x128x128B},
{K(64, 256), CutlassTileConfigSM90::CtaShape64x256x128B},
{K(128, 16), CutlassTileConfigSM90::CtaShape128x16x128B},
{K(128, 32), CutlassTileConfigSM90::CtaShape128x32x128B},
{K(128, 64), CutlassTileConfigSM90::CtaShape128x64x128B},
{K(128, 128), CutlassTileConfigSM90::CtaShape128x128x128B},
{K(128, 256), CutlassTileConfigSM90::CtaShape128x256x128B},
{K(256, 128), CutlassTileConfigSM90::CtaShape256x128x128B},
};
if (c.getTileConfigAsInt() != (int) tile_map.at(K(tile_shape[0], tile_shape[1])))
continue;
static std::unordered_map<Kv, ClusterShape> const cluster_map{
// CTA configs for M=64
{K(1, 1), ClusterShape::ClusterShape_1x1x1},
{K(2, 1), ClusterShape::ClusterShape_2x1x1},
{K(1, 2), ClusterShape::ClusterShape_1x2x1},
{K(2, 2), ClusterShape::ClusterShape_2x2x1},
};
std::array<int, 3> cluster_shape;
tactic_config.at("cluster_shape").get_to(cluster_shape);
if (c.cluster_shape != cluster_map.at(K(cluster_shape[0], cluster_shape[1])))
continue;
// Found matching config
return i;
}
else
{
std::array<int, 3> warp_shape;
tactic_config.at("warp_shape").get_to(warp_shape);
using Kv = uint64_t;
constexpr static auto K = [](std::array<int, 3> a, std::array<int, 3> b)
{
uint64_t sum = 0;
for (auto v : a)
sum = sum * 512 + v;
for (auto v : b)
sum = sum * 256 + v;
return sum;
};
static std::unordered_map<Kv, CutlassTileConfig> tile_map{
{K({128, 128, 8}, {64, 64, 8}), CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8},
{K({16, 128, 64}, {16, 32, 64}), CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64},
{K({32, 128, 64}, {32, 32, 64}), CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64},
{K({64, 128, 64}, {32, 64, 64}), CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64},
{K({64, 64, 128}, {32, 64, 64}), CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64},
{K({64, 128, 64}, {64, 32, 64}), CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64},
{K({128, 64, 64}, {64, 32, 64}), CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64},
{K({128, 128, 64}, {64, 32, 64}), CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64},
{K({128, 128, 64}, {64, 64, 64}), CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64},
{K({128, 128, 64}, {64, 32, 64}), CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64},
{K({128, 256, 64}, {64, 64, 64}), CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64},
{K({256, 128, 64}, {64, 64, 64}), CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64},
{K({16, 256, 64}, {16, 64, 64}), CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64}
};
if (c.tile_config_sm80 != tile_map.at(K(tile_shape, warp_shape)))
continue;
// Found matching config
return i;
}
}
}
catch (std::out_of_range const& e)
{
std::cerr << "Warning: error parsing tactic " << tactic_config.dump(2) << std::endl;
}
return -1;
}
template <class BenchClass>
void parseTacticToVectorID(nlohmann::json& tactic, std::vector<int>& tactic_ids)
void parseTacticToVectorID(nlohmann::json& tactic, std::vector<int>& tactic_ids, MoeGemmId gemm_id)
{
if (tactic.is_number_integer())
{
@ -193,20 +60,16 @@ void parseTacticToVectorID(nlohmann::json& tactic, std::vector<int>& tactic_ids)
{
for (auto c : tactic)
{
parseTacticToVectorID<BenchClass>(c, tactic_ids);
parseTacticToVectorID<BenchClass>(c, tactic_ids, gemm_id);
}
}
else if (tactic.is_object())
{
tactic_ids.push_back(parseTacticToId<BenchClass>(tactic));
}
else if (tactic.is_string())
{
assert(tactic.is_string());
auto tactic_name = tactic.get<std::string>();
if (tactic_name == "all")
{
auto all_tactics = listAllTactics<BenchClass>();
auto all_tactics = listAllTactics<BenchClass>(gemm_id);
tactic_ids.resize(all_tactics.size());
std::iota(tactic_ids.begin(), tactic_ids.end(), 0);
}
@ -410,39 +273,15 @@ void argGenLoadFile(benchmark::internal::Benchmark* benchmark)
}
// Do this after filtering datatypes as tactics only make sense if we know the data type
bool has_tactic_ids2 = false;
std::vector<int> tactic_ids1{};
std::vector<int> tactic_ids2{};
if (run_config.contains("tactic_id1") || run_config.contains("tactic_id2"))
if (run_config.contains("tactic_id1"))
{
if (run_config.contains("tactic_id"))
{
throw std::invalid_argument("Cannot use tactic_id and tactic_idX");
}
has_tactic_ids2 = true;
parseTacticToVectorID<BenchClass>(run_config["tactic_id1"], tactic_ids1);
parseTacticToVectorID<BenchClass>(run_config["tactic_id2"], tactic_ids2);
parseTacticToVectorID<BenchClass>(run_config["tactic_id1"], tactic_ids1, MoeGemmId::GEMM_1);
}
else
if (run_config.contains("tactic_id2"))
{
parseTacticToVectorID<BenchClass>(run_config["tactic_id"], tactic_ids1);
has_tactic_ids2 = false;
tactic_ids2.resize(1); // Dummy value so we loop exactly once below
}
if (tactic_ids1.empty() || tactic_ids2.empty())
{
std::cerr << "Warning: Skipping benchmark, no valid tactic found" << std::endl;
static bool printed = false;
if (!printed)
{
printed = true;
std::cerr << __PRETTY_FUNCTION__ << ": Valid Tactics are:\n";
auto confs = listAllTactics<BenchClass>();
for (auto c : confs)
std::cerr << c.toString();
}
continue;
parseTacticToVectorID<BenchClass>(run_config["tactic_id2"], tactic_ids2, MoeGemmId::GEMM_2);
}
auto get_or = [&](auto name, auto def)
@ -478,8 +317,6 @@ void argGenLoadFile(benchmark::internal::Benchmark* benchmark)
}
else if (gemm_to_profile == (int) GemmToProfile::GEMM_2)
{
if (!has_tactic_ids2)
tactic_ids2 = std::move(tactic_ids1);
tactic_ids1 = {-1};
}
}
@ -494,14 +331,31 @@ void argGenLoadFile(benchmark::internal::Benchmark* benchmark)
return val;
};
if (tactic_ids1.empty() || tactic_ids2.empty())
{
std::cerr << "Warning: Skipping benchmark, no valid tactic found" << std::endl;
static bool printed = false;
if (!printed)
{
printed = true;
std::cerr << __PRETTY_FUNCTION__ << ": Valid Tactics are:\n";
for (auto gemm_id : {MoeGemmId::GEMM_1, MoeGemmId::GEMM_2})
{
std::cerr << "GEMM " << (int) gemm_id << ":\n";
auto confs = listAllTactics<BenchClass>(gemm_id);
for (auto c : confs)
std::cerr << c.toString();
std::cerr << std::endl;
}
}
continue;
}
for (auto t1 : tactic_ids1)
{
// tactic_ids2 will have one dummy value if has_tactic_ids2 = false
for (auto t2 : tactic_ids2)
{
if (!has_tactic_ids2)
t2 = t1;
benchmark->Args({num_experts, //
get_range("k"), //
get_range("hidden_size"), //
@ -531,7 +385,7 @@ void argGenHardcoded(benchmark::internal::Benchmark* benchmark)
// {ActivationType::Relu, ActivationType::Gelu,
// ActivationType::Silu, ActivationType::Geglu,
// ActivationType::Swiglu};
auto cutlass_tactic = {-1}; // {0,..., listAllTactics<BenchClass>().size()};
auto cutlass_tactic = {-1}; // {0,..., listAllTactics<BenchClass>(MoeGemmId).size()};
auto routing_config = {LOAD_BALANCED_ROUTING_CONFIG}; // {0, 1, 2};
for (auto num_expert : num_experts)
@ -558,14 +412,18 @@ void argGen(benchmark::internal::Benchmark* benchmark)
{
if (LOG_LEVEL >= VERBOSE)
{
std::cout << "List of all tactics for dtype " << (int) BenchClass::toDTypeID() << ":\n";
int i = 0;
for (auto& t : listAllTactics<BenchClass>())
std::cout << "== List of all tactics for dtype " << (int) BenchClass::toDTypeID() << " ==\n";
for (auto gemm_id : {MoeGemmId::GEMM_1, MoeGemmId::GEMM_2})
{
std::cout << "Tactic " << i << ":\n";
std::cout << t.toString() << std::endl;
int i = 0;
std::cout << "=== GEMM " << (int) gemm_id << " ===\n";
for (auto& t : listAllTactics<BenchClass>(gemm_id))
{
std::cout << "==== Tactic " << i << " ====\n";
std::cout << t.toString() << std::endl;
i++;
i++;
}
}
}
@ -652,7 +510,6 @@ void help()
" \"bias\": int, (optional)\n"
" \"do_final_scale\": int, (optional)\n"
" \"act_fn\": int,\n"
" \"tactic_id\": tactic, (see below)\n"
" \"tactic_id1\": tactic, (see below)\n"
" \"tactic_id2\": tactic, (see below)\n"
" \"dtypes\": [string, ...], (optional)\n"
@ -676,27 +533,14 @@ void help()
"- \"do_final_scale\" - If final scales should be applied, 0 = no scale, 1 = scale\n"
"- \"act_fn\" - The activation function to use, 0 = identity, 1 = relu, 2 = gelu, 3 = silu, 4 = geglu, 5 = "
"swiglu\n"
"- \"tactic_id, tactic_id1, tactic_id2\"\n"
"The config for the CUTLASS GEMM. tactic_id sets the same tactic for both to the same tactic (except in "
"auto mode)\n"
"Use tactic_idX to set the tactic for the corresponding GEMM"
"- \"tactic_id1, tactic_id2\"\n"
"The config for the CUTLASS GEMM. tactic_idX sets the tactic for the corresponding GEMM"
"Valid tactics are:\n"
" - An object:\n"
" {\n"
" \"is_tma_warp_specialized\": bool,\n"
" \"tile_shape\": [int, int, int] or int,\n"
" \"cluster_shape\": [int, int, int] or int, (required for sm90, type must be an int if tile_shape "
"is "
"an int)\n"
" \"warp_shape\": [int, int, int], (required for non-sm90 if tile_shape is an array)\n"
" \"stages\": int, (required for non-sm90)\n"
" },\n"
" - An integer: corresponds to an index in the tactics array. WARNING this is not stable between test "
"configurations\n"
" - An array: of integers or objects, forms a list of tactics to sweep\n"
" - An integer: corresponds to an index in the tactics array. WARNING this is not stable between data types "
"or GPU architectures\n"
" - An array: of integers, forms a list of tactics to sweep\n"
" - The string \"all\": This will sweep through all possible tactics\n"
" - The string \"auto\": This runs a short benchmark to pick the fastest tactic before each benchmark "
"case. "
" - The string \"auto\": This runs a short benchmark to pick the fastest tactic before each benchmark case. "
"Useful for quick perf tests, prefer a full sweep and manually setting the tactic for more accurate "
"results"
"- dtypes - A list of dtypes to run this config through.\n"

View File

@ -294,8 +294,7 @@ if(TARGET ${NIXL_WRAPPER_TARGET})
endif()
if(NOT WIN32)
set_target_properties(${SHARED_TARGET} PROPERTIES LINK_FLAGS
"-Wl,-rpath='$ORIGIN'")
set_target_properties(${SHARED_TARGET} PROPERTIES BUILD_RPATH "$ORIGIN")
endif()
if(BUILD_PYT)

View File

@ -822,6 +822,14 @@ void CacheFormatter::unformat(TransferSession& session)
TLLM_LOG_WARNING("CacheFormatter::inquireSupport: only support non-MLA");
return false;
}
if (selfConfig.getParallelConfig().mContextParallelism != 1
|| destConfig.getParallelConfig().mContextParallelism != 1)
{
TLLM_LOG_WARNING(
"CacheFormatter::inquireSupport: context parallelism is not currently supported (selfCP=%d, destCP=%d).",
selfConfig.getParallelConfig().mContextParallelism, destConfig.getParallelConfig().mContextParallelism);
return false;
}
std::unordered_set<int> setVecDest{
destConfig.getModelConfig().mNbKvHeadsPerLayer.begin(), destConfig.getModelConfig().mNbKvHeadsPerLayer.end()};

View File

@ -20,11 +20,14 @@
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/batch_manager/medusaBuffers.h"
#include "tensorrt_llm/batch_manager/utils/logitsThread.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/common/nvtxUtils.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/decoderState.h"
#include "tensorrt_llm/runtime/decodingInput.h"
#include "tensorrt_llm/runtime/decodingOutput.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include "tensorrt_llm/runtime/runtimeKernels.h"
#include "tensorrt_llm/runtime/speculativeDecodingMode.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
@ -45,6 +48,8 @@ namespace tensorrt_llm::batch_manager
using SizeType32 = CreateNewDecoderRequests::SizeType32;
using TensorPtr = CreateNewDecoderRequests::TensorPtr;
using SharedConstPtr = CreateNewDecoderRequests::SharedConstPtr;
template <typename T>
using OptionalRef = tensorrt_llm::common::OptionalRef<T>;
namespace
{
@ -320,149 +325,165 @@ void initializeOutputs(DecodingOutput& dJointOutput, SizeType32 batchSlot, SizeT
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
} // namespace
void CreateNewDecoderRequests::newRequestSpeculativeDecoding(SizeType32 batchIdx,
runtime::decoder_batch::Request const& request, SamplingConfig const& samplingConfig,
runtime::ModelConfig const& modelConfig, DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput,
CudaStream const& runtimeStream, CudaStream const& decoderStream,
SpeculativeDecodingMode const& speculativeDecodingMode, SizeType32 maxDecodingEngineTokens)
void retrieveDraftLogits(TensorPtr& draftLogitsHost, std::shared_ptr<runtime::ITensor> const& reqDraftLogits,
ModelConfig const& modelConfig, WorldConfig const& worldConfig, bool speculativeDecodingFastLogits,
bool isLeaderInOrchMode, BufferManager const& bufferManager)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
if (speculativeDecodingMode.predictsDraftTokens())
if (!speculativeDecodingFastLogits)
{
auto const& stream = decoderStream;
BufferManager manager{std::make_shared<CudaStream>(stream.get())};
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
bufferManager.copy(*reqDraftLogits, *draftLogitsHost);
return;
}
auto& dJointOutput = jointDecodingOutput;
if (isLeaderInOrchMode)
{
// reqDraftLogits contains metadata for fast-logits path; validate size.
auto constexpr fastLogitsInfoSize = sizeof(te::SpeculativeDecodingFastLogitsInfo);
TLLM_CHECK_WITH_INFO(reqDraftLogits->getSizeInBytes() >= fastLogitsInfoSize,
"Draft logits metadata buffer is too small to hold SpeculativeDecodingFastLogitsInfo.");
te::SpeculativeDecodingFastLogitsInfo fastLogitsInfo{};
std::memcpy(&fastLogitsInfo, reqDraftLogits->data(), fastLogitsInfoSize);
utils::targetModelReceiveLogits(draftLogitsHost, fastLogitsInfo, modelConfig.getLogitsDtype());
TensorPtr nextDraftTokens
= ITensor::slice(dJointOutput.speculativeDecodingOutputs->nextDraftTokens, batchIdx, 1);
// FIXME: can we skip this?
manager.setZero(*nextDraftTokens);
if (speculativeDecodingMode.variableDraftLength())
// Broadcast to other ranks if needed
if (worldConfig.isTensorParallel())
{
TensorPtr nextDraftTokensLen
= ITensor::slice(dJointOutput.speculativeDecodingOutputs->nextDraftTokensLen, batchIdx, 1);
manager.setZero(*nextDraftTokensLen);
auto const& commSession = COMM_SESSION;
auto shape = draftLogitsHost->getShape();
commSession.bcastValue(shape.d[0], 0);
commSession.bcastValue(shape.d[1], 0);
commSession.bcast(draftLogitsHost->data(), draftLogitsHost->getSizeInBytes(), mpi::MpiType::kUINT8, 0);
}
}
else
{
TLLM_CHECK_WITH_INFO(worldConfig.isTensorParallel(),
"Fast logits path requires tensor-parallel broadcast for non-leader ranks.");
if (speculativeDecodingMode.isDraftTokensExternal())
{
newRequestDraftTokensExternal(batchIdx, request, samplingConfig, jointDecodingInput, decoderStream);
}
else if (speculativeDecodingMode.isMedusa())
{
newRequestMedusa(batchIdx, request, jointDecodingInput, decoderStream, maxDecodingEngineTokens);
}
else if (speculativeDecodingMode.isLookaheadDecoding())
{
newRequestLookahead(batchIdx, request, jointDecodingInput, jointDecodingOutput, runtimeStream);
}
else if (speculativeDecodingMode.isExplicitDraftTokens())
{
newRequestExplicitDraftTokens(batchIdx, request, jointDecodingOutput, runtimeStream);
}
else if (speculativeDecodingMode.isEagle())
{
newRequestEagle(batchIdx, request, modelConfig, jointDecodingOutput, runtimeStream);
// Get logits from leader rank
auto const& commSession = COMM_SESSION;
int64_t dims[2];
commSession.bcastValue(dims[0], 0);
commSession.bcastValue(dims[1], 0);
draftLogitsHost->reshape(ITensor::makeShape({dims[0], dims[1]}));
commSession.bcast(draftLogitsHost->data(), draftLogitsHost->getSizeInBytes(), mpi::MpiType::kUINT8, 0);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
};
void CreateNewDecoderRequests::newRequestDraftTokensExternal(SizeType32 batchIdx,
runtime::decoder_batch::Request const& request, SamplingConfig const& samplingConfig,
DecodingInput& jointDecodingInput, CudaStream const& decoderStream)
//! @brief Setups decoder internal tensors for new request in Draft model Sps mode
void newRequestDraftTokensExternal(DecodingInput& jointDecodingInput, SizeType32 batchIdx, LlmRequest const& llmReq,
SizeType32 numDecodingEngineTokens, runtime::ModelConfig const& modelConfig, WorldConfig const& worldConfig,
bool speculativeDecodingFastLogits, bool isLeaderInOrchMode, CudaStream const& decoderStream)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
BufferManager manager{std::make_shared<CudaStream>(decoderStream.get())};
BufferManager decoderBufferManager{std::make_shared<CudaStream>(decoderStream.get())};
auto& dJointInput = jointDecodingInput;
TLLM_CHECK(jointDecodingInput.externalDraftTokensInputs);
auto& externalDraftTokensInputs = jointDecodingInput.externalDraftTokensInputs;
auto const numDraftTokens = request.generatedTokensPerEngineStep - 1;
auto const& draftTokens = llmReq.getDraftTokens();
auto const numDraftTokens = numDecodingEngineTokens - 1;
auto const useDraftLogits = request.draftLogits.has_value();
if (useDraftLogits)
{
TensorPtr draftLogitsView = ITensor::view(request.draftLogits.value());
TensorPtr draftLogitsReqBatchSlice
= ITensor::slice(dJointInput.externalDraftTokensInputs->draftLogits, batchIdx, 1);
draftLogitsReqBatchSlice->squeeze(0);
TensorPtr draftLogitsReqTokensSlice = ITensor::slice(draftLogitsReqBatchSlice, 0, numDraftTokens);
manager.copy(*draftLogitsView, *draftLogitsReqTokensSlice);
}
auto* useDraftLogitsHostPtr = runtime::bufferCast<bool>(*dJointInput.externalDraftTokensInputs->useDraftLogitsHost);
useDraftLogitsHostPtr[batchIdx] = useDraftLogits;
auto useDraftLogitsView = ITensor::slice(dJointInput.externalDraftTokensInputs->useDraftLogits, batchIdx, 1);
runtime::kernels::invokeFill(*useDraftLogitsView, useDraftLogits, decoderStream);
auto numDraftTokensHostRange = runtime::BufferRange<SizeType32>(*externalDraftTokensInputs->numDraftTokensHost);
numDraftTokensHostRange[batchIdx] = numDraftTokens;
auto numDraftTokensView = ITensor::slice(externalDraftTokensInputs->numDraftTokens, batchIdx, 1);
runtime::kernels::invokeFill(*numDraftTokensView, numDraftTokens, decoderStream);
if (numDraftTokens > 0)
{
TensorPtr draftTokensReqBatchSlice
= ITensor::slice(dJointInput.externalDraftTokensInputs->draftTokenIds, batchIdx, 1);
draftTokensReqBatchSlice->squeeze(0);
TensorPtr draftTokensReqTokensSlice = ITensor::slice(draftTokensReqBatchSlice, 0, numDraftTokens);
TensorPtr draftTokensView = ITensor::view(request.draftTokens, ITensor::makeShape({numDraftTokens}));
manager.copy(*draftTokensView, *draftTokensReqTokensSlice);
TensorPtr draftTokenIdsHostSlice
= ITensor::slice(externalDraftTokensInputs->draftTokenIdsHost, {batchIdx, 0}, numDraftTokens);
// Copy to pinned host memory (don't care about stream of bufferManager)
decoderBufferManager.copy(draftTokens->data(), *draftTokenIdsHostSlice);
TensorPtr draftTokenIdsSlice
= ITensor::slice(externalDraftTokensInputs->draftTokenIds, {batchIdx, 0}, numDraftTokens);
decoderBufferManager.copy(*draftTokenIdsHostSlice, *draftTokenIdsSlice);
}
auto* numDraftTokensHostPtr
= runtime::bufferCast<SizeType32>(*dJointInput.externalDraftTokensInputs->numDraftTokensHost);
numDraftTokensHostPtr[batchIdx] = numDraftTokens;
auto numDraftTokensView = ITensor::slice(dJointInput.externalDraftTokensInputs->numDraftTokens, batchIdx, 1);
runtime::kernels::invokeFill(*numDraftTokensView, numDraftTokens, decoderStream);
auto const& draftLogits = llmReq.getDraftLogits();
auto const useDraftLogits = draftLogits.has_value();
auto useDraftLogitsHostRange = runtime::BufferRange<bool>(*externalDraftTokensInputs->useDraftLogitsHost);
useDraftLogitsHostRange[batchIdx] = useDraftLogits;
auto useDraftLogitsView = ITensor::slice(externalDraftTokensInputs->useDraftLogits, batchIdx, 1);
runtime::kernels::invokeFill(*useDraftLogitsView, useDraftLogits, decoderStream);
if (useDraftLogits)
{
TensorPtr draftLogitsHostSlice
= ITensor::slice(externalDraftTokensInputs->draftLogitsHost, {batchIdx, 0}, numDraftTokens);
retrieveDraftLogits(draftLogitsHostSlice, draftLogits.value(), modelConfig, worldConfig,
speculativeDecodingFastLogits, isLeaderInOrchMode, decoderBufferManager);
TensorPtr draftLogitsSlice
= ITensor::slice(externalDraftTokensInputs->draftLogits, {batchIdx, 0}, numDraftTokens);
decoderBufferManager.copy(*draftLogitsHostSlice, *draftLogitsSlice);
}
auto const& samplingConfig = llmReq.mSamplingConfig;
bool const useRandomAcceptanceThreshold = !samplingConfig.draftAcceptanceThreshold.has_value();
float const constantThreshold
= useRandomAcceptanceThreshold ? 0 : samplingConfig.draftAcceptanceThreshold.value()[0];
dJointInput.externalDraftTokensInputs->useRandomAcceptanceThreshold = useRandomAcceptanceThreshold;
dJointInput.externalDraftTokensInputs->constantThreshold = constantThreshold;
externalDraftTokensInputs->useRandomAcceptanceThreshold = useRandomAcceptanceThreshold;
externalDraftTokensInputs->constantThreshold = constantThreshold;
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void CreateNewDecoderRequests::newRequestMedusa(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
DecodingInput& jointDecodingInput, CudaStream const& decoderStream, SizeType32 maxDecodingEngineTokens)
//! @brief Setups decoder internal tensors for new Medusa request
void newRequestMedusa(DecodingInput& jointDecodingInput, SizeType32 batchIdx, LlmRequest& llmReq,
SizeType32 numDecodingEngineTokens, SizeType32 maxDecodingEngineTokens, MedusaBuffers const& medusaBuffers,
CudaStream const& decoderStream)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
llmReq.mSamplingConfig.topKMedusaHeads = {medusaBuffers.mTopKs};
// FIXME: we must set medusa paths and tree ids not from seq slot, but from llmRequest?
// When multiple microbatches buffers are used, runtime buffers can not be addressed with seqSlot.
auto medusaPaths = ITensor::slice(medusaBuffers.medusaPathsDevice, 0, 1);
auto medusaTreeIds = ITensor::slice(medusaBuffers.medusaTreeIdsDevice, 0, 1);
BufferManager manager{std::make_shared<CudaStream>(decoderStream.get())};
auto& dJointInput = jointDecodingInput;
auto& medusaInputs = jointDecodingInput.medusaInputs;
TensorPtr curTokensPerStepSlice
= ITensor::slice(constPointerCast(dJointInput.medusaInputs->medusaCurTokensPerStep), batchIdx, 1);
= ITensor::slice(constPointerCast(medusaInputs->medusaCurTokensPerStep), batchIdx, 1);
// Context phase Medusa processes 1 token only, new value from targetTokensPerStep will be filled at the end
// of first decoder
runtime::kernels::invokeFill(*curTokensPerStepSlice, 1, decoderStream);
TensorPtr targetTokensPerStepSlice
= ITensor::slice(constPointerCast(dJointInput.medusaInputs->medusaTargetTokensPerStep), batchIdx, 1);
auto const generatedTokensPerEngineStep = request.generatedTokensPerEngineStep;
TLLM_CHECK_WITH_INFO(generatedTokensPerEngineStep <= maxDecodingEngineTokens,
"Tokens per step for (%d) is larger than maximum tokens per step (%d)", generatedTokensPerEngineStep,
= ITensor::slice(constPointerCast(medusaInputs->medusaTargetTokensPerStep), batchIdx, 1);
TLLM_CHECK_WITH_INFO(numDecodingEngineTokens <= maxDecodingEngineTokens,
"Tokens per step for (%d) is larger than maximum tokens per step (%d)", numDecodingEngineTokens,
maxDecodingEngineTokens);
runtime::kernels::invokeFill(*targetTokensPerStepSlice, generatedTokensPerEngineStep, decoderStream);
runtime::kernels::invokeFill(*targetTokensPerStepSlice, numDecodingEngineTokens, decoderStream);
TensorPtr pathsSlice = ITensor::slice(constPointerCast(dJointInput.medusaInputs->medusaPaths), batchIdx, 1);
manager.copy(*request.medusaPaths, *pathsSlice);
TensorPtr pathsSlice = ITensor::slice(constPointerCast(medusaInputs->medusaPaths), batchIdx, 1);
manager.copy(*medusaPaths, *pathsSlice);
TensorPtr treeIdsSlice = ITensor::slice(constPointerCast(dJointInput.medusaInputs->medusaTreeIds), batchIdx, 1);
manager.copy(*request.medusaTreeIds, *treeIdsSlice);
TensorPtr treeIdsSlice = ITensor::slice(constPointerCast(medusaInputs->medusaTreeIds), batchIdx, 1);
manager.copy(*medusaTreeIds, *treeIdsSlice);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void CreateNewDecoderRequests::newRequestLookahead(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream)
//! @brief Setups decoder internal tensors for new Lookahead request
void newRequestLookahead(DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, SizeType32 batchIdx,
CudaStream const& runtimeStream)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK(jointDecodingOutput.lookaheadOutputs);
TLLM_CHECK(jointDecodingInput.lookaheadInputs);
// The first generation step only generate 1 token.
TensorPtr curTokensPerStepSlice
@ -472,65 +493,72 @@ void CreateNewDecoderRequests::newRequestLookahead(SizeType32 batchIdx, runtime:
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void CreateNewDecoderRequests::newRequestExplicitDraftTokens(SizeType32 batchIdx,
runtime::decoder_batch::Request const& request, DecodingOutput& jointDecodingOutput,
CudaStream const& runtimeStream)
//! @brief Setups decoder internal tensors for new Explicit draft tokens request
void newRequestExplicitDraftTokens(
DecodingOutput& jointDecodingOutput, SizeType32 batchIdx, LlmRequest const& llmReq, CudaStream const& runtimeStream)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK(jointDecodingOutput.explicitDraftTokensBuffers);
auto const inputLen = llmReq.getPromptLen();
TensorPtr positionIdsBaseSlice
= ITensor::slice(jointDecodingOutput.explicitDraftTokensBuffers->positionIdsBase, batchIdx, 1);
runtime::kernels::invokeFill(*positionIdsBaseSlice, request.inputLen, runtimeStream);
runtime::kernels::invokeFill(*positionIdsBaseSlice, inputLen, runtimeStream);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void CreateNewDecoderRequests::newRequestEagle(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
runtime::ModelConfig const& modelConfig, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream)
//! @brief Setups decoder internal tensors for new Eagle request
void newRequestEagle(DecodingOutput& jointDecodingOutput, SizeType32 batchIdx, LlmRequest const& llmReq,
runtime::ModelConfig const& modelConfig, executor::DecodingConfig const& decodingConfig,
CudaStream const& runtimeStream)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK(jointDecodingOutput.eagleBuffers);
auto& eagleBuffers = *jointDecodingOutput.eagleBuffers;
auto const inputLen = llmReq.getPromptLen();
BufferManager manager{std::make_shared<CudaStream>(runtimeStream.get())};
TensorPtr eagleNetCtxRequestTypesHostSlice
= ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetCtxRequestTypesHost, batchIdx, 1);
TensorPtr eagleNetCtxRequestTypesHostSlice = ITensor::slice(eagleBuffers.eagleNetCtxRequestTypesHost, batchIdx, 1);
TensorPtr eagleNetCtxContextLengthsHostSlice
= ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetCtxContextLengthsHost, batchIdx, 1);
= ITensor::slice(eagleBuffers.eagleNetCtxContextLengthsHost, batchIdx, 1);
TensorPtr eagleNetCtxPastKeyValueLengthsHostSlice
= ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetCtxPastKeyValueLengthsHost, batchIdx, 1);
= ITensor::slice(eagleBuffers.eagleNetCtxPastKeyValueLengthsHost, batchIdx, 1);
runtime::bufferCast<SizeType32>(*eagleNetCtxRequestTypesHostSlice)[0] = 0;
runtime::bufferCast<SizeType32>(*eagleNetCtxContextLengthsHostSlice)[0] = request.inputLen;
runtime::bufferCast<SizeType32>(*eagleNetCtxPastKeyValueLengthsHostSlice)[0] = request.inputLen;
runtime::bufferCast<SizeType32>(*eagleNetCtxContextLengthsHostSlice)[0] = inputLen;
runtime::bufferCast<SizeType32>(*eagleNetCtxPastKeyValueLengthsHostSlice)[0] = inputLen;
TensorPtr eagleNetGenRequestTypesHostSlice
= ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetGenRequestTypesHost, batchIdx, 1);
TensorPtr eagleNetGenRequestTypesHostSlice = ITensor::slice(eagleBuffers.eagleNetGenRequestTypesHost, batchIdx, 1);
TensorPtr eagleNetGenContextLengthsHostSlice
= ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetGenContextLengthsHost, batchIdx, 1);
= ITensor::slice(eagleBuffers.eagleNetGenContextLengthsHost, batchIdx, 1);
TensorPtr eagleNetGenPastKeyValueLengthsHostSlice
= ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetGenPastKeyValueLengthsHost, batchIdx, 1);
= ITensor::slice(eagleBuffers.eagleNetGenPastKeyValueLengthsHost, batchIdx, 1);
runtime::bufferCast<SizeType32>(*eagleNetGenRequestTypesHostSlice)[0] = 1;
runtime::bufferCast<SizeType32>(*eagleNetGenContextLengthsHostSlice)[0] = request.inputLen;
runtime::bufferCast<SizeType32>(*eagleNetGenPastKeyValueLengthsHostSlice)[0] = request.inputLen;
runtime::bufferCast<SizeType32>(*eagleNetGenContextLengthsHostSlice)[0] = inputLen;
runtime::bufferCast<SizeType32>(*eagleNetGenPastKeyValueLengthsHostSlice)[0] = inputLen;
auto const eagleModule = std::dynamic_pointer_cast<tensorrt_llm::runtime::EagleModule const>(
modelConfig.getSpeculativeDecodingModulePtr());
std::optional<executor::EagleChoices> eagleChoicesOpt;
if (request.eagleConfig)
auto const& eagleConfig = llmReq.getEagleConfig() ? llmReq.getEagleConfig() : decodingConfig.getEagleConfig();
if (eagleConfig)
{
eagleChoicesOpt = request.eagleConfig->getEagleChoices();
eagleChoicesOpt = eagleConfig->getEagleChoices();
}
if (!request.eagleConfig || !request.eagleConfig->useDynamicTree())
if (!eagleConfig || !eagleConfig->useDynamicTree())
{
TensorPtr draftPathsHostSlice = ITensor::slice(jointDecodingOutput.eagleBuffers->draftPathsHost, batchIdx, 1);
TensorPtr draftPathsSlice = ITensor::slice(jointDecodingOutput.eagleBuffers->draftPaths, batchIdx, 1);
TensorPtr draftPathsHostSlice = ITensor::slice(eagleBuffers.draftPathsHost, batchIdx, 1);
TensorPtr draftPathsSlice = ITensor::slice(eagleBuffers.draftPaths, batchIdx, 1);
// eagleConfig is nullptr or Eagle-1
std::vector<SizeType32> topKs;
@ -546,6 +574,61 @@ void CreateNewDecoderRequests::newRequestEagle(SizeType32 batchIdx, runtime::dec
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
//! @brief Setups decoder internal tensors for new speculative decoding request
void newRequestSpeculativeDecoding(DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput,
SizeType32 batchIdx, LlmRequest& llmReq, SpeculativeDecodingMode const& speculativeDecodingMode,
SizeType32 numDecodingEngineTokens, SizeType32 maxDecodingEngineTokens,
OptionalRef<MedusaBuffers const> medusaBuffers, runtime::ModelConfig const& modelConfig,
WorldConfig const& worldConfig, executor::DecodingConfig const& decodingConfig, bool speculativeDecodingFastLogits,
bool isLeaderInOrchMode, CudaStream const& runtimeStream, CudaStream const& decoderStream)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
if (speculativeDecodingMode.predictsDraftTokens())
{
BufferManager manager{std::make_shared<CudaStream>(decoderStream.get())};
TLLM_CHECK(jointDecodingOutput.speculativeDecodingOutputs);
auto& speculativeDecodingOutputs = *jointDecodingOutput.speculativeDecodingOutputs;
TensorPtr nextDraftTokens = ITensor::slice(speculativeDecodingOutputs.nextDraftTokens, batchIdx, 1);
// FIXME: can we skip this?
manager.setZero(*nextDraftTokens);
if (speculativeDecodingMode.variableDraftLength())
{
TensorPtr nextDraftTokensLen = ITensor::slice(speculativeDecodingOutputs.nextDraftTokensLen, batchIdx, 1);
manager.setZero(*nextDraftTokensLen);
}
}
if (speculativeDecodingMode.isDraftTokensExternal())
{
newRequestDraftTokensExternal(jointDecodingInput, batchIdx, llmReq, numDecodingEngineTokens, modelConfig,
worldConfig, speculativeDecodingFastLogits, isLeaderInOrchMode, decoderStream);
}
else if (speculativeDecodingMode.isMedusa())
{
TLLM_CHECK(medusaBuffers);
newRequestMedusa(jointDecodingInput, batchIdx, llmReq, numDecodingEngineTokens, maxDecodingEngineTokens,
medusaBuffers.value(), decoderStream);
}
else if (speculativeDecodingMode.isLookaheadDecoding())
{
newRequestLookahead(jointDecodingInput, jointDecodingOutput, batchIdx, runtimeStream);
}
else if (speculativeDecodingMode.isExplicitDraftTokens())
{
newRequestExplicitDraftTokens(jointDecodingOutput, batchIdx, llmReq, runtimeStream);
}
else if (speculativeDecodingMode.isEagle())
{
newRequestEagle(jointDecodingOutput, batchIdx, llmReq, modelConfig, decodingConfig, runtimeStream);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
} // namespace
std::tuple<std::vector<runtime::ITensor::SharedConstPtr>, std::vector<executor::LookaheadDecodingConfig>>
CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedContextRequests, TensorPtr const& inputIds,
executor::DecodingConfig const& decodingConfig, runtime::decoder::DecoderState& decoderState,
@ -563,9 +646,6 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon
}
inputIds->resize(decoderInputSize);
std::vector<decoder_batch::Request> decoderRequests;
decoderRequests.reserve(finishedContextRequests.size());
std::vector<runtime::ITensor::SharedConstPtr> lookaheadPrompt;
std::vector<executor::LookaheadDecodingConfig> lookaheadAlgoConfigs;
if (modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding())
@ -597,36 +677,18 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon
auto const promptLen = llmReq->getPromptLen();
auto decoderRequest = decoder_batch::Request{promptLen};
SizeType32 numDecodingEngineTokens{1};
if (modelConfig.getSpeculativeDecodingMode().isDraftTokensExternal())
{
if (llmReq->hasDraftTokens())
{
auto const& draftTokens = llmReq->getDraftTokens();
// Copy to pinned host memory (don't care about stream of bufferManager)
decoderRequest.draftTokens = decoderBufferManager.copyFrom(*draftTokens, MemoryType::kPINNEDPOOL);
auto const& draftLogits = llmReq->getDraftLogits();
if (draftLogits.has_value())
{
decoderRequest.draftLogits
= retrieveDraftLogits(modelConfig, worldConfig, draftLogits.value(), decoderBufferManager);
}
decoderRequest.generatedTokensPerEngineStep = draftTokens->size() + 1;
}
else
{
decoderRequest.generatedTokensPerEngineStep = 1;
}
numDecodingEngineTokens = llmReq->getNumDraftTokens() + 1;
}
else if (!modelConfig.getSpeculativeDecodingMode().isNone())
{
decoderRequest.generatedTokensPerEngineStep = modelConfig.getMaxDecodingTokens();
numDecodingEngineTokens = modelConfig.getMaxDecodingTokens();
}
auto& dJointInput = decoderState.getJointDecodingInput();
auto const numDecodingEngineTokens = decoderRequest.generatedTokensPerEngineStep;
initializeInputLengths(dJointInput, batchSlot, promptLen, llmReq->mMaxNewTokens, numDecodingEngineTokens,
maxSequenceLength, decoderBufferManager);
decoderState.setNumDecodingEngineTokens(batchSlot, numDecodingEngineTokens);
@ -667,16 +729,7 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon
{
TLLM_CHECK(beamWidth == 1);
if (modelConfig.getSpeculativeDecodingMode().isMedusa())
{
TLLM_CHECK(medusaBuffers);
llmReq->mSamplingConfig.topKMedusaHeads = {medusaBuffers->mTopKs};
// FIXME: we must set medusa paths and tree ids not from seq slot, but from llmRequest?
// When multiple microbatches buffers are used, runtime buffers can not be addressed with seqSlot.
decoderRequest.medusaPaths = ITensor::slice(medusaBuffers->medusaPathsDevice, 0, 1);
decoderRequest.medusaTreeIds = ITensor::slice(medusaBuffers->medusaTreeIdsDevice, 0, 1);
}
else if (modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding())
if (modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding())
{
lookaheadPrompt.emplace_back(requestIds);
@ -684,67 +737,17 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon
= llmReq->getLookaheadConfig().value_or(decodingConfig.getLookaheadDecodingConfig().value());
lookaheadAlgoConfigs.emplace_back(lookaheadRuntimeConfig);
}
else if (modelConfig.getSpeculativeDecodingMode().isEagle())
{
decoderRequest.eagleConfig
= llmReq->getEagleConfig() ? llmReq->getEagleConfig() : decodingConfig.getEagleConfig();
}
newRequestSpeculativeDecoding(batchSlot, decoderRequest, samplingConfig, modelConfig,
decoderState.getJointDecodingInput(), decoderState.getJointDecodingOutput(), runtimeStream,
decoderStream, decoderState.getSpeculativeDecodingMode(), decoderState.getMaxDecodingEngineTokens());
newRequestSpeculativeDecoding(decoderState.getJointDecodingInput(), decoderState.getJointDecodingOutput(),
batchSlot, *llmReq, decoderState.getSpeculativeDecodingMode(), numDecodingEngineTokens,
decoderState.getMaxDecodingEngineTokens(), medusaBuffers, modelConfig, worldConfig, decodingConfig,
mSpeculativeDecodingFastLogits, mIsLeaderInOrchMode, runtimeStream, decoderStream);
}
decoderRequests.push_back(decoderRequest);
inputOffset += promptLen;
}
return {std::move(lookaheadPrompt), std::move(lookaheadAlgoConfigs)};
}
std::shared_ptr<runtime::ITensor> CreateNewDecoderRequests::retrieveDraftLogits(ModelConfig const& modelConfig,
WorldConfig const& worldConfig, std::shared_ptr<runtime::ITensor> const& tensor,
BufferManager const& bufferManager) const
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
if (!mSpeculativeDecodingFastLogits)
{
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return bufferManager.copyFrom(*tensor, MemoryType::kPINNEDPOOL);
}
if (mIsLeaderInOrchMode)
{
te::SpeculativeDecodingFastLogitsInfo fastLogitsInfo;
std::memcpy(&fastLogitsInfo, tensor->data(), sizeof(fastLogitsInfo));
auto logits = utils::targetModelReceiveLogits(fastLogitsInfo, modelConfig).value();
// Broadcast to other ranks if needed
if (worldConfig.isTensorParallel())
{
auto const& commSession = COMM_SESSION;
auto shape = logits->getShape();
commSession.bcastValue(shape.d[0], 0);
commSession.bcastValue(shape.d[1], 0);
commSession.bcast(logits->data(), logits->getSizeInBytes(), mpi::MpiType::kUINT8, 0);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return logits;
}
// Get logits from leader rank
auto const& commSession = COMM_SESSION;
int64_t dims[2];
commSession.bcastValue(dims[0], 0);
commSession.bcastValue(dims[1], 0);
auto const logitsDtype = modelConfig.getLogitsDtype();
auto logits = tensorrt_llm::runtime::BufferManager::pinnedPool(ITensor::makeShape({dims[0], dims[1]}), logitsDtype);
commSession.bcast(logits->data(), logits->getSizeInBytes(), mpi::MpiType::kUINT8, 0);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return logits;
};
} // namespace tensorrt_llm::batch_manager

View File

@ -558,18 +558,20 @@ void MLACacheFormatter::unformat(TransferSession& session)
TLLM_LOG_WARNING("MLACacheFormatter::inquireSupport: only support MLA");
return false;
}
if (selfConfig.getAttentionConfig().mKvFactor != destConfig.getAttentionConfig().mKvFactor)
{
TLLM_LOG_WARNING("MLACacheFormatter::inquireSupport: only support same kv factor");
return false;
}
if (selfConfig.getParallelConfig().mEnableAttentionDP
&& (selfConfig.getParallelConfig().mTensorParallelism % selfConfig.getParallelConfig().mDPsize != 0))
{
TLLM_LOG_WARNING("MLACacheFormatter::inquireSupport: TP size must be divisible by DP size");
return false;
}
if (selfConfig.getParallelConfig().mContextParallelism != 1
|| destConfig.getParallelConfig().mContextParallelism != 1)
{
TLLM_LOG_WARNING(
"MLACacheFormatter::inquireSupport: context parallelism is not currently supported (selfCP=%d, destCP=%d).",
selfConfig.getParallelConfig().mContextParallelism, destConfig.getParallelConfig().mContextParallelism);
return false;
}
if (destConfig.getParallelConfig().mEnableAttentionDP
&& (destConfig.getParallelConfig().mTensorParallelism % destConfig.getParallelConfig().mDPsize != 0))
{

View File

@ -121,8 +121,8 @@ void draftModelSendLogitsThread(int device, std::atomic<bool>* draftModelThreadS
#endif // ENABLE_MULTI_DEVICE
}
std::optional<runtime::ITensor::SharedPtr> targetModelReceiveLogits(
executor::SpeculativeDecodingFastLogitsInfo const& fastLogitsInfo, runtime::ModelConfig const& modelConfig)
void targetModelReceiveLogits(runtime::ITensor::SharedPtr& draftLogitsHost,
executor::SpeculativeDecodingFastLogitsInfo const& fastLogitsInfo, nvinfer1::DataType logitsDtype)
{
#if ENABLE_MULTI_DEVICE
auto const& worldComm = tensorrt_llm::mpi::MpiComm::world();
@ -151,10 +151,7 @@ std::optional<runtime::ITensor::SharedPtr> targetModelReceiveLogits(
int64_t dims[2];
MPICHECK(MPI_Mrecv(&dims, count, MPI_INT64_T, &msg, &status));
auto const logitsDtype = modelConfig.getLogitsDtype();
auto tensor = tensorrt_llm::runtime::BufferManager::pinnedPool(
runtime::ITensor::makeShape({dims[0], dims[1]}), logitsDtype);
draftLogitsHost->reshape(runtime::ITensor::makeShape({dims[0], dims[1]}));
worldComm.mprobe(fastLogitsInfo.draftParticipantId, mpi::MpiTag::kSpecDecLogitsData, &msg, &status);
@ -163,11 +160,7 @@ std::optional<runtime::ITensor::SharedPtr> targetModelReceiveLogits(
uint64_t const expectedSize = static_cast<uint64_t>(dims[0]) * dims[1] * tc::getDTypeSize(logitsDtype);
TLLM_CHECK((uint64_t) count == expectedSize);
MPICHECK(MPI_Mrecv(tensor->data(), count, MPI_UINT8_T, &msg, &status));
return tensor;
#else
return std::nullopt;
MPICHECK(MPI_Mrecv(draftLogitsHost->data(), count, MPI_UINT8_T, &msg, &status));
#endif // ENABLE_MULTI_DEVICE
}

View File

@ -21,10 +21,8 @@
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/modelConfig.h"
#include <memory>
#include <optional>
namespace tensorrt_llm::batch_manager
{
@ -52,7 +50,7 @@ void draftModelSendLogitsThread(int device, std::atomic<bool>* draftModelThreadS
std::shared_ptr<kv_cache_manager::BaseKVCacheManager> const& crossKvCacheManager,
std::shared_ptr<BasePeftCacheManager> const& peftCacheManager);
std::optional<runtime::ITensor::SharedPtr> targetModelReceiveLogits(
executor::SpeculativeDecodingFastLogitsInfo const& fastLogitsInfo, runtime::ModelConfig const& modelConfig);
void targetModelReceiveLogits(runtime::ITensor::SharedPtr& draftLogitsHost,
executor::SpeculativeDecodingFastLogitsInfo const& fastLogitsInfo, nvinfer1::DataType logitsDtype);
} // namespace tensorrt_llm::batch_manager::utils

View File

@ -435,6 +435,14 @@ struct CutlassGemmConfig
int sm_version = 80; // Use 80 as a catch all for <90
bool is_tma_warp_specialized = false;
enum class EpilogueFusionType : int
{
NONE,
FINALIZE
};
EpilogueFusionType epilogue_fusion_type = EpilogueFusionType::NONE;
CutlassGemmConfig() = default;
CutlassGemmConfig(CutlassTileConfig tile_config, SplitKStyle split_k_style, int split_k_factor, int stages)
@ -505,7 +513,8 @@ struct CutlassGemmConfig
<< "\n\tsm: " << sm_version << "\n\ttile shape ID: " << getTileConfigAsInt()
<< "\n\tcluster shape ID: " << (int) cluster_shape
<< "\n\tmainloop sched: " << (int) mainloop_schedule << "\n\tepi sched: " << (int) epilogue_schedule
<< "\n\tenable cuda kernel: " << (enableCudaKernel ? "true" : "false");
<< "\n\tenable cuda kernel: " << (enableCudaKernel ? "true" : "false")
<< "\n\tepilogue fusion type: " << (int) epilogue_fusion_type;
}
else if (tile_config_sm80 != tensorrt_llm::cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic)
{
@ -537,7 +546,8 @@ inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& conf
<< ", mainloop_schedule_enum: " << int(config.mainloop_schedule)
<< ", epilogue_schedule_enum: " << int(config.epilogue_schedule)
<< ", cluster_shape_enum: " << int(config.cluster_shape)
<< ", enable_cuda_kernel: " << (config.enableCudaKernel ? "true" : "false");
<< ", enable_cuda_kernel: " << (config.enableCudaKernel ? "true" : "false")
<< ", epilogue_fusion_type: " << int(config.epilogue_fusion_type);
}
else
{

View File

@ -531,14 +531,15 @@ kv_cache::CacheState Serialization::deserializeCacheState(std::istream& is)
auto tokensPerBlock = su::deserialize<decltype(CacheState::ModelConfig::mTokensPerBlock)>(is);
auto tensorParallelism = su::deserialize<decltype(CacheState::ParallelConfig::mTensorParallelism)>(is);
auto pipelineParallelism = su::deserialize<decltype(CacheState::ParallelConfig::mPipelineParallelism)>(is);
auto contextParallelism = su::deserialize<decltype(CacheState::ParallelConfig::mContextParallelism)>(is);
auto enableAttentionDP = su::deserialize<decltype(CacheState::ParallelConfig::mEnableAttentionDP)>(is);
auto DPrank = su::deserialize<decltype(CacheState::ParallelConfig::mDPrank)>(is);
auto DPsize = su::deserialize<decltype(CacheState::ParallelConfig::mDPsize)>(is);
auto dataType = su::deserialize<decltype(CacheState::mDataType)>(is);
auto attentionType = su::deserialize<decltype(CacheState::AttentionConfig::mAttentionType)>(is);
auto kvFactor = su::deserialize<decltype(CacheState::AttentionConfig::mKvFactor)>(is);
return CacheState{nbKvHeadsPerLayer, sizePerHead, tokensPerBlock, tensorParallelism, pipelineParallelism, dataType,
attentionType, kvFactor, enableAttentionDP, DPrank, DPsize};
return CacheState{nbKvHeadsPerLayer, sizePerHead, tokensPerBlock, tensorParallelism, pipelineParallelism,
contextParallelism, dataType, attentionType, kvFactor, enableAttentionDP, DPrank, DPsize};
}
void Serialization::serialize(kv_cache::CacheState const& state, std::ostream& os)
@ -548,6 +549,7 @@ void Serialization::serialize(kv_cache::CacheState const& state, std::ostream& o
su::serialize(state.mModelConfig.mTokensPerBlock, os);
su::serialize(state.mParallelConfig.mTensorParallelism, os);
su::serialize(state.mParallelConfig.mPipelineParallelism, os);
su::serialize(state.mParallelConfig.mContextParallelism, os);
su::serialize(state.mParallelConfig.mEnableAttentionDP, os);
su::serialize(state.mParallelConfig.mDPrank, os);
su::serialize(state.mParallelConfig.mDPsize, os);
@ -564,6 +566,7 @@ size_t Serialization::serializedSize(kv_cache::CacheState const& state)
totalSize += su::serializedSize(state.mModelConfig.mTokensPerBlock);
totalSize += su::serializedSize(state.mParallelConfig.mTensorParallelism);
totalSize += su::serializedSize(state.mParallelConfig.mPipelineParallelism);
totalSize += su::serializedSize(state.mParallelConfig.mContextParallelism);
totalSize += su::serializedSize(state.mParallelConfig.mEnableAttentionDP);
totalSize += su::serializedSize(state.mParallelConfig.mDPrank);
totalSize += su::serializedSize(state.mParallelConfig.mDPsize);

View File

@ -0,0 +1,268 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "moeTopKFuncs.cuh"
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
#include "tensorrt_llm/common/envUtils.h"
#include "tensorrt_llm/kernels/archCondition.h"
#include "tensorrt_llm/kernels/customMoeRoutingKernels.h"
#include <climits> // For INT_MAX
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cub/cub.cuh>
#include <cuda/std/limits> // For numeric_limits
#include <math.h>
namespace cg = cooperative_groups;
using namespace tensorrt_llm::common;
namespace tensorrt_llm::kernels
{
static constexpr int BLOCK_SIZE = 1024;
static constexpr int WARP_SIZE = 32;
static constexpr int WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE;
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
__device__ T calcSoftmax(cg::thread_block_tile<WARP_SIZE> const& warp, T score, int32_t laneIdx, int32_t NumTopExperts)
{
T maxScore = T{-INFINITY};
if (laneIdx < NumTopExperts)
{
maxScore = score >= maxScore ? score : maxScore;
}
maxScore = cg::reduce(warp, maxScore, cg::greater<T>());
float sumScore{0.f};
float newScore;
// Get the summation of scores for each token
if (laneIdx < NumTopExperts)
{
newScore = static_cast<float>(score) - static_cast<float>(maxScore);
newScore = static_cast<float>(exp(newScore));
sumScore += newScore;
}
sumScore = cg::reduce(warp, sumScore, cg::plus<float>());
if (laneIdx < NumTopExperts)
{
score = static_cast<T>(newScore / sumScore);
}
return score;
}
template <typename DataType, int VecSize>
__device__ void calcSoftmax(cg::thread_block_tile<WARP_SIZE> const& warp, DataType (&scores)[VecSize])
{
DataType maxScore = DataType{-INFINITY};
DataType sumScore = DataType{0.f};
// Get the max score for each token
#pragma unroll
for (int i = 0; i < VecSize; ++i)
{
maxScore = scores[i] >= maxScore ? scores[i] : maxScore;
}
maxScore = cg::reduce(warp, maxScore, cg::greater<DataType>());
// Get the summation of scores for each token
#pragma unroll
for (int i = 0; i < VecSize; ++i)
{
scores[i] = static_cast<DataType>(exp(scores[i] - maxScore));
sumScore += scores[i];
}
sumScore = cg::reduce(warp, sumScore, cg::plus<DataType>());
// Normalize the scores
#pragma unroll
for (int i = 0; i < VecSize; ++i)
{
scores[i] = static_cast<DataType>(scores[i] / sumScore);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename InputT, typename OutputT, typename IdxT, int MaxNumExperts, int MaxNumTopExperts,
bool DoSoftmaxBeforeTopK>
__global__ void customMoeRoutingKernel(InputT* routerLogits, OutputT* topkValues, IdxT* topkIndices,
int32_t const numTokens, int32_t const numExperts, int32_t const topK)
{
using BaseType = std::conditional_t<DoSoftmaxBeforeTopK, float, InputT>;
uint32_t const blockRank = blockIdx.x;
uint32_t const tIdx = BLOCK_SIZE * blockRank + threadIdx.x;
uint32_t const warpIdx = tIdx / WARP_SIZE;
uint32_t const laneIdx = tIdx % WARP_SIZE;
uint32_t const warpNum = gridDim.x * WARPS_PER_BLOCK;
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
BaseType minScore = BaseType{-INFINITY};
for (uint32_t tokenId = warpIdx; tokenId < numTokens; tokenId += warpNum)
{
auto scoreOffset = tokenId * numExperts;
auto outputOffset = tokenId * topK;
BaseType inputScore[MaxNumExperts / WARP_SIZE];
IdxT inputIndex[MaxNumExperts / WARP_SIZE];
BaseType warpTopKScore[MaxNumTopExperts];
IdxT warpTopKExpertIdx[MaxNumTopExperts];
// Load scores and indices for this warp
for (uint32_t i = 0; i < MaxNumExperts / WARP_SIZE; ++i)
{
auto expertIdx = i * WARP_SIZE + laneIdx;
inputScore[i]
= expertIdx < numExperts ? static_cast<BaseType>(routerLogits[scoreOffset + expertIdx]) : minScore;
inputIndex[i] = expertIdx;
}
if constexpr (DoSoftmaxBeforeTopK)
{
calcSoftmax(warp, inputScore);
}
// Reduce topK scores and indices for this warp
reduce_topk::reduceTopK(warp, warpTopKScore, warpTopKExpertIdx, inputScore, inputIndex, minScore);
// Normalize the scores
if constexpr (DoSoftmaxBeforeTopK)
{
if (laneIdx < topK)
{
topkValues[outputOffset + laneIdx] = static_cast<OutputT>(warpTopKScore[laneIdx]);
topkIndices[outputOffset + laneIdx] = warpTopKExpertIdx[laneIdx];
}
}
else
{
auto softmaxScore = calcSoftmax(warp,
laneIdx < topK ? static_cast<float>(warpTopKScore[laneIdx]) : static_cast<float>(minScore), laneIdx,
topK);
if (laneIdx < topK)
{
topkValues[outputOffset + laneIdx] = static_cast<OutputT>(softmaxScore);
topkIndices[outputOffset + laneIdx] = warpTopKExpertIdx[laneIdx];
}
}
} // end for tokenId
}
int nextPowerOfTwo(int num)
{
if (num <= 0)
{
return 1; // Handle invalid input
}
int power = 1;
while (power < num)
{
// Check for overflow before shifting
if (power > INT_MAX / 2)
{
return power;
}
power <<= 1;
}
return power;
}
#define CASE(MAX_NUM_EXPERTS) \
case MAX_NUM_EXPERTS: \
switch (maxNumTopExperts) \
{ \
case 1: \
kernelInstance = &customMoeRoutingKernel<InputT, OutputT, IdxT, MAX_NUM_EXPERTS, 1, DoSoftmaxBeforeTopK>; \
break; \
case 2: \
kernelInstance = &customMoeRoutingKernel<InputT, OutputT, IdxT, MAX_NUM_EXPERTS, 2, DoSoftmaxBeforeTopK>; \
break; \
case 4: \
kernelInstance = &customMoeRoutingKernel<InputT, OutputT, IdxT, MAX_NUM_EXPERTS, 4, DoSoftmaxBeforeTopK>; \
break; \
case 8: \
kernelInstance = &customMoeRoutingKernel<InputT, OutputT, IdxT, MAX_NUM_EXPERTS, 8, DoSoftmaxBeforeTopK>; \
break; \
default: kernelInstance = nullptr; break; \
} \
break;
template <typename InputT, typename OutputT, typename IdxT, bool DoSoftmaxBeforeTopK>
void invokeRenormMoeRouting(InputT* routerLogits, OutputT* topkValues, IdxT* topkIndices, int64_t const numTokens,
int64_t const numExperts, int64_t const topK, cudaStream_t const stream)
{
const uint32_t maxNumBlocks = 1024;
const uint32_t numBlocks = std::min(static_cast<uint32_t>((numTokens - 1) / WARPS_PER_BLOCK + 1), maxNumBlocks);
uint32_t maxNumExperts = nextPowerOfTwo(numExperts) < 32 ? 32 : nextPowerOfTwo(numExperts);
uint32_t maxNumTopExperts = nextPowerOfTwo(topK);
auto* kernelInstance = &customMoeRoutingKernel<InputT, OutputT, IdxT, 128, 8, DoSoftmaxBeforeTopK>;
switch (maxNumExperts)
{
CASE(32)
CASE(64)
CASE(96)
CASE(128)
default: kernelInstance = nullptr; break;
}
if (kernelInstance == nullptr)
{
TLLM_CHECK_WITH_INFO(kernelInstance != nullptr, "Can not find corresponding kernel instance.");
}
dim3 renormMoeRoutingGridDim(numBlocks);
dim3 renormMoeRoutingBlockDim(BLOCK_SIZE);
cudaLaunchConfig_t config;
config.gridDim = renormMoeRoutingGridDim;
config.blockDim = renormMoeRoutingBlockDim;
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config, kernelInstance, routerLogits, topkValues, topkIndices, static_cast<int32_t>(numTokens),
static_cast<int32_t>(numExperts), static_cast<int32_t>(topK));
sync_check_cuda_error(stream);
}
#define INSTANTIATE_RENORM_MOE_ROUTING(InputT, OutputT, IdxT, DoSoftmaxBeforeTopK) \
template void invokeRenormMoeRouting<InputT, OutputT, IdxT, DoSoftmaxBeforeTopK>(InputT * routerLogits, \
OutputT * topkValues, IdxT * topkIndices, int64_t const numTokens, int64_t const numExperts, \
int64_t const topK, cudaStream_t const stream);
INSTANTIATE_RENORM_MOE_ROUTING(float, float, int32_t, false);
INSTANTIATE_RENORM_MOE_ROUTING(half, float, int32_t, false);
#ifdef ENABLE_BF16
INSTANTIATE_RENORM_MOE_ROUTING(__nv_bfloat16, float, int32_t, false);
#endif
INSTANTIATE_RENORM_MOE_ROUTING(float, float, int32_t, true);
INSTANTIATE_RENORM_MOE_ROUTING(half, float, int32_t, true);
#ifdef ENABLE_BF16
INSTANTIATE_RENORM_MOE_ROUTING(__nv_bfloat16, float, int32_t, true);
#endif
} // namespace tensorrt_llm::kernels

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -23,7 +23,7 @@
namespace tensorrt_llm::kernels
{
template <typename InputT, typename OutputT, typename IdxT>
template <typename InputT, typename OutputT, typename IdxT, bool DoSoftmaxBeforeTopK>
void invokeRenormMoeRouting(InputT* routerLogits, OutputT* topkValues, IdxT* topkIndices, int64_t const numTokens,
int64_t const numExperts, int64_t const topK, cudaStream_t const stream);
} // namespace tensorrt_llm::kernels

View File

@ -288,15 +288,20 @@ public:
void moeGemm(GroupedGemmInput<T, WeightType, ScaleBiasType, OutputType> inputs,
TmaWarpSpecializedGroupedGemmInput hopper_inputs);
std::vector<cutlass_extensions::CutlassGemmConfig> getConfigs() const;
static std::vector<cutlass_extensions::CutlassGemmConfig> getConfigs(int sm);
static std::vector<cutlass_extensions::CutlassGemmConfig> getTmaWarpSpecializedConfigs(int sm);
static std::vector<cutlass_extensions::CutlassGemmConfig> getBlackwellConfigs(int sm);
static std::vector<cutlass_extensions::CutlassGemmConfig> getHopperConfigs(int sm);
std::vector<cutlass_extensions::CutlassGemmConfig> getConfigs(bool supports_finalize_fusion) const;
static std::vector<cutlass_extensions::CutlassGemmConfig> getConfigs(int sm, bool supports_finalize_fusion);
static std::vector<cutlass_extensions::CutlassGemmConfig> getTmaWarpSpecializedConfigs(
int sm, bool supports_finalize_fusion);
static std::vector<cutlass_extensions::CutlassGemmConfig> getAmpereConfigs(int sm);
[[nodiscard]] bool isTmaWarpSpecialized(cutlass_extensions::CutlassGemmConfig gemm_config) const;
[[nodiscard]] bool supportsTmaWarpSpecialized() const;
[[nodiscard]] bool supportsTmaWarpSpecialized() const
{
return supportsTmaWarpSpecialized(sm_);
}
[[nodiscard]] static bool supportsTmaWarpSpecialized(int sm);
[[nodiscard]] bool isFusedGatedActivation(cutlass_extensions::CutlassGemmConfig gemm_config,
ActivationType activation_type, int gemm_n, int gemm_k) const;
[[nodiscard]] bool supportsFusedGatedActivation(ActivationType activation_type, int gemm_n, int gemm_k) const;

View File

@ -228,6 +228,13 @@ struct MOEParallelismConfig
}
};
enum class MoeGemmId : int
{
Undefined = 0,
GEMM_1,
GEMM_2
};
struct QuantParams
{
// Int weight only quantization params
@ -446,7 +453,7 @@ public:
virtual void setTactic(std::optional<cutlass_extensions::CutlassGemmConfig> gemm1_config,
std::optional<cutlass_extensions::CutlassGemmConfig> gemm2_config)
= 0;
virtual std::vector<cutlass_extensions::CutlassGemmConfig> getTactics() = 0;
virtual std::vector<cutlass_extensions::CutlassGemmConfig> getTactics(MoeGemmId gemm_id) = 0;
virtual void runMoe(void const* input_activations, void const* input_sf, bool const swizzled_input_sf,
int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights,
@ -593,15 +600,15 @@ public:
gemm2_config_ = std::move(gemm2_config);
}
std::vector<cutlass_extensions::CutlassGemmConfig> getTactics() override
std::vector<cutlass_extensions::CutlassGemmConfig> getTactics(MoeGemmId gemm_id) override
{
return moe_gemm_runner_.getConfigs();
return moe_gemm_runner_.getConfigs(gemm_id == MoeGemmId::GEMM_2 && mayHaveFinalizeFused());
}
static std::vector<cutlass_extensions::CutlassGemmConfig> getTactics(int sm)
static std::vector<cutlass_extensions::CutlassGemmConfig> getTactics(int sm, MoeGemmId gemm_id)
{
using RunnerType = decltype(moe_gemm_runner_);
return RunnerType::getConfigs(sm);
return RunnerType::getConfigs(sm, gemm_id == MoeGemmId::GEMM_2 && Self::mayHaveFinalizeFused(sm));
}
void runMoe(void const* input_activations, void const* input_sf, bool const swizzled_input_sf,
@ -798,6 +805,12 @@ private:
&& !use_w4_groupwise;
}
static bool mayHaveFinalizeFused(int sm)
{
using RunnerType = decltype(moe_gemm_runner_);
return RunnerType::supportsTmaWarpSpecialized(sm) && sm >= 90 && !use_w4_groupwise;
}
// TODO: This should eventually take the quant params to give more flexibility
static auto getScalingType()
{
@ -895,12 +908,7 @@ struct GemmProfilerBackend
{
public:
using Config = cutlass_extensions::CutlassGemmConfig;
enum class GemmToProfile
{
Undefined = 0,
GEMM_1,
GEMM_2
};
using GemmToProfile = MoeGemmId;
void init(CutlassMoeFCRunnerInterface& runner, GemmToProfile gemm_to_profile, nvinfer1::DataType dtype,
nvinfer1::DataType wtype, nvinfer1::DataType otype, int num_experts, int k, int64_t hidden_size,
@ -951,7 +959,6 @@ public:
CutlassMoeFCRunnerInterface* mInterface;
GemmToProfile mGemmToProfile = GemmToProfile::Undefined;
std::vector<Config> mAllTacticsSaved;
int mSM{};
int64_t mNumExperts{};
int64_t mNumExpertsPerNode{};
@ -972,7 +979,7 @@ public:
// This will be a unique value for every iteration of warmup and actual bench
constexpr static int64_t NUM_ROUTING_SAMPLES = 16;
std::array<TmaWarpSpecializedGroupedGemmInput, NUM_ROUTING_SAMPLES> mTmaInputCache;
std::array<std::array<TmaWarpSpecializedGroupedGemmInput, 2>, NUM_ROUTING_SAMPLES> mTmaInputCache;
QuantParams mQuantParams;
bool mBias{};
@ -985,7 +992,8 @@ public:
private:
void prepareRouting(int num_tokens, char* workspace, cudaStream_t stream);
void prepareQuantParams(int num_tokens, char* workspace, cudaStream_t stream);
void prepareTmaWsInputs(int num_tokens, char* workspace, void const* expert_weights, cudaStream_t stream);
void prepareTmaWsInputs(int num_tokens, char* workspace, void const* expert_weights,
TmaWarpSpecializedGroupedGemmInput::EpilogueFusion fusion, cudaStream_t stream);
};
// Populates a buffer with random values for use with MOE benchmarking

View File

@ -57,7 +57,6 @@ namespace kernels
{
namespace cutlass_kernels
{
template <typename T, typename arch, typename ThreadblockShape, typename WarpShape, int Stages>
void genericInt8GemmKernelLauncher(int8_t const* A, int8_t const* B, tk::QuantMode quantOption, float const* alphaCol,
float const* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace,

View File

@ -475,17 +475,18 @@ void dispatchMoeGemmToCutlass(GroupedGemmInput<T, WeightType, GemmOutputType, Ge
}
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
std::vector<cutlass_extensions::CutlassGemmConfig>
MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getConfigs() const
std::vector<cutlass_extensions::CutlassGemmConfig> MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getConfigs(
bool supports_finalize_fusion) const
{
return getConfigs(sm_);
return getConfigs(sm_, supports_finalize_fusion);
}
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
std::vector<cutlass_extensions::CutlassGemmConfig> MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getConfigs(
int sm)
int sm, bool supports_finalize_fusion)
{
std::vector<cutlass_extensions::CutlassGemmConfig> candidate_configs = getTmaWarpSpecializedConfigs(sm);
std::vector<cutlass_extensions::CutlassGemmConfig> candidate_configs
= getTmaWarpSpecializedConfigs(sm, supports_finalize_fusion);
std::vector<cutlass_extensions::CutlassGemmConfig> ampere_configs = getAmpereConfigs(sm);
std::copy(ampere_configs.begin(), ampere_configs.end(), std::back_inserter(candidate_configs));
return candidate_configs;
@ -521,7 +522,8 @@ MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getAmpereConfigs(int sm
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
std::vector<cutlass_extensions::CutlassGemmConfig>
MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getTmaWarpSpecializedConfigs(int sm)
MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getTmaWarpSpecializedConfigs(
int sm, bool supports_finalize_fusion)
{
using tensorrt_llm::cutlass_extensions::CutlassGemmConfig;
static constexpr auto weight_only_flag
@ -568,6 +570,17 @@ MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getTmaWarpSpecializedCo
= tensorrt_llm::kernels::cutlass_kernels::get_candidate_configs(100, max_split_k, config_type_param);
std::copy(sm100_configs.begin(), sm100_configs.end(), std::back_inserter(tma_ws_configs));
}
if (supports_finalize_fusion)
{
// Duplicate the configs and set the epilogue fusion type to FINALIZE
auto finalize_configs = tma_ws_configs;
std::transform(finalize_configs.begin(), finalize_configs.end(), std::back_inserter(tma_ws_configs),
[](auto& config)
{
config.epilogue_fusion_type = cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE;
return config;
});
}
return tma_ws_configs;
}
@ -580,13 +593,11 @@ bool MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::isTmaWarpSpecializ
}
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
bool MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::supportsTmaWarpSpecialized() const
bool MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::supportsTmaWarpSpecialized(int sm)
{
return (sm_ == 90 && tensorrt_llm::kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType>())
|| (sm_ >= 100 && sm_ < 120
&& tensorrt_llm::kernels::cutlass_kernels::isValidBlackwellMOESpecialisation<T, WeightType>())
|| ((sm_ == 120 || sm_ == 121)
&& tensorrt_llm::kernels::cutlass_kernels::isValidSM120MOESpecialisation<T, WeightType>());
return (sm == 90 && kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType>())
|| (sm >= 100 && sm < 120 && kernels::cutlass_kernels::isValidBlackwellMOESpecialisation<T, WeightType>())
|| ((sm == 120 || sm == 121) && kernels::cutlass_kernels::isValidSM120MOESpecialisation<T, WeightType>());
}
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
@ -833,7 +844,9 @@ size_t MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::calcMaxWorkspace
if constexpr (kernels::cutlass_kernels::isValidTmaWarpSpecializedMOESpecialisation<T, WeightType>() && !use_w4afp8
&& !use_wfp4a16)
{
auto configs = getTmaWarpSpecializedConfigs(sm_);
// Finalize fusion may not actually be supported by the kernel,
// if they are not we will catch the error and skip them
auto configs = getTmaWarpSpecializedConfigs(sm_, true);
auto fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE;
if constexpr (use_wfp4afp8)
{

View File

@ -2847,9 +2847,10 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
expert_first_token_offset_ = getWsPtr(int64_t{}, "expert_first_token_offset");
// We check if the provided config uses fused finalize and disable it if it does not
bool const gemm2_using_tma_ws = moe_gemm_runner_.isTmaWarpSpecialized(*gemm2_config_);
bool gemm2_using_finalize_fusion
= gemm2_config_->epilogue_fusion_type == cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE;
permuted_token_final_scales_
= (gemm2_using_tma_ws && mayHaveFinalizeFused()) ? getWsPtr(float{}, "permuted_token_final_scales") : nullptr;
= gemm2_using_finalize_fusion ? getWsPtr(float{}, "permuted_token_final_scales") : nullptr;
bool const is_gated_activation = isGatedActivation(activation_type);
bool const gemm1_using_fused_moe
@ -4006,8 +4007,12 @@ CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>::
bool apply_bias = parallelism_config.tp_rank == 0;
auto* fc2_bias = apply_bias ? fc2_expert_biases : nullptr;
bool gemm2_using_finalize_fusion = gemm2_config_->epilogue_fusion_type
== cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE;
bool using_fused_finalize
= use_fused_finalize_ && gemm2_config_->sm_version >= 90 && !use_w4_groupwise && !use_lora;
= use_fused_finalize_ && gemm2_using_finalize_fusion && !use_w4_groupwise && !use_lora;
TLLM_CHECK_WITH_INFO(using_fused_finalize == gemm2_using_finalize_fusion,
"GEMM2 tactic requests finalize fusion, but the runner is not configured to use it");
if (using_fused_finalize)
{
assert(min_latency_mode == false);
@ -4550,14 +4555,26 @@ void GemmProfilerBackend::prepareQuantParams(int num_tokens, char* workspace_ptr
}
}
void GemmProfilerBackend::prepareTmaWsInputs(
int num_tokens, char* workspace_ptr_char, void const* expert_weights, cudaStream_t stream)
void GemmProfilerBackend::prepareTmaWsInputs(int num_tokens, char* workspace_ptr_char, void const* expert_weights,
TmaWarpSpecializedGroupedGemmInput::EpilogueFusion fusion, cudaStream_t stream)
{
if (mSM < 90)
{
return;
}
bool use_w4afp8 = (mDType == nvinfer1::DataType::kFP8 && mWType == nvinfer1::DataType::kINT4);
bool use_wfp4a16 = ((mDType == nvinfer1::DataType::kHALF || mDType == nvinfer1::DataType::kBF16)
&& mWType == nvinfer1::DataType::kUINT8);
bool use_w4_groupwise = use_w4afp8 || use_wfp4a16;
bool const use_finalize_fusion = fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE;
bool const finalize_fusion_not_supported = !mInterface->use_fused_finalize_ || mMinLatencyMode || use_w4_groupwise
|| mGemmToProfile != GemmToProfile::GEMM_2;
if (use_finalize_fusion && finalize_fusion_not_supported)
{
return;
}
auto workspaces = getProfilerWorkspaces(num_tokens, mSM >= 90);
#define GET_WS_PTR(type, name) \
@ -4596,28 +4613,24 @@ void GemmProfilerBackend::prepareTmaWsInputs(
size_t num_expanded_tokens = num_tokens * mK;
for (int64_t i = 0; i < NUM_ROUTING_SAMPLES; i++)
{
mTmaInputCache[i].configureWorkspace(tma_ws_input_workspace, mNumExpertsPerNode, gemm_workspace,
// Note: Even though we have separate TMA WS inputs for finalize fusion on/off we reuse the same pointers to
// save space.
auto& cache_element = mTmaInputCache[i][use_finalize_fusion];
cache_element.configureWorkspace(tma_ws_input_workspace, mNumExpertsPerNode, gemm_workspace,
workspaces.at("gemm_workspace").first, mScalingType);
tma_ws_input_workspace += tma_ws_size;
int64_t* expert_first_token_offset = expert_first_token_offset_base + i * (mNumExpertsPerNode + 1);
int* permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_base + i * num_expanded_tokens;
auto& gemm1_tma_ws_input = mGemmToProfile == GemmToProfile::GEMM_1 ? mTmaInputCache[i] : dummy_tma_ws_input;
auto& gemm2_tma_ws_input = mGemmToProfile == GemmToProfile::GEMM_2 ? mTmaInputCache[i] : dummy_tma_ws_input;
auto& gemm1_tma_ws_input = mGemmToProfile == GemmToProfile::GEMM_1 ? cache_element : dummy_tma_ws_input;
auto& gemm2_tma_ws_input = mGemmToProfile == GemmToProfile::GEMM_2 ? cache_element : dummy_tma_ws_input;
if (mSM >= 90)
{
/* GEMM1 */
gemm1_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE;
gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE;
bool use_w4afp8 = (mDType == nvinfer1::DataType::kFP8 && mWType == nvinfer1::DataType::kINT4);
bool use_wfp4a16 = ((mDType == nvinfer1::DataType::kHALF || mDType == nvinfer1::DataType::kBF16)
&& mWType == nvinfer1::DataType::kUINT8);
bool use_w4_groupwise = use_w4afp8 || use_wfp4a16;
bool using_fused_finalize
= mInterface->use_fused_finalize_ && mSM >= 90 && !mMinLatencyMode && !use_w4_groupwise;
if (using_fused_finalize)
if (use_finalize_fusion)
{
assert(!mMinLatencyMode);
gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE;
@ -4652,7 +4665,6 @@ void GemmProfilerBackend::prepareTmaWsInputs(
void GemmProfilerBackend::prepare(
int num_tokens, char* workspace_ptr_char, void const* expert_weights, cudaStream_t stream)
{
mAllTacticsSaved = mInterface->getTactics();
mSampleIndex = 0;
auto workspace_size = getWorkspaceSize(num_tokens);
@ -4660,7 +4672,10 @@ void GemmProfilerBackend::prepare(
prepareRouting(num_tokens, workspace_ptr_char, stream);
prepareQuantParams(num_tokens, workspace_ptr_char, stream);
prepareTmaWsInputs(num_tokens, workspace_ptr_char, expert_weights, stream);
prepareTmaWsInputs(num_tokens, workspace_ptr_char, expert_weights,
TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE, stream);
prepareTmaWsInputs(num_tokens, workspace_ptr_char, expert_weights,
TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE, stream);
}
size_t GemmProfilerBackend::getWorkspaceSize(int maxM)
@ -4724,7 +4739,9 @@ void GemmProfilerBackend::runProfiler(int original_num_tokens, Config const& tac
TmaWarpSpecializedGroupedGemmInput tma_ws_input_template;
if (tactic.is_tma_warp_specialized)
{
tma_ws_input_template = mTmaInputCache[mSampleIndex];
tma_ws_input_template = mTmaInputCache[mSampleIndex][tactic.epilogue_fusion_type
== cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE];
TLLM_CHECK_WITH_INFO(tma_ws_input_template.isValid(), "TMA WS input template is not initialized");
}
mInterface->is_profiler = true;

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:86586b9f6845e91e8ba0accad53a5a3418c50d8fd30ad49fa8837470c72b5dcf
size 67051604
oid sha256:d6a3f6adef11003f794a6cec1235d0c622ead71b4e801a89866e91dfd91bb30c
size 67053244

View File

@ -1,2 +1,2 @@
568cb6ca2413c93b0f5839dd05577c0c57bc4b5f2359366c79d0ace665de4bd6 libtensorrt_llm_internal_cutlass_kernels_static.a
commit 9c0a42825905952beaf9b35d5a35d58de1a123fa
317a25037093a6f3d156ffa58a68bce53071ef68dacdcb04cc0aaeea80b64e76 libtensorrt_llm_internal_cutlass_kernels_static.a
commit 444ef1b3b06cdc7ee66b4e612ce26ad25967440b

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6489751f16a4dadf42664738ded03fbbd60195619f2d5f80af8190554318257d
size 66872936
oid sha256:489fb557b78062efedd1514f2995fafb9216bb0e0068a550e86763efb9d5eee9
size 66874608

View File

@ -1,2 +1,2 @@
813c237a565664b2acf2313f0e436f66f24deeb16a84d273dc007af55795e55f libtensorrt_llm_internal_cutlass_kernels_static.a
commit 9c0a42825905952beaf9b35d5a35d58de1a123fa
5a31acd0fb1415196bff71fa4a8d1dded147e15ea10821cc46c85684c66986ee libtensorrt_llm_internal_cutlass_kernels_static.a
commit 444ef1b3b06cdc7ee66b4e612ce26ad25967440b

View File

@ -0,0 +1,205 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#ifndef TRTLLM_MOETOPKFUNCS_CUH_H
#define TRTLLM_MOETOPKFUNCS_CUH_H
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cub/cub.cuh>
#include "tensorrt_llm/kernels/archCondition.h"
namespace tensorrt_llm::kernels
{
namespace reduce_topk
{
namespace cg = cooperative_groups;
static constexpr int kWARP_SIZE = 32;
static constexpr bool kTLLM_GEN_HAS_FAST_REDUX = tensorrt_llm::kernels::arch::is_major_v<10>;
template <typename T_>
struct TopKRedType
{
using T = T_;
static_assert(std::is_same_v<T, float> || std::is_same_v<T, half> || std::is_same_v<T, __nv_bfloat16>
|| std::is_same_v<T, int>,
"Top K reduction only implemented for int, float, float16 and bfloat16");
using TypeCmp = std::conditional_t<sizeof(T) == 4, uint64_t, uint32_t>;
using IdxT = std::conditional_t<sizeof(T) == 4, int32_t, int16_t>;
static constexpr int kMoveBits = (sizeof(T) == 4) ? 32 : 16;
static constexpr int kMaxIdx = 65535;
TypeCmp compValIdx;
static __host__ __device__ inline TypeCmp makeCmpVal(T val, int32_t idx = 0)
{
auto valueBits = cub::Traits<T>::TwiddleIn(reinterpret_cast<typename cub::Traits<T>::UnsignedBits&>(val));
TypeCmp compactTmp = reinterpret_cast<TypeCmp&>(valueBits);
compactTmp = (compactTmp << kMoveBits) | (0xFFFF & (kMaxIdx - idx));
// Use 65535 minus idx to give higher priority to elements with smaller indices.
return compactTmp;
}
static __host__ __device__ void unpack(T& value, int32_t& index, TypeCmp cmp)
{
// Since “65535-idx” is always smaller than 65536 and positive, we can directly use it as the lower 16 bits
index = kMaxIdx - static_cast<int32_t>((cmp & 0xFFFF));
auto compactTmp = cmp >> kMoveBits;
auto valueBits
= cub::Traits<T>::TwiddleOut(reinterpret_cast<typename cub::Traits<T>::UnsignedBits&>(compactTmp));
value = reinterpret_cast<T&>(valueBits);
}
__host__ __device__ TopKRedType() = default;
__host__ __device__ TopKRedType(T val, int32_t idx)
: compValIdx(makeCmpVal(val, idx))
{
}
__host__ __device__ operator TypeCmp() const noexcept
{
return compValIdx;
}
__device__ inline TypeCmp reduce(cg::thread_block_tile<kWARP_SIZE> const& warp)
{
if constexpr (!kTLLM_GEN_HAS_FAST_REDUX || sizeof(TypeCmp) == 8)
{
return cg::reduce(warp, compValIdx, cg::greater<TypeCmp>{});
}
else
{
TypeCmp result;
asm("redux.sync.max.u32 %0, %1, 0xffffffff;\n" : "=r"(result) : "r"(compValIdx));
return result;
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int K_, bool Enable_>
struct TopKIdx
{
// by default, empty
};
template <int K_>
struct TopKIdx<K_, true>
{
static constexpr int K = K_;
int32_t val[K];
};
////////////////////////////////////////////////////////////////////////////////////////////////////
#define TOPK_SWAP(I, J) \
{ \
auto pairMin = min(topK[I].compValIdx, topK[J].compValIdx); \
auto pairMax = max(topK[I].compValIdx, topK[J].compValIdx); \
topK[I].compValIdx = pairMax; \
topK[J].compValIdx = pairMin; \
}
template <int N, typename RedType>
struct Sort;
template <typename RedType>
struct Sort<1, RedType>
{
static __device__ void run(RedType* topK) {}
};
template <typename RedType>
struct Sort<2, RedType>
{
static __device__ void run(RedType* topK)
{
TOPK_SWAP(0, 1);
}
};
template <typename RedType>
struct Sort<3, RedType>
{
static __device__ void run(RedType* topK)
{
TOPK_SWAP(0, 1);
TOPK_SWAP(1, 2);
TOPK_SWAP(0, 1);
}
};
template <typename RedType>
struct Sort<4, RedType>
{
static __device__ void run(RedType* topK)
{
TOPK_SWAP(0, 2);
TOPK_SWAP(1, 3);
TOPK_SWAP(0, 1);
TOPK_SWAP(2, 3);
TOPK_SWAP(1, 2);
}
};
template <int K, typename Type, int N, bool IsSorted = false>
__device__ void reduceTopK(cg::thread_block_tile<kWARP_SIZE> const& warp, Type (&out)[K], int32_t (&outIdx)[K],
Type (&value)[N], int32_t (&idx)[N], Type minValue)
{
static_assert(K > 0, "Top K must have K > 0");
static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE");
static_assert(N > 0, "Top K must have N > 0");
static_assert(N < 5, "Only support candidates number less than or equal to 128");
using RedType = TopKRedType<Type>;
RedType topK[N];
#pragma unroll
for (int nn = 0; nn < N; ++nn)
{
topK[nn] = RedType{value[nn], idx[nn]};
}
if constexpr (!IsSorted)
{
Sort<N, RedType>::run(topK);
}
typename RedType::TypeCmp packedMax{};
#pragma unroll
for (int kk = 0; kk < K; ++kk)
{
bool update = kk > 0 && packedMax == topK[0].compValIdx;
#pragma unroll
for (int nn = 0; nn < N; ++nn)
{
topK[nn] = update && nn == N - 1 ? RedType{minValue, idx[nn]} : update ? topK[nn + 1] : topK[nn];
}
// get the next largest value
packedMax = topK[0].reduce(warp);
RedType::unpack(out[kk], outIdx[kk], packedMax);
}
};
#undef TOPK_SWAP
} // namespace reduce_topk
} // namespace tensorrt_llm::kernels
#endif // TRTLLM_MOETOPKFUNCS_CUH_H

View File

@ -1,376 +0,0 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
#include "tensorrt_llm/common/envUtils.h"
#include "tensorrt_llm/kernels/archCondition.h"
#include "tensorrt_llm/kernels/renormMoeRoutingKernels.h"
#include <climits> // For INT_MAX
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cub/cub.cuh>
#include <cuda/std/limits> // For numeric_limits
#include <math.h>
namespace cg = cooperative_groups;
using namespace tensorrt_llm::common;
namespace tensorrt_llm::kernels
{
static constexpr int BLOCK_SIZE = 1024;
static constexpr int WARP_SIZE = 32;
static constexpr int WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE;
namespace reduce_topk
{
static constexpr bool TLLM_GEN_HAS_FAST_REDUX = tensorrt_llm::kernels::arch::is_major_v<10>;
template <typename T_>
struct TopKRedType
{
using T = T_;
static_assert(std::is_same_v<T, float> || std::is_same_v<T, half> || std::is_same_v<T, __nv_bfloat16>,
"Top K reduction only implemented for float, float16 and bfloat16");
using TypeCmp = std::conditional_t<sizeof(T) == 4, uint64_t, uint32_t>;
using IdxT = std::conditional_t<sizeof(T) == 4, int32_t, int16_t>;
static constexpr int moveBits = (sizeof(T) == 4) ? 32 : 16;
static constexpr int maxIdx = 65535;
TypeCmp compValIdx;
static __host__ __device__ inline TypeCmp makeCmpVal(T val, int32_t idx = 0)
{
auto valueBits = cub::Traits<T>::TwiddleIn(reinterpret_cast<typename cub::Traits<T>::UnsignedBits&>(val));
TypeCmp compactTmp = reinterpret_cast<TypeCmp&>(valueBits);
compactTmp = (compactTmp << moveBits) | (0xFFFF & (maxIdx - idx));
// Use 65535 minus idx to give higher priority to elements with smaller indices.
return compactTmp;
}
static __host__ __device__ void unpack(T& value, int32_t& index, TypeCmp cmp)
{
// Since “65535-idx” is always smaller than 65536 and positive, we can directly use it as the lower 16 bits
index = maxIdx - static_cast<int32_t>((cmp & 0xFFFF));
auto compactTmp = cmp >> moveBits;
auto valueBits
= cub::Traits<T>::TwiddleOut(reinterpret_cast<typename cub::Traits<T>::UnsignedBits&>(compactTmp));
value = reinterpret_cast<T&>(valueBits);
}
__host__ __device__ TopKRedType() = default;
__host__ __device__ TopKRedType(T val, int32_t idx)
: compValIdx(makeCmpVal(val, idx))
{
}
__host__ __device__ operator TypeCmp() const noexcept
{
return compValIdx;
}
__device__ inline TypeCmp reduce(cg::thread_block_tile<WARP_SIZE> const& warp)
{
if constexpr (!TLLM_GEN_HAS_FAST_REDUX || sizeof(TypeCmp) == 8)
{
return cg::reduce(warp, compValIdx, cg::greater<TypeCmp>{});
}
else
{
TypeCmp result;
asm("redux.sync.max.u32 %0, %1, 0xffffffff;\n" : "=r"(result) : "r"(compValIdx));
return result;
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int K_, bool Enable_>
struct TopKIdx
{
// by default, empty
};
template <int K_>
struct TopKIdx<K_, true>
{
static constexpr int K = K_;
int32_t val[K];
};
////////////////////////////////////////////////////////////////////////////////////////////////////
#define TOPK_SWAP(I, J) \
{ \
auto pairMin = min(topK[I].compValIdx, topK[J].compValIdx); \
auto pairMax = max(topK[I].compValIdx, topK[J].compValIdx); \
topK[I].compValIdx = pairMax; \
topK[J].compValIdx = pairMin; \
}
template <int N, typename RedType>
struct Sort;
template <typename RedType>
struct Sort<1, RedType>
{
static __device__ void run(RedType* topK) {}
};
template <typename RedType>
struct Sort<2, RedType>
{
static __device__ void run(RedType* topK)
{
TOPK_SWAP(0, 1);
}
};
template <typename RedType>
struct Sort<3, RedType>
{
static __device__ void run(RedType* topK)
{
TOPK_SWAP(0, 1);
TOPK_SWAP(1, 2);
TOPK_SWAP(0, 1);
}
};
template <typename RedType>
struct Sort<4, RedType>
{
static __device__ void run(RedType* topK)
{
TOPK_SWAP(0, 2);
TOPK_SWAP(1, 3);
TOPK_SWAP(0, 1);
TOPK_SWAP(2, 3);
TOPK_SWAP(1, 2);
}
};
template <int K, typename Type, int N, bool IsSorted = false>
__device__ void reduceTopK(cg::thread_block_tile<WARP_SIZE> const& warp, Type (&out)[K], int32_t (&outIdx)[K],
Type (&value)[N], int32_t (&idx)[N], Type minValue)
{
static_assert(K > 0, "Top K must have K > 0");
static_assert(K < WARP_SIZE, "Top K must have K < WARP_SIZE");
static_assert(N > 0, "Top K must have N > 0");
static_assert(N < 5, "Only support candidates number less than or equal to 128");
using RedType = TopKRedType<Type>;
RedType topK[N];
#pragma unroll
for (int nn = 0; nn < N; ++nn)
{
topK[nn] = RedType{value[nn], idx[nn]};
}
if constexpr (!IsSorted)
{
Sort<N, RedType>::run(topK);
}
typename RedType::TypeCmp packedMax{};
#pragma unroll
for (int kk = 0; kk < K; ++kk)
{
bool update = kk > 0 && packedMax == topK[0].compValIdx;
#pragma unroll
for (int nn = 0; nn < N; ++nn)
{
topK[nn] = update && nn == N - 1 ? RedType{minValue, idx[nn]} : update ? topK[nn + 1] : topK[nn];
}
// get the next largest value
packedMax = topK[0].reduce(warp);
RedType::unpack(out[kk], outIdx[kk], packedMax);
}
};
#undef TOPK_SWAP
} // end of namespace reduce_topk
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
__device__ T calcSoftmax(cg::thread_block_tile<WARP_SIZE> const& warp, T score, int32_t laneIdx, int32_t NumTopExperts)
{
T maxScore = T{-INFINITY};
if (laneIdx < NumTopExperts)
{
maxScore = score >= maxScore ? score : maxScore;
}
maxScore = cg::reduce(warp, maxScore, cg::greater<T>());
float sumScore = float{0.f};
float newScore;
// Get the summation of scores for each token
if (laneIdx < NumTopExperts)
{
newScore = static_cast<float>(score) - static_cast<float>(maxScore);
newScore = static_cast<float>(exp(newScore));
sumScore += newScore;
}
sumScore = cg::reduce(warp, sumScore, cg::plus<float>());
if (laneIdx < NumTopExperts)
{
score = static_cast<T>(newScore / sumScore);
}
return score;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename InputT, typename OutputT, typename IdxT, int MaxNumExperts, int MaxNumTopExperts>
__global__ void renormMoeRoutingKernel(InputT* routerLogits, OutputT* topkValues, IdxT* topkIndices,
int32_t const numTokens, int32_t const numExperts, int32_t const topK)
{
uint32_t const blockRank = blockIdx.x;
uint32_t const tIdx = BLOCK_SIZE * blockRank + threadIdx.x;
uint32_t const warpIdx = tIdx / WARP_SIZE;
uint32_t const laneIdx = tIdx % WARP_SIZE;
uint32_t const warpNum = gridDim.x * WARPS_PER_BLOCK;
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<WARP_SIZE>(block);
InputT minScore = InputT{-INFINITY};
for (uint32_t tokenId = warpIdx; tokenId < numTokens; tokenId += warpNum)
{
auto scoreOffset = tokenId * numExperts;
auto outputOffset = tokenId * topK;
InputT inputScore[MaxNumExperts / WARP_SIZE];
IdxT inputIndex[MaxNumExperts / WARP_SIZE];
InputT warpTopKScore[MaxNumTopExperts];
IdxT warpTopKExpertIdx[MaxNumTopExperts];
// Load scores and indices for this warp
for (uint32_t i = 0; i < MaxNumExperts / WARP_SIZE; ++i)
{
auto expertIdx = i * WARP_SIZE + laneIdx;
inputScore[i]
= expertIdx < numExperts ? static_cast<InputT>(routerLogits[scoreOffset + expertIdx]) : minScore;
inputIndex[i] = expertIdx;
}
// Reduce topK scores and indices for this warp
reduce_topk::reduceTopK(warp, warpTopKScore, warpTopKExpertIdx, inputScore, inputIndex, minScore);
// Perform softmax on topK scores
auto score = calcSoftmax(warp,
laneIdx < topK ? static_cast<float>(warpTopKScore[laneIdx]) : static_cast<float>(minScore), laneIdx, topK);
if (laneIdx < topK)
{
topkValues[outputOffset + laneIdx] = static_cast<OutputT>(score);
topkIndices[outputOffset + laneIdx] = warpTopKExpertIdx[laneIdx];
}
} // end for tokenId
}
int nextPowerOfTwo(int num)
{
if (num <= 0)
{
return 1; // Handle invalid input
}
int power = 1;
while (power < num)
{
// Check for overflow before shifting
if (power > INT_MAX / 2)
{
return power;
}
power <<= 1;
}
return power;
}
#define CASE(MAX_NUM_EXPERTS) \
case MAX_NUM_EXPERTS: \
switch (maxNumTopExperts) \
{ \
case 1: kernelInstance = &renormMoeRoutingKernel<InputT, OutputT, IdxT, MAX_NUM_EXPERTS, 1>; break; \
case 2: kernelInstance = &renormMoeRoutingKernel<InputT, OutputT, IdxT, MAX_NUM_EXPERTS, 2>; break; \
case 4: kernelInstance = &renormMoeRoutingKernel<InputT, OutputT, IdxT, MAX_NUM_EXPERTS, 4>; break; \
case 8: kernelInstance = &renormMoeRoutingKernel<InputT, OutputT, IdxT, MAX_NUM_EXPERTS, 8>; break; \
default: kernelInstance = nullptr; break; \
} \
break;
template <typename InputT, typename OutputT, typename IdxT>
void invokeRenormMoeRouting(InputT* routerLogits, OutputT* topkValues, IdxT* topkIndices, int64_t const numTokens,
int64_t const numExperts, int64_t const topK, cudaStream_t const stream)
{
const uint32_t maxNumBlocks = 1024;
const uint32_t numBlocks = std::min(static_cast<uint32_t>((numTokens - 1) / WARPS_PER_BLOCK + 1), maxNumBlocks);
uint32_t maxNumExperts = nextPowerOfTwo(numExperts) < 32 ? 32 : nextPowerOfTwo(numExperts);
uint32_t maxNumTopExperts = nextPowerOfTwo(topK);
auto* kernelInstance = &renormMoeRoutingKernel<InputT, OutputT, IdxT, 128, 8>;
switch (maxNumExperts)
{
CASE(32)
CASE(64)
CASE(96)
CASE(128)
default: kernelInstance = nullptr; break;
}
if (kernelInstance == nullptr)
{
TLLM_CHECK_WITH_INFO(kernelInstance != nullptr, "Can not find corresponding kernel instance.");
}
dim3 renormMoeRoutingGridDim(numBlocks);
dim3 renormMoeRoutingBlockDim(BLOCK_SIZE);
cudaLaunchConfig_t config;
config.gridDim = renormMoeRoutingGridDim;
config.blockDim = renormMoeRoutingBlockDim;
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config, kernelInstance, routerLogits, topkValues, topkIndices, static_cast<int32_t>(numTokens),
static_cast<int32_t>(numExperts), static_cast<int32_t>(topK));
sync_check_cuda_error(stream);
}
#define INSTANTIATE_RENORM_MOE_ROUTING(InputT, OutputT, IdxT) \
template void invokeRenormMoeRouting<InputT, OutputT, IdxT>(InputT * routerLogits, OutputT * topkValues, \
IdxT * topkIndices, int64_t const numTokens, int64_t const numExperts, int64_t const topK, \
cudaStream_t const stream);
INSTANTIATE_RENORM_MOE_ROUTING(float, float, int32_t);
INSTANTIATE_RENORM_MOE_ROUTING(half, float, int32_t);
#ifdef ENABLE_BF16
INSTANTIATE_RENORM_MOE_ROUTING(__nv_bfloat16, float, int32_t);
#endif
} // namespace tensorrt_llm::kernels

View File

@ -22,11 +22,17 @@
*/
#include <cuda_runtime_api.h>
#include "moeTopKFuncs.cuh"
#include "topkLastDim.h"
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <cub/cub.cuh>
#include <cuda/atomic>
#include <cuda/std/limits>
#include <limits>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <type_traits>
namespace tensorrt_llm
{
@ -203,12 +209,12 @@ __host__ __device__ IdxT calc_buf_len(IdxT len)
* @param len the number of elements to read
* @param f the lambda taking two arguments (T x, IdxT idx)
*/
template <typename T, typename idxT, typename Func>
__device__ void vectorized_process(size_t thread_rank, size_t num_threads, T const* in, idxT len, Func f)
template <typename T, typename IdxT, typename Func>
__device__ void vectorized_process(size_t thread_rank, size_t num_threads, T const* in, IdxT len, Func f)
{
if constexpr (sizeof(T) >= sizeof(WideT))
{
for (idxT i = thread_rank; i < len; i += num_threads)
for (IdxT i = thread_rank; i < len; i += num_threads)
{
f(in[i], i);
}
@ -233,12 +239,12 @@ __device__ void vectorized_process(size_t thread_rank, size_t num_threads, T con
skip_cnt = len;
}
WideT const* in_cast = reinterpret_cast<decltype(in_cast)>(in + skip_cnt);
const idxT len_cast = (len - skip_cnt) / items_per_scalar;
const IdxT len_cast = (len - skip_cnt) / items_per_scalar;
for (idxT i = thread_rank; i < len_cast; i += num_threads)
for (IdxT i = thread_rank; i < len_cast; i += num_threads)
{
wide.scalar = in_cast[i];
const idxT real_i = skip_cnt + i * items_per_scalar;
const IdxT real_i = skip_cnt + i * items_per_scalar;
#pragma unroll
for (int j = 0; j < items_per_scalar; ++j)
{
@ -258,7 +264,7 @@ __device__ void vectorized_process(size_t thread_rank, size_t num_threads, T con
// and so
// len - (skip_cnt + len_cast * items_per_scalar) < items_per_scalar <= WARP_SIZE
// no need to use loop
const idxT remain_i = skip_cnt + len_cast * items_per_scalar + thread_rank;
const IdxT remain_i = skip_cnt + len_cast * items_per_scalar + thread_rank;
if (remain_i < len)
{
f(in[remain_i], remain_i);
@ -267,14 +273,14 @@ __device__ void vectorized_process(size_t thread_rank, size_t num_threads, T con
}
// sync_width should >= WARP_SIZE
template <typename T, typename idxT, typename Func>
__device__ void vectorized_process(T const* in, idxT len, Func f, int sync_width)
template <typename T, typename IdxT, typename Func>
__device__ void vectorized_process(T const* in, IdxT len, Func f, int sync_width)
{
const idxT stride = blockDim.x * gridDim.x;
const idxT tid = blockIdx.x * blockDim.x + threadIdx.x;
const IdxT stride = blockDim.x * gridDim.x;
const IdxT tid = blockIdx.x * blockDim.x + threadIdx.x;
if constexpr (sizeof(T) >= sizeof(WideT))
{
for (idxT i = tid; i < len; i += stride)
for (IdxT i = tid; i < len; i += stride)
{
f(in[i], i, true);
}
@ -298,17 +304,17 @@ __device__ void vectorized_process(T const* in, idxT len, Func f, int sync_width
skip_cnt = len;
}
WideT const* in_cast = reinterpret_cast<decltype(in_cast)>(in + skip_cnt);
const idxT len_cast = (len - skip_cnt) / items_per_scalar;
const IdxT len_cast = (len - skip_cnt) / items_per_scalar;
const idxT len_cast_for_sync = ((len_cast - 1) / sync_width + 1) * sync_width;
for (idxT i = tid; i < len_cast_for_sync; i += stride)
const IdxT len_cast_for_sync = ((len_cast - 1) / sync_width + 1) * sync_width;
for (IdxT i = tid; i < len_cast_for_sync; i += stride)
{
bool valid = i < len_cast;
if (valid)
{
wide.scalar = in_cast[i];
}
const idxT real_i = skip_cnt + i * items_per_scalar;
const IdxT real_i = skip_cnt + i * items_per_scalar;
#pragma unroll
for (int j = 0; j < items_per_scalar; ++j)
{
@ -325,7 +331,7 @@ __device__ void vectorized_process(T const* in, idxT len, Func f, int sync_width
T value = valid ? in[tid] : T();
f(value, tid, valid);
const idxT remain_i = skip_cnt + len_cast * items_per_scalar + tid;
const IdxT remain_i = skip_cnt + len_cast * items_per_scalar + tid;
valid = remain_i < len;
value = valid ? in[remain_i] : T();
f(value, remain_i, valid);
@ -1166,6 +1172,77 @@ __global__ void radix_topk_one_block_kernel(T const* in, IdxT const* in_idx, con
} // namespace air_topk_stable
//}
namespace moe_topk
{
namespace cg = cooperative_groups;
static constexpr int kBLOCK_SIZE = 1024;
static constexpr int kWARP_SIZE = 32;
static constexpr int kWARPS_PER_BLOCK = kBLOCK_SIZE / kWARP_SIZE;
template <typename T>
__device__ T negativeInfinity()
{
return -INFINITY;
}
template <>
__device__ half negativeInfinity<half>()
{
return -CUDART_INF_FP16;
}
template <>
__device__ __nv_bfloat16 negativeInfinity<__nv_bfloat16>()
{
return -CUDART_INF_BF16;
}
/****************TopK kernel for candidate number<= 128 and K <= 8 **************** */
template <typename InputT, typename OutputT, typename IdxT, int MaxLen, int MaxTopK>
__global__ void moe_topk_kernel(
InputT const* in, OutputT* out, IdxT* outIdx, int32_t const batchSize, int32_t const len, int32_t const topK)
{
uint32_t const blockRank = blockIdx.x;
uint32_t const tIdx = kBLOCK_SIZE * blockRank + threadIdx.x;
uint32_t const warpIdx = tIdx / kWARP_SIZE;
uint32_t const laneIdx = tIdx % kWARP_SIZE;
uint32_t const warpNum = gridDim.x * kWARPS_PER_BLOCK;
auto block = cg::this_thread_block();
auto warp = cg::tiled_partition<kWARP_SIZE>(block);
InputT minScore = negativeInfinity<InputT>();
for (uint32_t tokenId = warpIdx; tokenId < batchSize; tokenId += warpNum)
{
auto scoreOffset = tokenId * len;
auto outputOffset = tokenId * topK;
InputT inputScore[MaxLen / kWARP_SIZE];
IdxT inputIndex[MaxLen / kWARP_SIZE];
InputT warpTopKScore[MaxTopK];
IdxT warpTopKExpertIdx[MaxTopK];
// Load scores and indices for this warp
for (uint32_t i = 0; i < MaxLen / kWARP_SIZE; ++i)
{
auto expertIdx = i * kWARP_SIZE + laneIdx;
inputScore[i] = expertIdx < len ? static_cast<InputT>(in[scoreOffset + expertIdx]) : minScore;
inputIndex[i] = expertIdx;
}
// Reduce topK scores and indices for this warp
tensorrt_llm::kernels::reduce_topk::reduceTopK(
warp, warpTopKScore, warpTopKExpertIdx, inputScore, inputIndex, minScore);
if (laneIdx < topK)
{
out[outputOffset + laneIdx] = static_cast<OutputT>(warpTopKScore[laneIdx]);
outIdx[outputOffset + laneIdx] = warpTopKExpertIdx[laneIdx];
}
} // end for tokenId
}
} // namespace moe_topk
/***************Runtime API****************/
@ -1223,9 +1300,11 @@ void standalone_stable_radix_topk_(void* buf, size_t& buf_size, T const* in, Idx
IdxT* sort_in_idx = nullptr;
air_topk_stable::ComputeOffset<IdxT> computeoffset(k);
thrust::counting_iterator<IdxT> counting_iter(0);
thrust::transform_iterator<air_topk_stable::ComputeOffset<IdxT>, thrust::counting_iterator<IdxT>> transform_iter(
counting_iter, computeoffset);
cub::DeviceSegmentedSort::SortPairs(NULL, temp_storage_bytes, out_idx, out_idx, out, out, k * batch_size,
batch_size, transform_iter, transform_iter + 1, stream);
if (sorted)
@ -1277,8 +1356,8 @@ void standalone_stable_radix_topk_(void* buf, size_t& buf_size, T const* in, Idx
sort_in = static_cast<decltype(sort_in)>(aligned_pointers[9]);
sort_in_idx = static_cast<decltype(sort_in_idx)>(aligned_pointers[10]);
}
cudaMemsetAsync(
buf, 0, static_cast<char*>(aligned_pointers[2]) - static_cast<char*>(aligned_pointers[0]), stream);
cudaMemsetAsync(aligned_pointers[0], 0,
static_cast<char*>(aligned_pointers[2]) - static_cast<char*>(aligned_pointers[0]), stream);
}
T const* in_buf = nullptr;
@ -1423,36 +1502,120 @@ void standalone_stable_radix_topk_one_block_(void* buf, size_t& buf_size, T cons
}
}
template <typename T, typename idxT, bool sorted = false>
void standalone_stable_radix_11bits(void* buf, size_t& buf_size, T const* in, int batch_size, idxT len, idxT k, T* out,
idxT* out_idx, bool greater, cudaStream_t stream = 0)
template <typename T, typename IdxT, bool sorted = false>
void standalone_stable_radix_11bits(void* buf, size_t& buf_size, T const* in, int batch_size, IdxT len, IdxT k, T* out,
IdxT* out_idx, bool greater, cudaStream_t stream = 0)
{
constexpr int items_per_thread = 32;
constexpr int block_dim = 512;
constexpr bool fused_last_filter = false;
if (len <= block_dim * items_per_thread)
{
standalone_stable_radix_topk_one_block_<T, idxT, 11, block_dim>(
buf, buf_size, in, static_cast<idxT*>(nullptr), batch_size, len, k, out, out_idx, !greater, stream, sorted);
standalone_stable_radix_topk_one_block_<T, IdxT, 11, block_dim>(
buf, buf_size, in, static_cast<IdxT*>(nullptr), batch_size, len, k, out, out_idx, !greater, stream, sorted);
}
else
{
int sm_cnt = tensorrt_llm::common::getMultiProcessorCount();
unsigned grid_dim = air_topk_stable::calc_grid_dim<T, idxT, 11, block_dim>(batch_size, len, sm_cnt);
unsigned grid_dim = air_topk_stable::calc_grid_dim<T, IdxT, 11, block_dim>(batch_size, len, sm_cnt);
if (grid_dim == 1)
{
standalone_stable_radix_topk_one_block_<T, idxT, 11, block_dim>(buf, buf_size, in,
static_cast<idxT*>(nullptr), batch_size, len, k, out, out_idx, !greater, stream, sorted);
standalone_stable_radix_topk_one_block_<T, IdxT, 11, block_dim>(buf, buf_size, in,
static_cast<IdxT*>(nullptr), batch_size, len, k, out, out_idx, !greater, stream, sorted);
}
else
{
standalone_stable_radix_topk_<T, idxT, 11, block_dim>(buf, buf_size, in, static_cast<idxT*>(nullptr),
standalone_stable_radix_topk_<T, IdxT, 11, block_dim>(buf, buf_size, in, static_cast<IdxT*>(nullptr),
batch_size, len, k, out, out_idx, !greater, fused_last_filter, grid_dim, stream, sorted);
}
}
}
int nextPowerOfTwo(int num)
{
if (num <= 0)
{
return 1; // Handle invalid input
}
int power = 1;
while (power < num)
{
// Check for overflow before shifting
if (power > INT_MAX / 2)
{
return power;
}
power <<= 1;
}
return power;
}
template <typename T, typename IdxT>
void moe_reduce_topk(
T const* in, int batch_size, IdxT len, IdxT k, T* out, IdxT* out_idx, bool greater, cudaStream_t stream = 0)
{
using InputT = T;
using OutputT = T;
const uint32_t max_num_blocks = 1024;
const uint32_t num_blocks
= std::min(static_cast<uint32_t>((batch_size - 1) / moe_topk::kWARPS_PER_BLOCK + 1), max_num_blocks);
uint32_t max_len = nextPowerOfTwo(len) < 32 ? 32 : nextPowerOfTwo(len);
uint32_t moe_topk = nextPowerOfTwo(k);
auto* kernel_instance = &moe_topk::moe_topk_kernel<InputT, OutputT, IdxT, 128, 8>;
switch (max_len)
{
case 32:
switch (moe_topk)
{
case 1: kernel_instance = &moe_topk::moe_topk_kernel<InputT, OutputT, IdxT, 32, 1>; break;
case 2: kernel_instance = &moe_topk::moe_topk_kernel<InputT, OutputT, IdxT, 32, 2>; break;
case 4: kernel_instance = &moe_topk::moe_topk_kernel<InputT, OutputT, IdxT, 32, 4>; break;
case 8: kernel_instance = &moe_topk::moe_topk_kernel<InputT, OutputT, IdxT, 32, 8>; break;
default: kernel_instance = nullptr; break;
}
break;
case 64:
switch (moe_topk)
{
case 1: kernel_instance = &moe_topk::moe_topk_kernel<InputT, OutputT, IdxT, 64, 1>; break;
case 2: kernel_instance = &moe_topk::moe_topk_kernel<InputT, OutputT, IdxT, 64, 2>; break;
case 4: kernel_instance = &moe_topk::moe_topk_kernel<InputT, OutputT, IdxT, 64, 4>; break;
case 8: kernel_instance = &moe_topk::moe_topk_kernel<InputT, OutputT, IdxT, 64, 8>; break;
default: kernel_instance = nullptr; break;
}
break;
case 96:
switch (moe_topk)
{
case 1: kernel_instance = &moe_topk::moe_topk_kernel<InputT, OutputT, IdxT, 96, 1>; break;
case 2: kernel_instance = &moe_topk::moe_topk_kernel<InputT, OutputT, IdxT, 96, 2>; break;
case 4: kernel_instance = &moe_topk::moe_topk_kernel<InputT, OutputT, IdxT, 96, 4>; break;
case 8: kernel_instance = &moe_topk::moe_topk_kernel<InputT, OutputT, IdxT, 96, 8>; break;
default: kernel_instance = nullptr; break;
}
break;
case 128:
switch (moe_topk)
{
case 1: kernel_instance = &moe_topk::moe_topk_kernel<InputT, OutputT, IdxT, 128, 1>; break;
case 2: kernel_instance = &moe_topk::moe_topk_kernel<InputT, OutputT, IdxT, 128, 2>; break;
case 4: kernel_instance = &moe_topk::moe_topk_kernel<InputT, OutputT, IdxT, 128, 4>; break;
case 8: kernel_instance = &moe_topk::moe_topk_kernel<InputT, OutputT, IdxT, 128, 8>; break;
default: kernel_instance = nullptr; break;
}
break;
default: kernel_instance = nullptr; break;
}
dim3 moe_topk_grid_dim(num_blocks);
dim3 moe_topk_block_dim(moe_topk::kBLOCK_SIZE);
kernel_instance<<<moe_topk_grid_dim, moe_topk_block_dim, 0, stream>>>(in, out, out_idx, batch_size, len, k);
}
#endif
///////////////
@ -1461,22 +1624,22 @@ template <typename T>
size_t invokeComputeTopkLastDimWorkspaceSize(
SizeType32 batchSize, SizeType32 inputLength, SizeType32 k, bool is_largest)
{
using idxT = SizeType32;
using IdxT = SizeType32;
size_t buf_size = 0;
void* workspace = nullptr;
T const* in = nullptr;
T* out_val = nullptr;
idxT* out_idx = nullptr;
IdxT* out_idx = nullptr;
constexpr int block_dim = 512;
constexpr bool fused_last_filter = false;
constexpr bool sorted = true;
int sm_cnt = tensorrt_llm::common::getMultiProcessorCount();
unsigned grid_dim = air_topk_stable::calc_grid_dim<T, idxT, 11, block_dim>(batchSize, inputLength, sm_cnt);
unsigned grid_dim = air_topk_stable::calc_grid_dim<T, IdxT, 11, block_dim>(batchSize, inputLength, sm_cnt);
standalone_stable_radix_topk_<T, idxT, 11, block_dim>(workspace, buf_size, in, static_cast<idxT*>(nullptr),
standalone_stable_radix_topk_<T, IdxT, 11, block_dim>(workspace, buf_size, in, static_cast<IdxT*>(nullptr),
batchSize, inputLength, k, out_val, out_idx, !is_largest, fused_last_filter, grid_dim, 0, sorted);
return buf_size;
}
@ -1506,8 +1669,17 @@ void invokeTopkLastDim(SizeType32 batchSize, SizeType32 inputLength, SizeType32
T const* in = reinterpret_cast<T const*>(input);
T* out_val_ = reinterpret_cast<T*>(out_val);
SizeType32* out_idx_ = reinterpret_cast<SizeType32*>(out_idx);
standalone_stable_radix_11bits<T, SizeType32, true>(
workspace, buf_size, in, batchSize, inputLength, k, out_val_, out_idx_, is_largest, stream);
if (inputLength <= 128 && k <= 8 && is_largest == true)
{
// This method does not require a buffer, but since the implementation may vary in different cases,
// we still allocate the buffer in case AIR TopK is used instead.
moe_reduce_topk(in, batchSize, inputLength, k, out_val_, out_idx_, !is_largest, stream);
}
else
{
standalone_stable_radix_11bits<T, SizeType32, true>(
workspace, buf_size, in, batchSize, inputLength, k, out_val_, out_idx_, is_largest, stream);
}
}
#define INSTANTIATE_TOPK_LastDim_DATA_TYPE(T) \

View File

@ -378,7 +378,7 @@ __device__ void routingPermutation(KernelParams params, PackedScoreIdx<BaseType>
// We can't do it earlier because FC1 depends on the mPtrCtaIdxXyToBatchIdx,
// mPtrCtaIdxXyToMnLimit, mPtrNumNonExitingCtas and mPtrTotalNumPaddedTokens
// TODO: this is not sufficient to ensure visibility in the next kernel!
#if !defined(PDL_PROFILE) || PDL_PROFILE == 0
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
if constexpr (KernelParams::UsePdl)
{
cudaTriggerProgrammaticLaunchCompletion();
@ -757,15 +757,13 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesOffsetsKernel(Ke
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
// Trigger secondary kernel.
// Note: this does not guarantee the visibility of prior writes unless the consumer executes a
// dependency sync.
#if !defined(PDL_PROFILE) || PDL_PROFILE == 0
// Trigger secondary kernel.
// Note: this does not guarantee the visibility of prior writes unless the consumer executes a
// dependency sync.
if constexpr (KernelParams::UsePdl)
{
cudaTriggerProgrammaticLaunchCompletion();
}
#endif
#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
}

View File

@ -227,13 +227,11 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
#if !defined(PDL_PROFILE) || PDL_PROFILE == 0
// we can trigger the next kernel at this point
if constexpr (KernelParams::UsePdl)
{
cudaTriggerProgrammaticLaunchCompletion();
}
#endif
#endif
// at this point, all values for offsets are ready, except the final offsets

View File

@ -199,13 +199,11 @@ __global__ void __launch_bounds__(NumThreadsSingleBlock) routingIndicesBlockKern
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
#if !defined(PDL_PROFILE) || PDL_PROFILE == 0
// we can trigger the next kernel at this point
if constexpr (KernelParams::UsePdl)
{
cudaTriggerProgrammaticLaunchCompletion();
}
#endif
#endif
for (int tokenIdx = 0; tokenIdx < params.mNumTokens; tokenIdx++)

View File

@ -43,7 +43,7 @@ target_link_libraries(
${Python3_LIBRARIES}
${TORCH_LIBRARIES}
torch_python
CUDA::cuda_driver
${CUDA_DRV_LIB}
${CUDA_NVML_LIB}
th_common)
target_compile_definitions(

View File

@ -285,5 +285,35 @@ struct type_caster<std::vector<std::reference_wrapper<T const>>>
return make_caster<std::vector<T>>::from_cpp(result, policy, cleanup);
}
};
template <>
struct type_caster<torch::ScalarType>
{
NB_TYPE_CASTER(torch::ScalarType, const_name("torch.dtype"));
bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept
{
std::string dtype_name = nb::cast<std::string>(nb::str(src));
if (dtype_name.substr(0, 6) == "torch.")
{
dtype_name = dtype_name.substr(6);
}
auto const& dtype_map = c10::getStringToDtypeMap();
auto it = dtype_map.find(dtype_name);
if (it != dtype_map.end())
{
value = it->second;
return true;
}
return false;
}
static handle from_cpp(torch::ScalarType src, rv_policy policy, cleanup_list* cleanup)
{
throw std::runtime_error("from_cpp for torch::ScalarType is not implemented");
}
};
} // namespace detail
} // namespace NB_NAMESPACE

View File

@ -240,7 +240,8 @@ void initBindings(nb::module_& m)
nb::class_<tle::KVCacheEvent>(executor_kv_cache, "KVCacheEvent")
.def_ro("event_id", &tle::KVCacheEvent::eventId)
.def_ro("data", &tle::KVCacheEvent::data)
.def_ro("window_size", &tle::KVCacheEvent::windowSize);
.def_ro("window_size", &tle::KVCacheEvent::windowSize)
.def_ro("attention_dp_rank", &tle::KVCacheEvent::attentionDpRank);
nb::class_<tle::KVCacheEventManager>(executor_kv_cache, "KVCacheEventManager")
.def(

View File

@ -27,6 +27,7 @@
#include <nanobind/nanobind.h>
#include <nanobind/stl/chrono.h>
#include <nanobind/stl/function.h>
#include <nanobind/stl/list.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/shared_ptr.h>

View File

@ -279,7 +279,7 @@ void initBindings(nb::module_& m)
.def(nb::init<tr::GptDecoderBatched::CudaStreamPtr>(), nb::arg("stream"))
.def("setup", &tr::GptDecoderBatched::setup, nb::arg("mode"), nb::arg("max_num_sequences"),
nb::arg("max_beam_width"), nb::arg("dtype"), nb::arg("model_config"), nb::arg("world_config"))
.def("forward_async", &tr::GptDecoderBatched::forwardAsync, nb::arg("output"), nb::arg("input"))
.def("forward_async", &tr::GptDecoderBatched::forwardAsync, nb::arg("decoder_state"), nb::arg("input"))
.def("underlying_decoder", &tr::GptDecoderBatched::getUnderlyingDecoder, nb::rv_policy::reference)
.def("finalize", &tr::GptDecoderBatched::finalize, nb::arg("decoder_state"), nb::arg("batch_idx"),
nb::arg("sampling_config"), nb::arg("streaming"))

View File

@ -946,8 +946,8 @@ int MixtureOfExpertsPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
std::optional<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> gemm2;
if (common::getEnvForceDeterministicMOE())
{
gemm1 = mMOERunner->getTactics()[0];
gemm2 = mMOERunner->getTactics()[0];
gemm1 = mMOERunner->getTactics(MoeGemmId::GEMM_1)[0];
gemm2 = mMOERunner->getTactics(MoeGemmId::GEMM_2)[0];
}
else
{
@ -1278,7 +1278,7 @@ void MixtureOfExpertsGemmProfiler::runTactic(int m, int n, int k, MixtureOfExper
auto MixtureOfExpertsGemmProfiler::getTactics(int m, int n, int k) const -> std::vector<Config>
{
assert(mRunner);
return mRunner->mMOERunner->getTactics();
return mRunner->mMOERunner->getTactics(backend.mGemmToProfile);
}
void MixtureOfExpertsGemmProfiler::initTmpData(

View File

@ -43,6 +43,7 @@ namespace kernels = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE;
using MoeMinLatencyParams = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::MoeMinLatencyParams;
using MOEParallelismConfig = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::MOEParallelismConfig;
using QuantParams = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::QuantParams;
using MoeGemmId = CUTLASS_MOE_GEMM_NAMESPACE::MoeGemmId;
using ActivationType = CUTLASS_MOE_GEMM_NAMESPACE::ActivationType;
using ActivationParams = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::ActivationParams;
using TmaWarpSpecializedGroupedGemmInput = CUTLASS_MOE_GEMM_NAMESPACE::TmaWarpSpecializedGroupedGemmInput;

View File

@ -44,7 +44,7 @@ target_link_libraries(
${Python3_LIBRARIES}
${TORCH_LIBRARIES}
torch_python
CUDA::cuda_driver
${CUDA_DRV_LIB}
${CUDA_NVML_LIB}
th_common)
target_compile_definitions(

View File

@ -131,6 +131,7 @@ void DecoderState::setupSpeculativeDecodingBuffers(
mSpeculativeDecodingMode = speculativeDecodingMode;
auto constexpr nvTokenIdType = TRTDataType<TokenIdType>::value;
auto constexpr nvSizeType = TRTDataType<SizeType32>::value;
auto& dInput = mJointDecodingInput;
@ -179,6 +180,7 @@ void DecoderState::setupSpeculativeDecodingBuffers(
DecodingInput::ExternalDraftTokensInputs externalDraftTokensInputs;
externalDraftTokensInputs.draftLogits = bufferManager.emptyTensor(MemoryType::kGPU, dtype);
externalDraftTokensInputs.draftLogitsHost = bufferManager.emptyTensor(MemoryType::kPINNEDPOOL, dtype);
externalDraftTokensInputs.draftProbs = bufferManager.emptyTensor(MemoryType::kGPU, dtype);
externalDraftTokensInputs.targetProbs = bufferManager.emptyTensor(MemoryType::kGPU, dtype);
externalDraftTokensInputs.numDraftTokens = bufferManager.emptyTensor(MemoryType::kGPU, nvSizeType);
@ -187,8 +189,8 @@ void DecoderState::setupSpeculativeDecodingBuffers(
= bufferManager.emptyTensor(MemoryType::kGPU, TRTDataType<bool>::value);
externalDraftTokensInputs.useDraftLogitsHost
= bufferManager.emptyTensor(MemoryType::kPINNEDPOOL, TRTDataType<bool>::value);
externalDraftTokensInputs.draftTokenIds
= bufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32);
externalDraftTokensInputs.draftTokenIds = bufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType);
externalDraftTokensInputs.draftTokenIdsHost = bufferManager.emptyTensor(MemoryType::kPINNEDPOOL, nvTokenIdType);
dInput->externalDraftTokensInputs = externalDraftTokensInputs;
}
@ -366,10 +368,16 @@ void DecoderState::reshapeSpeculativeDecodingBuffers(SpeculativeDecodingMode con
{mMaxNumSequences, mMaxDecodingEngineTokens, mMaxBeamWidth, static_cast<SizeType32>(vocabSizePadded)});
dInput.externalDraftTokensInputs->draftProbs->reshape(probsShape);
dInput.externalDraftTokensInputs->targetProbs->reshape(probsShape);
dInput.externalDraftTokensInputs->draftLogits->reshape(
ITensor::makeShape({mMaxNumSequences, mMaxDecodingEngineTokens, static_cast<SizeType32>(vocabSizePadded)}));
dInput.externalDraftTokensInputs->draftTokenIds->reshape(
ITensor::makeShape({mMaxNumSequences, mMaxDecodingEngineTokens}));
auto const logitsShape = ITensor::makeShape(
{mMaxNumSequences, mMaxDecodingEngineTokens, static_cast<SizeType32>(vocabSizePadded)});
dInput.externalDraftTokensInputs->draftLogits->reshape(logitsShape);
dInput.externalDraftTokensInputs->draftLogitsHost->reshape(logitsShape);
auto const tokenIdsShape = ITensor::makeShape({mMaxNumSequences, mMaxDecodingEngineTokens});
dInput.externalDraftTokensInputs->draftTokenIds->reshape(tokenIdsShape);
dInput.externalDraftTokensInputs->draftTokenIdsHost->reshape(tokenIdsShape);
dInput.externalDraftTokensInputs->numDraftTokens->reshape(maxNumSequencesShape);
dInput.externalDraftTokensInputs->numDraftTokensHost->reshape(maxNumSequencesShape);
dInput.externalDraftTokensInputs->useDraftLogits->reshape(maxNumSequencesShape);

View File

@ -83,7 +83,7 @@ add_library(
reducescatterOp.cpp
relativeAttentionBiasOp.cpp
dsv3RouterGemmOp.cpp
renormMoeRoutingOp.cpp
customMoeRoutingOp.cpp
selectiveScanOp.cpp
userbuffersFinalizeOp.cpp
userbuffersTensor.cpp
@ -119,9 +119,9 @@ endif()
if(NOT WIN32)
set_target_properties(
th_common
PROPERTIES LINK_FLAGS
"-Wl,-rpath='$ORIGIN' ${AS_NEEDED_FLAG} ${UNDEFINED_FLAG}")
th_common PROPERTIES BUILD_RPATH "$ORIGIN;$ORIGIN/../../nvidia/nccl/lib")
set_target_properties(
th_common PROPERTIES LINK_FLAGS "${AS_NEEDED_FLAG} ${UNDEFINED_FLAG}")
else()
target_link_libraries(th_common PRIVATE context_attention_src)
endif()

View File

@ -15,7 +15,7 @@
*/
#include "tensorrt_llm/common/opUtils.h"
#include "tensorrt_llm/kernels/renormMoeRoutingKernels.h"
#include "tensorrt_llm/kernels/customMoeRoutingKernels.h"
#include "tensorrt_llm/runtime/torchUtils.h"
namespace th = torch;
@ -25,7 +25,8 @@ namespace tk = tensorrt_llm::kernels;
namespace torch_ext
{
std::tuple<at::Tensor, at::Tensor> renorm_moe_routing_op(th::Tensor const& router_logits, int64_t topk)
template <bool DoSoftmaxBeforeTopK>
std::tuple<at::Tensor, at::Tensor> custom_moe_routing_op(th::Tensor const& router_logits, int64_t topk)
{
auto data_type = router_logits.scalar_type();
auto input_size = router_logits.sizes();
@ -44,20 +45,22 @@ std::tuple<at::Tensor, at::Tensor> renorm_moe_routing_op(th::Tensor const& route
{
case torch::kFloat32:
// Handle Float32
tk::invokeRenormMoeRouting<float, float, int32_t>(reinterpret_cast<float*>(router_logits.mutable_data_ptr()),
tk::invokeRenormMoeRouting<float, float, int32_t, DoSoftmaxBeforeTopK>(
reinterpret_cast<float*>(router_logits.mutable_data_ptr()),
reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()), num_tokens, num_experts, topk, stream);
break;
case torch::kBFloat16:
// Handle BFloat16
tk::invokeRenormMoeRouting<__nv_bfloat16, float, int32_t>(
tk::invokeRenormMoeRouting<__nv_bfloat16, float, int32_t, DoSoftmaxBeforeTopK>(
reinterpret_cast<__nv_bfloat16*>(router_logits.mutable_data_ptr()),
reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()), num_tokens, num_experts, topk, stream);
break;
case torch::kHalf:
// Handle Half
tk::invokeRenormMoeRouting<half, float, int32_t>(reinterpret_cast<half*>(router_logits.mutable_data_ptr()),
tk::invokeRenormMoeRouting<half, float, int32_t, DoSoftmaxBeforeTopK>(
reinterpret_cast<half*>(router_logits.mutable_data_ptr()),
reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()), num_tokens, num_experts, topk, stream);
break;
@ -69,6 +72,15 @@ std::tuple<at::Tensor, at::Tensor> renorm_moe_routing_op(th::Tensor const& route
return {topk_indices, topk_values};
}
std::tuple<at::Tensor, at::Tensor> renorm_moe_routing_op(th::Tensor const& router_logits, int64_t topk)
{
return custom_moe_routing_op<false>(router_logits, topk);
}
std::tuple<at::Tensor, at::Tensor> default_moe_routing_op(th::Tensor const& router_logits, int64_t topk)
{
return custom_moe_routing_op<true>(router_logits, topk);
}
} // namespace torch_ext
TORCH_LIBRARY_FRAGMENT(trtllm, m)
@ -82,3 +94,15 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("renorm_moe_routing_op", &torch_ext::renorm_moe_routing_op);
}
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"default_moe_routing_op(Tensor router_logits, SymInt topk"
") -> (Tensor, Tensor)");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("default_moe_routing_op", &torch_ext::default_moe_routing_op);
}

View File

@ -48,6 +48,7 @@ namespace common = tensorrt_llm::common;
namespace kernels = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE;
using ActivationParams = CUTLASS_MOE_GEMM_NAMESPACE::ActivationParams;
using ActivationType = CUTLASS_MOE_GEMM_NAMESPACE::ActivationType;
using MoeGemmId = CUTLASS_MOE_GEMM_NAMESPACE::MoeGemmId;
// Always use public header as it is just utility functions and types
using TmaWarpSpecializedGroupedGemmInput = tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput;
using profiler_backend = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::GemmProfilerBackend;
@ -215,7 +216,8 @@ public:
mKernelRunner->use_fused_finalize_ = mUseFusedFinalize;
mProfiler = std::make_shared<kernels::GemmProfilerBackend>();
mAllProfiles = mKernelRunner->getTactics();
mGemm1Profiles = mKernelRunner->getTactics(MoeGemmId::GEMM_1);
mGemm2Profiles = mKernelRunner->getTactics(MoeGemmId::GEMM_2);
}
~FusedMoeRunner()
@ -585,10 +587,11 @@ public:
return std::make_tuple(output, num_active_experts_per_node, experts_to_token_score, active_expert_global_ids);
}
int64_t getTacticNum()
int64_t getTacticNum(int64_t const gemm_idx)
{
std::lock_guard<std::mutex> lock(mMutex);
return mAllProfiles.size();
TORCH_CHECK(gemm_idx == 1 || gemm_idx == 2, "gemm_idx must be 1 or 2");
return (gemm_idx == 1) ? mGemm1Profiles.size() : mGemm2Profiles.size();
}
// TODO Update this to be able to tell if we are profiling swiglu bias
@ -624,10 +627,14 @@ public:
: group_size_;
int const num_experts = static_cast<int>(fc2_expert_weights.sizes()[0] * ep_size);
auto const gemm_to_profile
= (gemm_idx == 1) ? profiler_backend::GemmToProfile::GEMM_1 : profiler_backend::GemmToProfile::GEMM_2;
auto const& profiles = (gemm_idx == 1) ? mGemm1Profiles : mGemm2Profiles;
// Get specific profile configs according to the profile_id.
// Fallback tactic is set to be 0
// TODO: use the best tactic id found offline for a better default inference perf
auto const& profile = profile_id == -1 ? mAllProfiles.front() : mAllProfiles[profile_id];
auto const& profile = profile_id == -1 ? profiles.front() : profiles[profile_id];
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
@ -638,8 +645,7 @@ public:
if (do_preparation)
{
// Set profiled gemm idx
mProfiler->mGemmToProfile
= (gemm_idx == 1) ? profiler_backend::GemmToProfile::GEMM_1 : profiler_backend::GemmToProfile::GEMM_2;
mProfiler->mGemmToProfile = gemm_to_profile;
// mProfiler init
auto parallelism_config = kernels::MOEParallelismConfig(static_cast<int>(tp_size),
@ -704,7 +710,8 @@ private:
bool mUseFusedFinalize = true;
using Profile = tensorrt_llm::cutlass_extensions::CutlassGemmConfig;
std::vector<Profile> mAllProfiles;
std::vector<Profile> mGemm1Profiles;
std::vector<Profile> mGemm2Profiles;
void freeProfileWorkspace()
{
@ -730,15 +737,15 @@ private:
return;
}
auto best_gemm1_profile = mAllProfiles.front();
auto best_gemm2_profile = mAllProfiles.front();
auto best_gemm1_profile = mGemm1Profiles.front();
auto best_gemm2_profile = mGemm2Profiles.front();
if (profile_ids.has_value())
{
TORCH_CHECK(profile_ids.value().size() == 2, "Expecting 2 profile ids");
best_gemm1_profile
= profile_ids.value()[0] == -1 ? best_gemm1_profile : mAllProfiles.at(profile_ids.value()[0]);
= profile_ids.value()[0] == -1 ? best_gemm1_profile : mGemm1Profiles.at(profile_ids.value()[0]);
best_gemm2_profile
= profile_ids.value()[1] == -1 ? best_gemm2_profile : mAllProfiles.at(profile_ids.value()[1]);
= profile_ids.value()[1] == -1 ? best_gemm2_profile : mGemm2Profiles.at(profile_ids.value()[1]);
}
mKernelRunner->setTactic(best_gemm1_profile, best_gemm2_profile);
}

View File

@ -99,7 +99,7 @@ TEST_F(RequestInfoTest, Basic)
}
auto state = std::make_unique<texec::DataTransceiverState>();
state->setCommState(texec::kv_cache::CommState{12, "127.0.0.1"});
state->setCacheState(texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, nvinfer1::DataType::kFLOAT});
state->setCacheState(texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, 8, nvinfer1::DataType::kFLOAT});
RequestInfo info{1, *state};
auto info2 = serializeDeserialize(info);
EXPECT_EQ(info, info2);
@ -133,7 +133,7 @@ TEST_F(CacheConfigTest, EqualTo)
constexpr SizeType32 tokensPerBlock{64};
constexpr SizeType32 tensorParallelism{8};
constexpr SizeType32 pipelineParallelism{2};
constexpr SizeType32 contextParallelism{1};
constexpr SizeType32 contextParallelism{2};
constexpr SizeType32 sizePerHead{hiddenSize / nbHeads};
constexpr CacheState::AttentionType attentionType{CacheState::AttentionType::kDEFAULT};
constexpr int kvFactor = 2;
@ -148,7 +148,7 @@ TEST_F(CacheConfigTest, EqualTo)
texec::kv_cache::CacheState state0{
cacheStateCfg, worldConfig, modelConfig.getKvDataType(), attentionType, kvFactor};
texec::kv_cache::CacheState state1{nbAttentionLayers, nbHeads, sizePerHead, tokensPerBlock, tensorParallelism,
pipelineParallelism, dtype, attentionType, kvFactor, false, 0, tensorParallelism};
pipelineParallelism, contextParallelism, dtype, attentionType, kvFactor, false, 0, tensorParallelism};
EXPECT_EQ(state0, state1);
}
@ -165,7 +165,7 @@ public:
ON_CALL(*this, recvRequestInfo)
.WillByDefault(Return(RequestInfo{0,
texec::DataTransceiverState{
texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, nvinfer1::DataType::kFLOAT},
texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, 1, nvinfer1::DataType::kFLOAT},
texec::kv_cache::CommState{std::vector<SizeType32>{0}, 0}}}));
ON_CALL(*this, getCounterpartsCount).WillByDefault(Return(1));
}
@ -217,7 +217,8 @@ TEST_F(MockTransceiverTest, MpiResponderBasic)
auto sender = std::make_unique<MockDataSender>();
EXPECT_CALL(*sender, recvRequestInfo)
.WillOnce(Return(RequestInfo{0,
texec::DataTransceiverState{texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, nvinfer1::DataType::kFLOAT},
texec::DataTransceiverState{
texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, 1, nvinfer1::DataType::kFLOAT},
texec::kv_cache::CommState{std::vector<SizeType32>{0}, 0}}}));
EXPECT_CALL(*sender, sendSync).WillOnce(Return());
EXPECT_CALL(*sender, getCounterpartsCount).WillOnce(Return(1));
@ -318,7 +319,7 @@ protected:
dataType, sinkTokenLength, stream, std::nullopt, enableBlockReuse, onboardBlocks, CacheType::kSELF,
std::nullopt, nullptr, true);
mCacheState = std::make_unique<texec::kv_cache::CacheState>(
numLayers, numHeads, sizePerHead, tokensPerBlock, 1, 1, dataType);
numLayers, numHeads, sizePerHead, tokensPerBlock, 1, 1, 1, dataType);
if (tensorrt_llm::common::getEnvUseUCXKvCache())
{
@ -506,7 +507,7 @@ TEST_F(SymmetricalCacheTest, SimpleTest)
#if ENABLE_MULTI_DEVICE
using AsymmetricTestParam
= std::tuple<int, int, int, int, int, int, int, int, nvinfer1::DataType, int, bool, bool, bool, bool>;
= std::tuple<int, int, int, int, int, int, int, int, int, int, nvinfer1::DataType, int, bool, bool, bool, bool>;
class AsymmetricalCacheTest : public ::testing::TestWithParam<AsymmetricTestParam>
{
@ -516,8 +517,8 @@ protected:
void TearDown() override {}
void setUpCommunicator(int contextTp, int contextPp, int genTp, int genPp, bool isMLA = false,
bool contextDP = false, bool generationDP = false)
void setUpCommunicator(int contextTp, int contextPp, int contextCp, int genTp, int genPp, int genCp,
bool isMLA = false, bool contextDP = false, bool generationDP = false)
{
#if ENABLE_MULTI_DEVICE
tensorrt_llm::mpi::initialize(tensorrt_llm::mpi::MpiThreadSupport::THREAD_MULTIPLE);
@ -572,11 +573,13 @@ protected:
{
mTpSize = contextTp;
mPpSize = contextPp;
mCpSize = contextCp;
}
if (mIsGeneration)
{
mTpSize = genTp;
mPpSize = genPp;
mCpSize = genCp;
}
mTpRank = mRankInInstance % mTpSize;
@ -585,6 +588,7 @@ protected:
mGenRankSize = genRanks;
mContextTpSize = contextTp;
mContextPpSize = contextPp;
mContextCpSize = contextCp;
EXPECT_EQ((sessionComm.getRank()), mRankInInstance);
EXPECT_EQ(sessionComm.getSize(), mSizeInInstance);
@ -696,11 +700,12 @@ protected:
texec::kv_cache::CacheState::AttentionType attentionType = isMLA
? texec::kv_cache::CacheState::AttentionType::kMLA
: texec::kv_cache::CacheState::AttentionType::kDEFAULT;
mCacheState = std::make_unique<texec::kv_cache::CacheState>(numLayers, numHeadsPerRank, sizePerHead,
tokensPerBlock, mTpSize, mPpSize, dataType, attentionType, kvFactor, enableDPAttention, DPrank, DPsize);
mCacheState
= std::make_unique<texec::kv_cache::CacheState>(numLayers, numHeadsPerRank, sizePerHead, tokensPerBlock,
mTpSize, mPpSize, mCpSize, dataType, attentionType, kvFactor, enableDPAttention, DPrank, DPsize);
mContextCacheState = std::make_unique<texec::kv_cache::CacheState>(numLayers, numHeadsPerRankForContext,
sizePerHead, tokensPerBlock, mContextTpSize, mContextPpSize, dataType, attentionType, kvFactor, mContextDP,
DPrank, mContextTpSize);
sizePerHead, tokensPerBlock, mContextTpSize, mContextPpSize, mContextCpSize, dataType, attentionType,
kvFactor, mContextDP, DPrank, mContextTpSize);
// UVM seems to be incompatible with MPI, and it is continuing to investigate.
bool constexpr useUvm = false;
@ -859,7 +864,8 @@ protected:
texec::kv_cache::CacheState cacheState{mContextCacheState->getModelConfig().mNbKvHeadsPerLayer,
mContextCacheState->getModelConfig().mSizePerHead, mContextCacheState->getModelConfig().mTokensPerBlock,
mContextCacheState->getParallelConfig().mTensorParallelism,
mContextCacheState->getParallelConfig().mPipelineParallelism, mContextCacheState->getDataType(),
mContextCacheState->getParallelConfig().mPipelineParallelism,
mContextCacheState->getParallelConfig().mContextParallelism, mContextCacheState->getDataType(),
mContextCacheState->getAttentionConfig().mAttentionType, mContextCacheState->getAttentionConfig().mKvFactor,
mContextCacheState->getParallelConfig().mEnableAttentionDP, contextDpRank,
mContextCacheState->getParallelConfig().mTensorParallelism};
@ -1094,8 +1100,8 @@ protected:
tensorrt_llm::mpi::MpiComm const* mComm;
tensorrt_llm::mpi::MpiComm mParticipatingComm{nullptr, false};
SizeType32 mWorldSize{0}, mRank{0}, mRankInInstance{0};
SizeType32 mSizeInInstance{0}, mTpRank{0}, mPpRank{0}, mTpSize{0}, mPpSize{0}, mContextRankSize{0}, mGenRankSize{0},
mContextTpSize{0}, mContextPpSize{0};
SizeType32 mSizeInInstance{0}, mTpRank{0}, mPpRank{0}, mTpSize{0}, mPpSize{0}, mCpSize{0}, mContextRankSize{0},
mGenRankSize{0}, mContextTpSize{0}, mContextPpSize{0}, mContextCpSize{0};
LlmRequest::RequestIdType mRequestId{0};
bool mContextDP{false};
bool mGenerationDP{false};
@ -1129,22 +1135,24 @@ TEST_P(AsymmetricalCacheTest, TestCase)
AsymmetricTestParam param = GetParam();
int contextTp = std::get<0>(param);
int contextPp = std::get<1>(param);
int genTp = std::get<2>(param);
int genPp = std::get<3>(param);
int numLayers = std::get<4>(param);
int numHeads = std::get<5>(param);
int sizePerHead = std::get<6>(param);
int tokensPerBlock = std::get<7>(param);
nvinfer1::DataType dataType = std::get<8>(param);
int contextCp = std::get<2>(param);
int genTp = std::get<3>(param);
int genPp = std::get<4>(param);
int genCp = std::get<5>(param);
int numLayers = std::get<6>(param);
int numHeads = std::get<7>(param);
int sizePerHead = std::get<8>(param);
int tokensPerBlock = std::get<9>(param);
nvinfer1::DataType dataType = std::get<10>(param);
int kvFactor = std::get<9>(param);
bool isMLA = std::get<10>(param);
bool contextDP = std::get<11>(param);
bool generationDP = std::get<12>(param);
int kvFactor = std::get<11>(param);
bool isMLA = std::get<12>(param);
bool contextDP = std::get<13>(param);
bool generationDP = std::get<14>(param);
bool isWindow = std::get<13>(param);
bool isWindow = std::get<15>(param);
setUpCommunicator(contextTp, contextPp, genTp, genPp, isMLA, contextDP, generationDP);
setUpCommunicator(contextTp, contextPp, contextCp, genTp, genPp, genCp, isMLA, contextDP, generationDP);
if (mIsContext || mIsGeneration)
{
@ -1221,21 +1229,23 @@ TEST_P(AsymmetricalCacheTestWithDP, TestCase)
AsymmetricTestParam param = GetParam();
int contextTp = std::get<0>(param);
int contextPp = std::get<1>(param);
int genTp = std::get<2>(param);
int genPp = std::get<3>(param);
int numLayers = std::get<4>(param);
int numHeads = std::get<5>(param);
int sizePerHead = std::get<6>(param);
int tokensPerBlock = std::get<7>(param);
nvinfer1::DataType dataType = std::get<8>(param);
int contextCp = std::get<2>(param);
int genTp = std::get<3>(param);
int genPp = std::get<4>(param);
int genCp = std::get<5>(param);
int numLayers = std::get<6>(param);
int numHeads = std::get<7>(param);
int sizePerHead = std::get<8>(param);
int tokensPerBlock = std::get<9>(param);
nvinfer1::DataType dataType = std::get<10>(param);
int kvFactor = std::get<9>(param);
bool isMLA = std::get<10>(param);
bool contextDP = std::get<11>(param);
bool generationDP = std::get<12>(param);
bool isWindow = std::get<13>(param);
int kvFactor = std::get<11>(param);
bool isMLA = std::get<12>(param);
bool contextDP = std::get<13>(param);
bool generationDP = std::get<14>(param);
bool isWindow = std::get<15>(param);
setUpCommunicator(contextTp, contextPp, genTp, genPp, isMLA, contextDP, generationDP);
setUpCommunicator(contextTp, contextPp, contextCp, genTp, genPp, genCp, isMLA, contextDP, generationDP);
if (mIsContext || mIsGeneration)
{
@ -1324,95 +1334,95 @@ TEST_P(AsymmetricalCacheTestWithDP, TestCase)
}
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTest0, AsymmetricalCacheTest,
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2),
testing::Values(4), testing::Values(4), testing::Values(4), testing::Values(16),
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(1, 2),
testing::Values(1, 2), testing::Values(1), testing::Values(4), testing::Values(4), testing::Values(4),
testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
testing::Values(false), testing::Values(false), testing::Values(false), testing::Values(true, false)));
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithWindow, AsymmetricalCacheTest,
testing::Combine(testing::Values(1), testing::Values(1), testing::Values(1), testing::Values(1), testing::Values(5),
testing::Values(4), testing::Values(4), testing::Values(8),
testing::Combine(testing::Values(1), testing::Values(1), testing::Values(1), testing::Values(1), testing::Values(1),
testing::Values(1), testing::Values(5), testing::Values(4), testing::Values(4), testing::Values(8),
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
testing::Values(false), testing::Values(false), testing::Values(false), testing::Values(true)));
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTest1, AsymmetricalCacheTest,
testing::Combine(testing::Values(4), testing::Values(1), testing::Values(1), testing::Values(4), testing::Values(8),
testing::Values(4), testing::Values(4), testing::Values(8),
testing::Combine(testing::Values(4), testing::Values(1), testing::Values(1), testing::Values(1), testing::Values(4),
testing::Values(1), testing::Values(8), testing::Values(4), testing::Values(4), testing::Values(8),
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
testing::Values(false), testing::Values(false), testing::Values(false), testing::Values(false, true)));
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTest2, AsymmetricalCacheTest,
testing::Combine(testing::Values(1), testing::Values(2), testing::Values(1), testing::Values(1, 4),
testing::Values(16), testing::Values(16), testing::Values(4), testing::Values(8),
testing::Values(nvinfer1::DataType::kFLOAT), testing::Values(2), testing::Values(false), testing::Values(false),
testing::Values(false), testing::Values(false)));
testing::Combine(testing::Values(1), testing::Values(2), testing::Values(1), testing::Values(1),
testing::Values(1, 4), testing::Values(1), testing::Values(16), testing::Values(16), testing::Values(4),
testing::Values(8), testing::Values(nvinfer1::DataType::kFLOAT), testing::Values(2), testing::Values(false),
testing::Values(false), testing::Values(false), testing::Values(false)));
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTest0ForMLA, AsymmetricalCacheTest,
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2),
testing::Values(4), testing::Values(1), testing::Values(4), testing::Values(16),
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1),
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(1, 2),
testing::Values(1, 2), testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(4),
testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1),
testing::Values(true), testing::Values(false), testing::Values(false), testing::Values(false)));
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTest1ForMLA, AsymmetricalCacheTest,
testing::Combine(testing::Values(4), testing::Values(1), testing::Values(1), testing::Values(4), testing::Values(4),
testing::Values(1), testing::Values(4), testing::Values(8),
testing::Combine(testing::Values(4), testing::Values(1), testing::Values(1), testing::Values(1), testing::Values(4),
testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(4), testing::Values(8),
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1),
testing::Values(true), testing::Values(false), testing::Values(false), testing::Values(false)));
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForMLA1, AsymmetricalCacheTestWithDP,
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2),
testing::Values(4), testing::Values(1), testing::Values(4), testing::Values(16),
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1),
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(1, 2),
testing::Values(1, 2), testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(4),
testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1),
testing::Values(true), testing::Values(true), testing::Values(true), testing::Values(false)));
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForMLA2, AsymmetricalCacheTestWithDP,
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2),
testing::Values(4), testing::Values(1), testing::Values(4), testing::Values(16),
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1),
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(1, 2),
testing::Values(1, 2), testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(4),
testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1),
testing::Values(true), testing::Values(true), testing::Values(false), testing::Values(false)));
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForMLA3, AsymmetricalCacheTestWithDP,
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2),
testing::Values(4), testing::Values(1), testing::Values(4), testing::Values(16),
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1),
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(1, 2),
testing::Values(1, 2), testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(4),
testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1),
testing::Values(true), testing::Values(false), testing::Values(true), testing::Values(false)));
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLA, AsymmetricalCacheTestWithDP,
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2),
testing::Values(4), testing::Values(4), testing::Values(4), testing::Values(16),
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(1, 2),
testing::Values(1, 2), testing::Values(1), testing::Values(4), testing::Values(4), testing::Values(4),
testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
testing::Values(false), testing::Values(true), testing::Values(true), testing::Values(false)));
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLA1, AsymmetricalCacheTestWithDP,
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2),
testing::Values(4), testing::Values(4), testing::Values(4), testing::Values(16),
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(1, 2),
testing::Values(1, 2), testing::Values(1), testing::Values(4), testing::Values(4), testing::Values(4),
testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
testing::Values(false), testing::Values(true), testing::Values(false), testing::Values(false)));
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLA2, AsymmetricalCacheTestWithDP,
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2),
testing::Values(4), testing::Values(4), testing::Values(4), testing::Values(16),
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(1, 2),
testing::Values(1, 2), testing::Values(1), testing::Values(4), testing::Values(4), testing::Values(4),
testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
testing::Values(false), testing::Values(false), testing::Values(true), testing::Values(false)));
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate0, AsymmetricalCacheTestWithDP,
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(4), testing::Values(1),
testing::Values(4), testing::Values(2), testing::Values(4), testing::Values(16),
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(4),
testing::Values(1), testing::Values(1), testing::Values(4), testing::Values(2), testing::Values(4),
testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
testing::Values(false), testing::Values(true, false), testing::Values(false), testing::Values(false)));
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate1, AsymmetricalCacheTestWithDP,
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(2), testing::Values(2),
testing::Values(4), testing::Values(1), testing::Values(4), testing::Values(16),
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(2),
testing::Values(2), testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(4),
testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
testing::Values(false), testing::Values(true, false), testing::Values(false), testing::Values(false)));
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate2, AsymmetricalCacheTestWithDP,
testing::Combine(testing::Values(4), testing::Values(1), testing::Values(4, 2), testing::Values(1),
testing::Values(4), testing::Values(2), testing::Values(4), testing::Values(16),
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
testing::Combine(testing::Values(4), testing::Values(1), testing::Values(1), testing::Values(4, 2),
testing::Values(1), testing::Values(1), testing::Values(4), testing::Values(2), testing::Values(4),
testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
testing::Values(false), testing::Values(false), testing::Values(false), testing::Values(false)));
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate4, AsymmetricalCacheTestWithDP,
testing::Combine(testing::Values(4), testing::Values(1), testing::Values(1, 2), testing::Values(2),
testing::Values(4), testing::Values(1, 2), testing::Values(4), testing::Values(16),
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
testing::Combine(testing::Values(4), testing::Values(1), testing::Values(1), testing::Values(1, 2),
testing::Values(2), testing::Values(1), testing::Values(4), testing::Values(1, 2), testing::Values(4),
testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
testing::Values(false), testing::Values(false), testing::Values(false), testing::Values(false)));
#endif
@ -1430,8 +1440,10 @@ TEST(targetTest, CacheStateNODP)
int contextPP = 2;
int contextTP = 4;
int contextCP = 1;
int genPP = 2;
int genTP = 2;
int genCP = 1;
bool const contextEnableDP = false;
bool const genEnableDP = false;
@ -1441,10 +1453,10 @@ TEST(targetTest, CacheStateNODP)
auto attentionType = isMLA ? texec::kv_cache::CacheState::AttentionType::kMLA
: texec::kv_cache::CacheState::AttentionType::kDEFAULT;
auto const contextCache = tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead,
tokensPerBlock, contextTP, contextPP, dataType, attentionType, kvFactor, contextEnableDP, 0, 0};
tokensPerBlock, contextTP, contextPP, contextCP, dataType, attentionType, kvFactor, contextEnableDP, 0, 0};
auto const genCache = tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead,
tokensPerBlock, genTP, genPP, dataType, attentionType, kvFactor, genEnableDP, 0, 0};
tokensPerBlock, genTP, genPP, genCP, dataType, attentionType, kvFactor, genEnableDP, 0, 0};
auto const contextTragetInfo
= tensorrt_llm::executor::kv_cache::TargetRanksInfoForDP(genCache, contextCache, contextRank);
@ -1504,8 +1516,10 @@ TEST(targetTest, CacheStateContextDP)
int contextPP = 1;
int contextTP = 4;
int contextCP = 1;
int genPP = 1;
int genTP = 2;
int genCP = 1;
bool contextEnableDP = true;
bool genEnableDP = true;
@ -1519,10 +1533,11 @@ TEST(targetTest, CacheStateContextDP)
auto const contextCache
= tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead, tokensPerBlock, contextTP,
contextPP, dataType, attentionType, kvFactor, contextEnableDP, contextDPRank, contextTP};
contextPP, contextCP, dataType, attentionType, kvFactor, contextEnableDP, contextDPRank, contextTP};
auto const genCache = tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead,
tokensPerBlock, genTP, genPP, dataType, attentionType, kvFactor, genEnableDP, generationDPRank, genTP};
auto const genCache
= tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead, tokensPerBlock, genTP,
genPP, genCP, dataType, attentionType, kvFactor, genEnableDP, generationDPRank, genTP};
auto const contextTragetInfo
= tensorrt_llm::executor::kv_cache::TargetRanksInfoForDP(genCache, contextCache, contextRank);
@ -1625,10 +1640,11 @@ TEST(targetTest, CacheStateContextDP)
auto const contextCache
= tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead, tokensPerBlock, contextTP,
contextPP, dataType, attentionType, kvFactor, contextEnableDP, contextDPRank, contextTP};
contextPP, contextCP, dataType, attentionType, kvFactor, contextEnableDP, contextDPRank, contextTP};
auto const genCache = tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead,
tokensPerBlock, genTP, genPP, dataType, attentionType, kvFactor, genEnableDP, generationDPRank, genTP};
auto const genCache
= tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead, tokensPerBlock, genTP,
genPP, genCP, dataType, attentionType, kvFactor, genEnableDP, generationDPRank, genTP};
auto const contextTragetInfo
= tensorrt_llm::executor::kv_cache::TargetRanksInfoForDP(contextCache, genCache, generationRank);

View File

@ -90,7 +90,7 @@ protected:
size_t maxNumTokens = 1024;
mTransBufferManager = std::make_unique<CacheTransBufferManager>(mCacheManager.get(), maxNumTokens);
mCacheState = std::make_unique<CacheState>(numLayers, numHeads, sizePerHead, tokensPerBlock, 1, 1, dataType);
mCacheState = std::make_unique<CacheState>(numLayers, numHeads, sizePerHead, tokensPerBlock, 1, 1, 1, dataType);
}
void TearDown() override

View File

@ -726,7 +726,7 @@ TEST(SerializeUtilsTest, ContextPhaseParams)
{
auto state = std::make_unique<texec::DataTransceiverState>();
state->setCommState(texec::kv_cache::CommState{12, "127.0.0.1"});
state->setCacheState(texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, nvinfer1::DataType::kFLOAT});
state->setCacheState(texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, 8, nvinfer1::DataType::kFLOAT});
auto stats = texec::ContextPhaseParams({10, 20, 30, 40, 50, 60}, 0, state.release(), VecTokens{10, 20});
auto stats2 = serializeDeserialize(stats);
EXPECT_EQ(stats, stats2);

View File

@ -255,7 +255,8 @@ TEST_F(TransferAgentTest, SyncMessage)
checked = nixlAgent0->checkRemoteDescs(agent1, regMem3.getDescs());
} while (!checked);
auto syncMessage = std::string("agent_sync_message");
TransferRequest writeReq{TransferOp::kWRITE, regMem0.getDescs(), regMem3.getDescs(), agent1, syncMessage};
nixlAgent0->notifySyncMessage(agent1, syncMessage);
TransferRequest writeReq{TransferOp::kWRITE, regMem0.getDescs(), regMem3.getDescs(), agent1};
auto status = nixlAgent0->submitTransferRequests(writeReq);
auto notif = nixlAgent1->getNotifiedSyncMessages();
@ -302,7 +303,8 @@ TEST_F(TransferAgentTest, SyncMessage)
} while (!checked2);
std::string syncMessage4 = "four_agent_sync_message";
TransferRequest writeReq1{TransferOp::kWRITE, regMem2.getDescs(), regMem1.getDescs(), agent0, syncMessage4};
nixlAgent1->notifySyncMessage(agent0, syncMessage4);
TransferRequest writeReq1{TransferOp::kWRITE, regMem2.getDescs(), regMem1.getDescs(), agent0};
auto status1 = nixlAgent1->submitTransferRequests(writeReq1);
auto notif4 = nixlAgent0->getNotifiedSyncMessages();
for (std::size_t i = 0; i < MAX_QUERY_TIMES && notif4.size() == 0; i++)

View File

@ -370,8 +370,8 @@ protected:
float mSparseMixerEpsilon = 0.2f;
// Default this to true. This only matters for K>2, and so by doing this we will test the fused and unfused paths
bool mUseDeterministicHopperReduce = true;
// Default this to false. This only matters for K>2, and so by doing this we will test the fused and unfused paths
bool mUseFusedFinalize = false;
// Disable this for long running tests to speed up runtime
bool mIsLongTest = false;
@ -456,7 +456,7 @@ protected:
{
managed_buffers.clear();
mMoERunner.use_fused_finalize_ = k < 3 || !mUseDeterministicHopperReduce;
mMoERunner.use_fused_finalize_ = k < 3 || mUseFusedFinalize;
mHiddenSize = hidden_size;
mInterSize = hidden_size * mInterSizeFraction;
@ -1087,9 +1087,9 @@ protected:
return std::tuple{(void*) weight_1, (void*) weight_2, bias_1, bias2_ptr, scale_1, scale_2, scale_3};
}
auto getFilteredConfigs(int sm)
auto getFilteredConfigs(int sm, MoeGemmId gemm_id)
{
auto tactics = mMoERunner.getTactics();
auto tactics = mMoERunner.getTactics(gemm_id);
if (sm == 89 || sm >= 120)
{
// Filter some unsupported configs for L40S
@ -1120,17 +1120,27 @@ protected:
auto selectTacticsForArch(int sm)
{
bool is_tma_warp_specialized = sm >= 90 && !INT_QUANT;
auto tactics = getFilteredConfigs(sm);
auto it = std::find_if(tactics.begin(), tactics.end(),
auto epilogue_fusion_type = (is_tma_warp_specialized && mUseFusedFinalize)
? tensorrt_llm::cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE
: tensorrt_llm::cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::NONE;
auto tactics1 = getFilteredConfigs(sm, MoeGemmId::GEMM_1);
auto tactics2 = getFilteredConfigs(sm, MoeGemmId::GEMM_2);
auto it1 = std::find_if(tactics1.begin(), tactics1.end(),
[is_tma_warp_specialized](auto& c) { return c.is_tma_warp_specialized == is_tma_warp_specialized; });
if (it == tactics.end())
auto it2 = std::find_if(tactics2.begin(), tactics2.end(),
[is_tma_warp_specialized, epilogue_fusion_type](auto& c) {
return c.is_tma_warp_specialized == is_tma_warp_specialized
&& c.epilogue_fusion_type == epilogue_fusion_type;
});
if (it1 == tactics1.end() || it2 == tactics2.end())
{
// Fall back to any tactic
std::cout << "WARNING: Could not find config for sm version " << sm << std::endl;
return std::pair{tactics[0], tactics[0]};
it1 = (it1 == tactics1.end()) ? tactics1.begin() : it1;
it2 = (it2 == tactics2.end()) ? tactics2.begin() : it2;
}
return std::pair(*it, *it);
return std::pair(*it1, *it2);
}
using ConfigsToTestVec = std::vector<std::pair<tensorrt_llm::cutlass_extensions::CutlassGemmConfig,
@ -1164,7 +1174,7 @@ protected:
auto stream = mStream->get();
auto tactic1 = mInternalSelectedConfig1;
auto tactic2 = mInternalSelectedConfig2;
if (!tactic1)
if (!tactic1 || !tactic2)
{
int sm = getSMVersion();
std::tie(tactic1, tactic2) = selectTacticsForArch(sm);
@ -1629,8 +1639,9 @@ void MixtureOfExpertsTest<TypeParam_>::BasicPermuteTest(
auto [expected_experts, token_final_scales] = populateRouting(num_experts, num_tokens, k);
runMoEPermute(hidden_input, expected_experts, token_final_scales, hidden_size, num_experts, k);
bool should_be_deterministic
= mUseDeterministicHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120;
bool is_finalize_fusion = gemm2.epilogue_fusion_type
== tensorrt_llm::cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE;
bool should_be_deterministic = !is_finalize_fusion || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120;
if (should_be_deterministic && !mIsLongTest)
{
auto first_iter = getDataFromDevice(mFinalOutput, mTotalTokens * mHiddenSize);
@ -1749,7 +1760,7 @@ TYPED_TEST(MixtureOfExpertsTest, PermuteSwigluBias)
TYPED_TEST(MixtureOfExpertsTest, PermuteNonDeterministic)
{
this->mUseDeterministicHopperReduce = false;
this->mUseFusedFinalize = true;
// Just test case 3, cases 1&2 always use the fused paths
this->BasicPermuteTest(3);
}
@ -1896,8 +1907,10 @@ void MixtureOfExpertsTest<TypeParam_>::ParallelismTest(
// Only need to init the inputs on the first iteration
runMoEPermute(hidden_input, expected_experts, token_final_scales, hidden_size, num_experts, k,
MOEParallelismConfig{tp_size, i, ep_size, j}, enable_alltoall);
bool is_finalize_fusion = gemm2.epilogue_fusion_type
== tensorrt_llm::cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE;
bool should_be_deterministic
= mUseDeterministicHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120;
= !is_finalize_fusion || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120;
if (should_be_deterministic && !mIsLongTest)
{
auto first_iter = getDataFromDevice(mFinalOutput, mTotalTokens * mHiddenSize);
@ -1912,8 +1925,10 @@ void MixtureOfExpertsTest<TypeParam_>::ParallelismTest(
else
{
runMoEPermute(MOEParallelismConfig{tp_size, i, ep_size, j}, enable_alltoall);
bool is_finalize_fusion = gemm2.epilogue_fusion_type
== tensorrt_llm::cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE;
bool should_be_deterministic
= mUseDeterministicHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120;
= !is_finalize_fusion || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120;
if (should_be_deterministic && !mIsLongTest)
{
auto first_iter = getDataFromDevice(mFinalOutput, mTotalTokens * mHiddenSize);
@ -2077,6 +2092,7 @@ PARALLEL_TEST_SUITE(MixedParallel)
TYPED_TEST(MixtureOfExpertsTest, ConfigSweep)
{
this->mIsLongTest = true;
this->mUseFusedFinalize = true; // True for all cases because we sweep both
auto genConfigName = [](auto conf) -> std::string
{
using namespace tensorrt_llm::cutlass_extensions;
@ -2103,12 +2119,13 @@ TYPED_TEST(MixtureOfExpertsTest, ConfigSweep)
auto activation_pool = std::vector{ActivationType::Relu, ActivationType::Swiglu, ActivationType::SwigluBias};
if (this->NVFP4)
activation_pool = {ActivationType::Relu};
auto configs = this->getFilteredConfigs(getSMVersion());
auto configs1 = this->getFilteredConfigs(getSMVersion(), MoeGemmId::GEMM_1);
auto configs2 = this->getFilteredConfigs(getSMVersion(), MoeGemmId::GEMM_2);
for (auto const activation_type : activation_pool)
{
for (auto conf1 : configs)
for (auto conf1 : configs1)
{
for (auto conf2 : configs)
for (auto conf2 : configs2)
{
auto name1 = genConfigName(conf1);
auto name2 = genConfigName(conf2);
@ -2120,7 +2137,6 @@ TYPED_TEST(MixtureOfExpertsTest, ConfigSweep)
this->mActType = activation_type;
for (auto k : {2, 3})
{
this->mOverrideSelectedConfig1 = conf1;
this->mOverrideSelectedConfig2 = conf2;
this->BasicPermuteTest(k, this->MINIMUM_ALIGNMENT);

View File

@ -4,8 +4,9 @@ set -ex
GITHUB_URL="https://github.com"
UCX_INSTALL_PATH="/usr/local/ucx/"
CUDA_PATH="/usr/local/cuda"
NIXL_VERSION="0.3.1"
NIXL_VERSION="0.5.0"
NIXL_REPO="https://github.com/ai-dynamo/nixl.git"
OLD_LD_LIBRARY_PATH=$LD_LIBRARY_PATH
ARCH_NAME="x86_64-linux-gnu"
GDS_PATH="$CUDA_PATH/targets/x86_64-linux"
@ -18,25 +19,26 @@ pip3 install --no-cache-dir meson ninja pybind11
git clone --depth 1 -b ${NIXL_VERSION} ${NIXL_REPO}
cd nixl
cuda_path=$(find / -name "libcuda.so.1" 2>/dev/null | head -n1)
if [[ -z "$cuda_path" ]]; then
echo "libcuda.so.1 not found "
CUDA_SO_PATH=$(find "/usr/local" -name "libcuda.so.1" 2>/dev/null | head -n1)
if [[ -z "$CUDA_SO_PATH" ]]; then
echo "libcuda.so.1 not found"
exit 1
fi
ln -sf $cuda_path $CUDA_PATH/lib64/libcuda.so.1
CUDA_SO_PATH=$(dirname $CUDA_SO_PATH)
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CUDA_SO_PATH
meson setup builddir \
-Ducx_path=$UCX_INSTALL_PATH \
-Dcudapath_lib="$CUDA_PATH/lib64" \
-Dcudapath_inc="$CUDA_PATH/include" \
-Dgds_path="$GDS_PATH" \
-Dinstall_headers=true \
-Dstatic_plugins=UCX
-Dinstall_headers=true
cd builddir && ninja install
cd ../..
rm -rf nixl* # Remove NIXL source tree to save space
rm $CUDA_PATH/lib64/libcuda.so.1
export LD_LIBRARY_PATH=$OLD_LD_LIBRARY_PATH
echo "export LD_LIBRARY_PATH=/opt/nvidia/nvda_nixl/lib/${ARCH_NAME}:/opt/nvidia/nvda_nixl/lib64:\$LD_LIBRARY_PATH" >> "${ENV}"

View File

@ -2,29 +2,28 @@
set -ex
GITHUB_URL="https://github.com"
UCX_VERSION="v1.19.0"
UCX_VERSION="v1.19.x"
UCX_INSTALL_PATH="/usr/local/ucx/"
CUDA_PATH="/usr/local/cuda"
UCX_REPO="https://github.com/openucx/ucx.git"
if [ ! -d ${UCX_INSTALL_PATH} ]; then
git clone --depth 1 -b ${UCX_VERSION} ${UCX_REPO}
cd ucx
./autogen.sh
./contrib/configure-release \
--prefix=${UCX_INSTALL_PATH} \
--enable-shared \
--disable-static \
--disable-doxygen-doc \
--enable-optimizations \
--enable-cma \
--enable-devel-headers \
--with-cuda=${CUDA_PATH} \
--with-verbs \
--with-dm \
--enable-mt
make install -j$(nproc)
cd ..
rm -rf ucx # Remove UCX source to save space
echo "export LD_LIBRARY_PATH=${UCX_INSTALL_PATH}/lib:\$LD_LIBRARY_PATH" >> "${ENV}"
fi
rm -rf ${UCX_INSTALL_PATH}
git clone --depth 1 -b ${UCX_VERSION} ${UCX_REPO}
cd ucx
./autogen.sh
./contrib/configure-release \
--prefix=${UCX_INSTALL_PATH} \
--enable-shared \
--disable-static \
--disable-doxygen-doc \
--enable-optimizations \
--enable-cma \
--enable-devel-headers \
--with-cuda=${CUDA_PATH} \
--with-verbs \
--with-dm \
--enable-mt
make install -j$(nproc)
cd ..
rm -rf ucx # Remove UCX source to save space
echo "export LD_LIBRARY_PATH=${UCX_INSTALL_PATH}/lib:\$LD_LIBRARY_PATH" >> "${ENV}"

View File

@ -19,11 +19,11 @@ We have a forthcoming guide for achieving great performance on H100; however, th
In this section, we introduce several ways to install TensorRT-LLM.
### NGC Docker Image of dev branch
### NGC Docker Image
Day-0 support for gpt-oss is provided via the NGC container image `nvcr.io/nvidia/tensorrt-llm/release:gpt-oss-dev`. This image was built on top of the pre-day-0 **dev branch**. This container is multi-platform and will run on both x64 and arm64 architectures.
Visit the [NGC TensorRT-LLM Release page](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/tensorrt-llm/containers/release) to find the most up-to-date NGC container image to use. You can also check the latest [release notes](https://github.com/NVIDIA/TensorRT-LLM/releases) to keep track of the support status of the latest releases.
Run the following docker command to start the TensorRT-LLM container in interactive mode:
Run the following Docker command to start the TensorRT-LLM container in interactive mode (change the image tag to match latest release):
```bash
docker run --rm --ipc=host -it \
@ -33,7 +33,7 @@ docker run --rm --ipc=host -it \
-p 8000:8000 \
-e TRTLLM_ENABLE_PDL=1 \
-v ~/.cache:/root/.cache:rw \
nvcr.io/nvidia/tensorrt-llm/release:gpt-oss-dev \
nvcr.io/nvidia/tensorrt-llm/release:1.1.0rc0 \
/bin/bash
```
@ -53,9 +53,9 @@ Additionally, the container mounts your user `.cache` directory to save the down
Support for gpt-oss has been [merged](https://github.com/NVIDIA/TensorRT-LLM/pull/6645) into the **main branch** of TensorRT-LLM. As we continue to optimize gpt-oss performance, you can build TensorRT-LLM from source to get the latest features and support. Please refer to the [doc](https://nvidia.github.io/TensorRT-LLM/latest/installation/build-from-source-linux.html) if you want to build from source yourself.
### Regular Release of TensorRT-LLM
### TensorRT-LLM Python Wheel Install
Since gpt-oss has been supported on the main branch, you can get TensorRT-LLM out of the box through its regular release in the future. Please check the latest [release notes](https://github.com/NVIDIA/TensorRT-LLM/releases) to keep track of the support status. The release is provided as [NGC Container Image](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/tensorrt-llm/containers/release/tags) or [pip Python wheel](https://pypi.org/project/tensorrt-llm/#history). You can find instructions on pip install [here](https://nvidia.github.io/TensorRT-LLM/installation/linux.html).
Regular releases of TensorRT-LLM are also provided as [Python wheels](https://pypi.org/project/tensorrt-llm/#history). You can find instructions on the pip install [here](https://nvidia.github.io/TensorRT-LLM/installation/linux.html).
## Performance Benchmarking and Model Serving
@ -210,7 +210,10 @@ We can use `trtllm-serve` to serve the model by translating the benchmark comman
```bash
trtllm-serve \
gpt-oss-120b \ # Or ${local_model_path}
Note: You can also point to a local path containing the model weights instead of the HF repo (e.g., `${local_model_path}`).
trtllm-serve \
openai/gpt-oss-120b \
--host 0.0.0.0 \
--port 8000 \
--backend pytorch \
@ -228,7 +231,8 @@ For max-throughput configuration, run:
```bash
trtllm-serve \
gpt-oss-120b \ # Or ${local_model_path}
trtllm-serve \
openai/gpt-oss-120b \
--host 0.0.0.0 \
--port 8000 \
--backend pytorch \
@ -262,7 +266,7 @@ curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '
"messages": [
{
"role": "user",
"content": "What is NVIDIA's advantage for inference?"
"content": "What is NVIDIAs advantage for inference?"
}
],
"max_tokens": 1024,
@ -348,12 +352,7 @@ others according to your needs.
## (H200/H100 Only) Using OpenAI Triton Kernels for MoE
OpenAI ships a set of Triton kernels optimized for its MoE models. TensorRT-LLM can leverage these kernels for Hopper-based GPUs like NVIDIA's H200 for optimal performance. `TRTLLM` MoE backend is not supported on Hopper, and `CUTLASS` backend support is still ongoing. Please enable `TRITON` backend with the steps below if you are running on Hopper GPUs.
### Installing OpenAI Triton
The `nvcr.io/nvidia/tensorrt-llm/release:gpt-oss-dev` has prepared Triton already (`echo $TRITON_ROOT` could reveal the path). In other situations, you will need to build and install a specific version of Triton. Please follow the instructions in this [link](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/models/core/gpt_oss#using-openai-triton-kernels-for-moe).
OpenAI ships a set of Triton kernels optimized for its MoE models. TensorRT-LLM can leverage these kernels for Hopper-based GPUs like NVIDIA's H200 for optimal performance. `TRTLLM` MoE backend is not supported on Hopper, and `CUTLASS` backend support is still ongoing. Please follow the instructions in this [link](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/models/core/gpt_oss#using-openai-triton-kernels-for-moe) to install and enable the `TRITON` MoE kernels on Hopper GPUs.
### Selecting Triton as the MoE backend

View File

@ -0,0 +1,328 @@
# Quick Start Recipe for GPT-OSS on TensorRT-LLM - Blackwell Hardware
## Introduction
This deployment guide provides step-by-step instructions for running the GPT-OSS model using TensorRT-LLM, optimized for NVIDIA GPUs. It covers the complete setup required; from accessing model weights and preparing the software environment to configuring TensorRT-LLM parameters, launching the server, and validating inference output.
The guide is intended for developers and practitioners seeking high-throughput or low-latency inference using NVIDIAs accelerated stack—starting with the PyTorch container from NGC, then installing TensorRT-LLM for model serving.
## Prerequisites
* GPU: NVIDIA Blackwell Architecture
* OS: Linux
* Drivers: CUDA Driver 575 or Later
* Docker with NVIDIA Container Toolkit installed
* Python3 and python3-pip (Optional, for accuracy evaluation only)
## Models
* MXFP4 model: [GPT-OSS-120B](https://huggingface.co/openai/gpt-oss-120b)
## MoE Backend Support Matrix
There are multiple MOE backends inside TRT-LLM. Here are the support matrix of the MOE backends.
| Device | Activation Type | MoE Weights Type | MoE Backend | Use Case |
|------------|------------------|------------------|-------------|----------------|
| B200/GB200 | MXFP8 | MXFP4 | TRTLLM | Low Latency |
| B200/GB200 | MXFP8 | MXFP4 | CUTLASS | Max Throughput |
The default moe backend is `CUTLASS`, so for the combination which is not supported by `CUTLASS`, one must set the `moe_config.backend` explicitly to run the model.
## Deployment Steps
### Run Docker Container
Run the docker container using the TensorRT-LLM NVIDIA NGC image.
```shell
docker run --rm -it \
--ipc=host \
--gpus all \
-p 8000:8000 \
-v ~/.cache:/root/.cache:rw \
--name tensorrt_llm \
nvcr.io/nvidia/tensorrt-llm/release:1.0.0rc6 \
/bin/bash
```
Note:
* The command mounts your user `.cache` directory to save the downloaded model checkpoints which are saved to `~/.cache/huggingface/hub/` by default. This prevents having to redownload the weights each time you rerun the container. If the `~/.cache` directory doesnt exist please create it using `$ mkdir ~/.cache`.
* You can mount additional directories and paths using the `-v <host_path>:<container_path>` flag if needed, such as mounting the downloaded weight paths.
* The command also maps port `8000` from the container to your host so you can access the LLM API endpoint from your host
* See the <https://catalog.ngc.nvidia.com/orgs/nvidia/teams/tensorrt-llm/containers/release/tags> for all the available containers. The containers published in the main branch weekly have `rcN` suffix, while the monthly release with QA tests has no `rcN` suffix. Use the `rc` release to get the latest model and feature support.
If you want to use latest main branch, you can choose to build from source to install TensorRT-LLM, the steps refer to <https://nvidia.github.io/TensorRT-LLM/latest/installation/build-from-source-linux.html>.
### Creating the TRT-LLM Server config
We create a YAML configuration file `/tmp/config.yml` for the TensorRT-LLM Server and populate it with the following recommended performance settings.
For low-latency with `TRTLLM` MOE backend:
```shell
EXTRA_LLM_API_FILE=/tmp/config.yml
cat << EOF > ${EXTRA_LLM_API_FILE}
enable_attention_dp: false
cuda_graph_config:
enable_padding: true
max_batch_size: 128
moe_config:
backend: TRTLLM
EOF
```
For max-throughput with `CUTLASS` MOE backend:
```shell
EXTRA_LLM_API_FILE=/tmp/config.yml
cat << EOF > ${EXTRA_LLM_API_FILE}
enable_attention_dp: true
cuda_graph_config:
enable_padding: true
max_batch_size: 128
moe_config:
backend: CUTLASS
EOF
```
### Launch the TRT-LLM Server
Below is an example command to launch the TRT-LLM server with the GPT-OSS model from within the container. The command is specifically configured for the 1024/1024 Input/Output Sequence Length test. The explanation of each flag is shown in the “Configs and Parameters” section.
```shell
trtllm-serve openai/gpt-oss-120b \
--host 0.0.0.0 \
--port 8000 \
--backend pytorch \
--max_batch_size 128 \
--max_num_tokens 16384 \
--max_seq_len 2048 \
--kv_cache_free_gpu_memory_fraction 0.9 \
--tp_size 8 \
--ep_size 8 \
--trust_remote_code \
--extra_llm_api_options ${EXTRA_LLM_API_FILE}
```
After the server is set up, the client can now send prompt requests to the server and receive results.
### Configs and Parameters
These options are used directly on the command line when you start the `trtllm-serve` process.
#### `--tp_size`
* **Description:** Sets the **tensor-parallel size**. This should typically match the number of GPUs you intend to use for a single model instance.
#### `--ep_size`
* **Description:** Sets the **expert-parallel size** for Mixture-of-Experts (MoE) models. Like `tp_size`, this should generally match the number of GPUs you're using. This setting has no effect on non-MoE models.
#### `--kv_cache_free_gpu_memory_fraction`
* **Description:** A value between `0.0` and `1.0` that specifies the fraction of free GPU memory to reserve for the KV cache after the model is loaded. Since memory usage can fluctuate, this buffer helps prevent out-of-memory (OOM) errors.
* **Recommendation:** If you experience OOM errors, try reducing this value to `0.7` or lower.
#### `--backend pytorch`
* **Description:** Tells TensorRT-LLM to use the **pytorch** backend.
#### `--max_batch_size`
* **Description:** The maximum number of user requests that can be grouped into a single batch for processing.
#### `--max_num_tokens`
* **Description:** The maximum total number of tokens (across all requests) allowed inside a single scheduled batch.
#### `--max_seq_len`
* **Description:** The maximum possible sequence length for a single request, including both input and generated output tokens.
#### `--trust_remote_code`
* **Description:** Allows TensorRT-LLM to download models and tokenizers from Hugging Face. This flag is passed directly to the Hugging Face API.
#### Extra LLM API Options (YAML Configuration)
These options provide finer control over performance and are set within a YAML file passed to the `trtllm-serve` command via the `--extra_llm_api_options` argument.
#### `cuda_graph_config`
* **Description**: A section for configuring CUDA graphs to optimize performance.
* **Options**:
* `enable_padding`: If `"true"`, input batches are padded to the nearest `cuda_graph_batch_size`. This can significantly improve performance.
**Default**: `false`
* `max_batch_size`: Sets the maximum batch size for which a CUDA graph will be created.
**Default**: `0`
**Recommendation**: Set this to the same value as the `--max_batch_size` command-line option.
#### `moe_config`
* **Description**: Configuration for Mixture-of-Experts (MoE) models.
* **Options**:
* `backend`: The backend to use for MoE operations.
**Default**: `CUTLASS`
See the [`TorchLlmArgs` class](https://nvidia.github.io/TensorRT-LLM/llm-api/reference.html#tensorrt_llm.llmapi.TorchLlmArgs) for the full list of options which can be used in the `extra_llm_api_options`.
## Testing API Endpoint
### Basic Test
Start a new terminal on the host to test the TensorRT-LLM server you just launched.
You can query the health/readiness of the server using:
```shell
curl -s -o /dev/null -w "Status: %{http_code}\n" "http://localhost:8000/health"
```
When the `Status: 200` code is returned, the server is ready for queries. Note that the very first query may take longer due to initialization and compilation.
After the TRT-LLM server is set up and shows Application startup complete, you can send requests to the server.
```shell
curl http://localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{
"model": "openai/gpt-oss-120b",
"messages": [
{
"role": "user",
"content": "Where is New York?"
}
],
"max_tokens": 1024,
"top_p": 1.0
}' -w "\n"
```
Here is an example response, showing that the TRT-LLM server reasons and answers the questions.
TODO: Use Chat Compeletions API / Responses API as the example after the PR is merged.
```json
{"id":"chatcmpl-c5bf51b5cab94e10ba5da5266d12ee59","object":"chat.completion","created":1755815898,"model":"openai/gpt-oss-120b","choices":[{"index":0,"message":{"role":"assistant","content":"analysisThe user asks: \"Where is New York?\" Likely they want location info. Provide answer: New York State in northeastern US, New York City on the east coast, coordinates, etc. Provide context.assistantfinal**NewYork** can refer to two related places in the United States:\n\n| What it is | Where it is | Approx. coordinates | How to picture it |\n|------------|------------|--------------------|-------------------|\n| **NewYork State** | The northeastern corner of the United States, bordered by **Vermont, Massachusetts, Connecticut, NewJersey, Pennsylvania, and the Canadian provinces of Ontario and Quebec**. | 42.7°N,75.5°W (roughly the states geographic centre) | A roughly rectangular state that stretches from the Atlantic Ocean in the southeast to the Adirondack Mountains and the Great Lakes region in the north. |\n| **NewYork City (NYC)** | The largest city in the state, located on the **southern tip of the state** where the **Hudson River meets the Atlantic Ocean**. It occupies five boroughs: Manhattan, Brooklyn, Queens, The Bronx, and Staten Island. | 40.7128°N,74.0060°W | A dense, worldfamous metropolis that sits on a series of islands (Manhattan, StatenIsland, parts of the Bronx) and the mainland (Brooklyn and Queens). |\n\n### Quick geographic context\n- **On a map of the United States:** NewYork State is in the **Northeast** region, just east of the Great Lakes and north of Pennsylvania. \n- **From Washington, D.C.:** Travel roughly **225mi (360km) northeast**. \n- **From Boston, MA:** Travel about **215mi (350km) southwest**. \n- **From Toronto, Canada:** Travel about **500mi (800km) southeast**.\n\n### Travel tips\n- **By air:** Major airports include **JohnF.Kennedy International (JFK)**, **LaGuardia (LGA)**, and **Newark Liberty International (EWR)** (the latter is actually in NewJersey but serves the NYC metro area). \n- **By train:** Amtraks **Northeast Corridor** runs from **Boston → NewYork City → Washington, D.C.** \n- **By car:** Interstates **I87** (northsouth) and **I90** (eastwest) are the primary highways crossing the state.\n\n### Fun fact\n- The name “**NewYork**” was given by the English in 1664, honoring the Duke of York (later King JamesII). The citys original Dutch name was **“NewAmsterdam.”**\n\nIf you need more specific directions (e.g., how to get to a particular neighborhood, landmark, or the state capital **Albany**), just let me know!","reasoning_content":null,"tool_calls":[]},"logprobs":null,"finish_reason":"stop","stop_reason":null,"mm_embedding_handle":null,"disaggregated_params":null,"avg_decoded_tokens_per_iter":1.0}],"usage":{"prompt_tokens":72,"total_tokens":705,"completion_tokens":633},"prompt_token_ids":null}
```
### Troubleshooting Tips
* If you encounter CUDA out-of-memory errors, try reducing `max_batch_size` or `max_seq_len`.
* Ensure your model checkpoints are compatible with the expected format.
* For performance issues, check GPU utilization with nvidia-smi while the server is running.
* If the container fails to start, verify that the NVIDIA Container Toolkit is properly installed.
* For connection issues, make sure the server port (`8000` in this guide) is not being used by another application.
### Running Evaluations to Verify Accuracy (Optional)
We use OpenAI's official evaluation tool to test the model's accuracy. For more information see [https://github.com/openai/gpt-oss/tree/main/gpt_oss/evals](gpt-oss-eval).
TODO(@Binghan Chen): Add instructions for running gpt-oss-eval.
## Benchmarking Performance
To benchmark the performance of your TensorRT-LLM server you can leverage the built-in `benchmark_serving.py` script. To do this first creating a wrapper `bench.sh` script.
```shell
cat <<'EOF' > bench.sh
#!/usr/bin/env bash
set -euo pipefail
concurrency_list="32 64 128 256 512 1024 2048 4096"
multi_round=5
isl=1024
osl=1024
result_dir=/tmp/gpt_oss_output
for concurrency in ${concurrency_list}; do
num_prompts=$((concurrency * multi_round))
python -m tensorrt_llm.serve.scripts.benchmark_serving \
--model openai/gpt-oss-120b \
--backend openai \
--dataset-name "random" \
--random-input-len ${isl} \
--random-output-len ${osl} \
--random-prefix-len 0 \
--random-ids \
--num-prompts ${num_prompts} \
--max-concurrency ${concurrency} \
--ignore-eos \
--tokenize-on-client \
--percentile-metrics "ttft,tpot,itl,e2el"
done
EOF
chmod +x bench.sh
```
If you want to save the results to a file add the following options.
```shell
--save-result \
--result-dir "${result_dir}" \
--result-filename "concurrency_${concurrency}.json"
```
For more benchmarking options see <https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt\_llm/serve/scripts/benchmark\_serving.py>.
Run `bench.sh` to begin a serving benchmark. This will take a long time if you run all the concurrencies mentioned in the above `bench.sh` script.
```shell
./bench.sh
```
Sample TensorRT-LLM serving benchmark output. Your results may vary due to ongoing software optimizations.
```
============ Serving Benchmark Result ============
Successful requests: 16
Benchmark duration (s): 17.66
Total input tokens: 16384
Total generated tokens: 16384
Request throughput (req/s): [result]
Output token throughput (tok/s): [result]
Total Token throughput (tok/s): [result]
User throughput (tok/s): [result]
---------------Time to First Token----------------
Mean TTFT (ms): [result]
Median TTFT (ms): [result]
P99 TTFT (ms): [result]
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): [result]
Median TPOT (ms): [result]
P99 TPOT (ms): [result]
---------------Inter-token Latency----------------
Mean ITL (ms): [result]
Median ITL (ms): [result]
P99 ITL (ms): [result]
----------------End-to-end Latency----------------
Mean E2EL (ms): [result]
Median E2EL (ms): [result]
P99 E2EL (ms): [result]
==================================================
```
### Key Metrics
* Median Time to First Token (TTFT)
* The typical time elapsed from when a request is sent until the first output token is generated.
* Median Time Per Output Token (TPOT)
* The typical time required to generate each token *after* the first one.
* Median Inter-Token Latency (ITL)
* The typical time delay between the completion of one token and the completion of the next.
* Median End-to-End Latency (E2EL)
* The typical total time from when a request is submitted until the final token of the response is received.
* Total Token Throughput
* The combined rate at which the system processes both input (prompt) tokens and output (generated) tokens.

View File

@ -38,6 +38,7 @@ Welcome to TensorRT-LLM's Documentation!
deployment-guide/quick-start-recipe-for-llama4-scout-on-trtllm.md
deployment-guide/quick-start-recipe-for-deepseek-r1-on-trtllm.md
deployment-guide/quick-start-recipe-for-llama3.3-70b-on-trtllm.md
deployment-guide/quick-start-recipe-for-gpt-oss-on-trtllm.md
.. toctree::

View File

@ -0,0 +1,77 @@
# Serving with trtllm-serve
AutoDeploy integrates with the OpenAI-compatible `trtllm-serve` CLI so you can expose AutoDeploy-optimized models over HTTP without writing server code. This page shows how to launch the server with the AutoDeploy backend, configure it via YAML, and validate with a simple request.
## Quick start
Launch `trtllm-serve` with the AutoDeploy backend by setting `--backend _autodeploy`:
```bash
trtllm-serve \
meta-llama/Llama-3.1-8B-Instruct \
--backend _autodeploy
```
- `model`: HF name or local path
- `--backend _autodeploy`: uses AutoDeploy runtime
Once the server is ready, test with an OpenAI-compatible request:
```bash
curl -s http://localhost:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d '{
"model": "meta-llama/Llama-3.1-8B-Instruct",
"messages":[{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Where is New York? Tell me in a single sentence."}],
"max_tokens": 32
}'
```
## Configuration via YAML
Use `--extra_llm_api_options` to supply a YAML file that augments or overrides server/runtime settings.
```bash
trtllm-serve \
meta-llama/Llama-3.1-8B \
--backend _autodeploy \
--extra_llm_api_options autodeploy_config.yaml
```
Example `autodeploy_config.yaml`:
```yaml
# Compilation backend for AutoDeploy
compile_backend: torch-opt # options: torch-simple, torch-compile, torch-cudagraph, torch-opt
# Runtime engine
runtime: trtllm # options: trtllm, demollm
# Model loading
skip_loading_weights: false # set true for architecture-only perf runs
# KV cache memory
free_mem_ratio: 0.8 # fraction of free GPU mem for KV cache
# CUDA graph optimization
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64]
# Attention backend
attn_backend: flashinfer # recommended for best performance
```
## Limitations and tips
- KV cache block reuse is disabled automatically for AutoDeploy backend
- AutoDeploy backend doesn't yet support disaggregated serving. WIP
- For best performance:
- Prefer `compile_backend: torch-opt`
- Use `attn_backend: flashinfer`
- Set realistic `cuda_graph_batch_sizes` that match expected traffic
- Tune `free_mem_ratio` to 0.80.9
## See also
- [AutoDeploy overview](../auto-deploy.md)
- [Benchmarking with trtllm-bench](./benchmarking_with_trtllm_bench.md)

View File

@ -59,6 +59,7 @@ The exported graph then undergoes a series of automated transformations, includi
- [Incorporating AutoDeploy into Your Own Workflow](./advanced/workflow.md)
- [Expert Configurations](./advanced/expert_configurations.md)
- [Performance Benchmarking](./advanced/benchmarking_with_trtllm_bench.md)
- [Serving with trtllm-serve](./advanced/serving_with_trtllm_serve.md)
## Roadmap

View File

@ -1,3 +1,3 @@
tensorrt_llm==1.1.0rc1
tensorrt_llm==1.1.0rc2
evaluate~=0.4.1
rouge_score~=0.1.2

View File

@ -7,8 +7,8 @@ from difflib import SequenceMatcher
import torch
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi.llm_args import KvCacheConfig
from tensorrt_llm.mapping import CpType
from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig
def dump_jsonl(data, fname):
@ -54,11 +54,8 @@ def similarity_score(a, b):
return SequenceMatcher(None, a, b).ratio()
# Generate the outputs using either TRT or PyTorch (based on the use_pytorch argument). Its the same function for both workflows.
def generate_llm_outputs(args, data, fp8=False, fp8_kv_cache=False):
quant_config = QuantConfig(quant_algo=QuantAlgo.FP8,
kv_cache_quant_algo=QuantAlgo.FP8 if fp8_kv_cache
else None) if fp8 else QuantConfig()
kv_cache_config = KvCacheConfig(dtype="fp8" if fp8_kv_cache else "auto")
cp_config = {
"cp_type": CpType.STAR,
"cp_anchor_size": args.sa_anchor_size,
@ -70,7 +67,7 @@ def generate_llm_outputs(args, data, fp8=False, fp8_kv_cache=False):
max_input_len=args.max_input_len,
max_seq_len=args.max_seq_len,
max_num_tokens=args.max_num_tokens,
quant_config=quant_config,
kv_cache_config=kv_cache_config,
tensor_parallel_size=1,
context_parallel_size=args.num_procs,
cp_config=cp_config,

View File

@ -57,10 +57,10 @@ def CONFIG_LINUX_AARCH64_CU12 = "linux_aarch64_CU12"
def CONFIG_LINUX_AARCH64_LLVM = "linux_aarch64_LLVM"
@Field
def CONFIG_LINUX_X86_64_NANOBIND = "linux_x86_64_Nanobind"
def CONFIG_LINUX_X86_64_PYBIND = "linux_x86_64_Pybind"
@Field
def CONFIG_LINUX_AARCH64_NANOBIND = "linux_aarch64_Nanobind"
def CONFIG_LINUX_AARCH64_PYBIND = "linux_aarch64_Pybind"
@Field
def BUILD_CONFIGS = [
@ -76,9 +76,9 @@ def BUILD_CONFIGS = [
(TARNAME) : "TensorRT-LLM-CU12.tar.gz",
(WHEEL_ARCHS): "80-real;86-real;89-real;90-real;100-real;120-real",
],
(CONFIG_LINUX_X86_64_NANOBIND) : [
(WHEEL_EXTRA_ARGS) : "--binding_type nanobind --extra-cmake-vars ENABLE_MULTI_DEVICE=1 --extra-cmake-vars WARNING_IS_ERROR=ON --extra-cmake-vars NIXL_ROOT=/opt/nvidia/nvda_nixl --micro_benchmarks",
(TARNAME) : "nanobind-TensorRT-LLM.tar.gz",
(CONFIG_LINUX_X86_64_PYBIND) : [
(WHEEL_EXTRA_ARGS) : "--binding_type pybind --extra-cmake-vars ENABLE_MULTI_DEVICE=1 --extra-cmake-vars WARNING_IS_ERROR=ON --extra-cmake-vars NIXL_ROOT=/opt/nvidia/nvda_nixl --micro_benchmarks",
(TARNAME) : "pybind-TensorRT-LLM.tar.gz",
(WHEEL_ARCHS): "80-real;86-real;89-real;90-real;100-real;120-real",
],
(CONFIG_LINUX_X86_64_SINGLE_DEVICE) : [
@ -101,9 +101,9 @@ def BUILD_CONFIGS = [
(TARNAME) : "TensorRT-LLM-GH200-CU12.tar.gz",
(WHEEL_ARCHS): "90-real;100-real;120-real",
],
(CONFIG_LINUX_AARCH64_NANOBIND): [
(WHEEL_EXTRA_ARGS) : "--binding_type nanobind --extra-cmake-vars WARNING_IS_ERROR=ON",
(TARNAME) : "nanobind-TensorRT-LLM-GH200.tar.gz",
(CONFIG_LINUX_AARCH64_PYBIND): [
(WHEEL_EXTRA_ARGS) : "--binding_type pybind --extra-cmake-vars WARNING_IS_ERROR=ON",
(TARNAME) : "pybind-TensorRT-LLM-GH200.tar.gz",
(WHEEL_ARCHS): "90-real;100-real;120-real",
],
(CONFIG_LINUX_AARCH64_LLVM) : [
@ -568,8 +568,8 @@ def launchStages(pipeline, cpu_arch, enableFailFast, globalVars)
pipeline, cpu_arch == AARCH64_TRIPLE ? CONFIG_LINUX_AARCH64_CU12 : CONFIG_LINUX_X86_64_VANILLA_CU12),
"Build TRT-LLM LLVM": [LLM_DOCKER_IMAGE] + prepareLLMBuild(
pipeline, cpu_arch == AARCH64_TRIPLE ? CONFIG_LINUX_AARCH64_LLVM : CONFIG_LINUX_X86_64_LLVM),
"Build TRT-LLM Nanobind": [LLM_DOCKER_IMAGE] + prepareLLMBuild(
pipeline, cpu_arch == AARCH64_TRIPLE ? CONFIG_LINUX_AARCH64_NANOBIND : CONFIG_LINUX_X86_64_NANOBIND),
"Build TRT-LLM Pybind": [LLM_DOCKER_IMAGE] + prepareLLMBuild(
pipeline, cpu_arch == AARCH64_TRIPLE ? CONFIG_LINUX_AARCH64_PYBIND : CONFIG_LINUX_X86_64_PYBIND),
]
if (cpu_arch == X86_64_TRIPLE) {

View File

@ -74,7 +74,7 @@ def LINUX_AARCH64_CONFIG = "linux_aarch64"
def LINUX_AARCH64_CONFIG_CU12 = "linux_aarch64_CU12"
@Field
def NANOBIND_CONFIG = "Nanobind"
def PYBIND_CONFIG = "Pybind"
@Field
def BUILD_CONFIGS = [
@ -85,7 +85,7 @@ def BUILD_CONFIGS = [
(LLVM_CONFIG) : [(TARNAME) : "llvm-TensorRT-LLM.tar.gz"],
(LINUX_AARCH64_CONFIG) : [(TARNAME) : "TensorRT-LLM-GH200.tar.gz"],
(LINUX_AARCH64_CONFIG_CU12) : [(TARNAME) : "TensorRT-LLM-GH200-CU12.tar.gz"],
(NANOBIND_CONFIG) : [(TARNAME) : "nanobind-TensorRT-LLM.tar.gz"],
(PYBIND_CONFIG) : [(TARNAME) : "pybind-TensorRT-LLM.tar.gz"],
]
// TODO: Move common variables to an unified location
@ -657,8 +657,7 @@ def createKubernetesPodConfig(image, type, arch = "amd64", gpuCount = 1, perfMod
def driverVersion = Constants.DEFAULT_NVIDIA_DRIVER_VERSION
def cpuCount = "${TESTER_CORES}"
// Multi-GPU only supports DGX-H100 and DGX-H200 due to the hardware stability.
if ((type.contains("dgx-h100") || type.contains("dgx-h200")) && hasMultipleGPUs)
if (hasMultipleGPUs)
{
// Not a hard requirement, but based on empirical values.
memorySize = "${gpuCount * 150}" + "Gi"
@ -672,7 +671,7 @@ def createKubernetesPodConfig(image, type, arch = "amd64", gpuCount = 1, perfMod
targetCould = "kubernetes"
// The following GPU types doesn't support dynamic driver flashing.
if (type.contains("dgx-h100") || type.contains("dgx-h200") || type in ["b100-ts2", "gh200", "rtx-5080", "rtx-5090"]) {
if (type.contains("dgx-h100") || type.contains("dgx-h200") || type.contains("rtx-pro-6000") || type in ["b100-ts2", "gh200", "rtx-5080", "rtx-5090"]) {
selectors = """
kubernetes.io/arch: ${arch}
kubernetes.io/os: linux
@ -1281,6 +1280,7 @@ def runLLMTestlistOnPlatformImpl(pipeline, platform, testList, config=VANILLA_CO
echoNodeAndGpuInfo(pipeline, stageName)
sh "cat ${MODEL_CACHE_DIR}/README"
sh "nvidia-smi -q"
sh "nvidia-smi topo -m"
sh "df -h"
// setup HF_HOME to cache model and datasets
@ -1789,7 +1789,7 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null)
"A10-TensorRT-4": ["a10", "l0_a10", 4, 6],
"A10-TensorRT-5": ["a10", "l0_a10", 5, 6],
"A10-TensorRT-6": ["a10", "l0_a10", 6, 6],
"A10-Nanobind": ["a10", "l0_a10_nanobind", 1, 1],
"A10-Pybind": ["a10", "l0_a10_pybind", 1, 1],
"A30-Triton-1": ["a30", "l0_a30", 1, 1],
"A30-PyTorch-1": ["a30", "l0_a30", 1, 2],
"A30-PyTorch-2": ["a30", "l0_a30", 2, 2],
@ -1809,8 +1809,6 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null)
"B200_PCIe-PyTorch-1": ["b100-ts2", "l0_b200", 1, 3],
"B200_PCIe-PyTorch-2": ["b100-ts2", "l0_b200", 2, 3],
"B200_PCIe-PyTorch-3": ["b100-ts2", "l0_b200", 3, 3],
"B200_PCIe-TensorRT-1": ["b100-ts2", "l0_b200", 1, 2],
"B200_PCIe-TensorRT-2": ["b100-ts2", "l0_b200", 2, 2],
"RTX5090-PyTorch-1": ["rtx-5090", "l0_gb202", 1, 1],
"RTX5080-TensorRT-1": ["rtx-5080", "l0_gb203", 1, 2],
"RTX5080-TensorRT-2": ["rtx-5080", "l0_gb203", 2, 2],
@ -1850,6 +1848,8 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null)
"H100_PCIe-TensorRT-Post-Merge-5": ["h100-cr", "l0_h100", 5, 5],
"B200_PCIe-Triton-Post-Merge-1": ["b100-ts2", "l0_b200", 1, 1],
"B200_PCIe-PyTorch-Post-Merge-1": ["b100-ts2", "l0_b200", 1, 1],
"B200_PCIe-TensorRT-Post-Merge-1": ["b100-ts2", "l0_b200", 1, 2],
"B200_PCIe-TensorRT-Post-Merge-2": ["b100-ts2", "l0_b200", 2, 2],
"H100_PCIe-TensorRT-Perf-1": ["h100-cr", "l0_perf", 1, 1],
"H100_PCIe-PyTorch-Perf-1": ["h100-cr", "l0_perf", 1, 1],
"DGX_H200-8_GPUs-PyTorch-Post-Merge-1": ["dgx-h200-x8", "l0_dgx_h200", 1, 1, 8],
@ -1857,6 +1857,9 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null)
"DGX_H200-4_GPUs-TensorRT-Post-Merge-1": ["dgx-h200-x4", "l0_dgx_h200", 1, 3, 4],
"DGX_H200-4_GPUs-TensorRT-Post-Merge-2": ["dgx-h200-x4", "l0_dgx_h200", 2, 3, 4],
"DGX_H200-4_GPUs-TensorRT-Post-Merge-3": ["dgx-h200-x4", "l0_dgx_h200", 3, 3, 4],
"RTXPro6000-Pytorch-Post-Merge-1": ["rtx-pro-6000", "l0_rtx_pro_6000", 1, 1],
"RTXPro6000-4_GPUs-Pytorch-Post-Merge-1": ["rtx-pro-6000-x4", "l0_rtx_pro_6000", 1, 2, 4],
"RTXPro6000-4_GPUs-Pytorch-Post-Merge-2": ["rtx-pro-6000-x4", "l0_rtx_pro_6000", 2, 2, 4],
]
parallelJobs = x86TestConfigs.collectEntries{key, values -> [key, [createKubernetesPodConfig(key.contains("-CU12-") ? LLM_DOCKER_IMAGE_12_9 : LLM_DOCKER_IMAGE, values[0], "amd64", values[4] ?: 1, key.contains("Perf")), {
@ -1867,8 +1870,8 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null)
if (key.contains("llvm")) {
config = LLVM_CONFIG
}
if (key.contains("Nanobind")) {
config = NANOBIND_CONFIG
if (key.contains("Pybind")) {
config = PYBIND_CONFIG
}
if (key.contains("-CU12-")) {
config = VANILLA_CONFIG_CU12
@ -1878,7 +1881,6 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null)
fullSet = parallelJobs.keySet()
x86SlurmTestConfigs = [
"RTXPro6000-PyTorch-Post-Merge-1": ["rtx-pro-6000", "l0_rtx_pro_6000", 1, 1],
"DGX_B200-4_GPUs-PyTorch-Post-Merge-1": ["b200-x4", "l0_dgx_b200", 1, 1, 4],
]
fullSet += x86SlurmTestConfigs.keySet()
@ -2099,11 +2101,6 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null)
checkPipStage = true
}
if (cpu_arch == AARCH64_TRIPLE && values[5] != DLFW_IMAGE) {
checkPipStage = false
echo "Skip pip install sanity check due to https://nvbugs/5453827"
}
if (checkPipStage) {
stage("Run LLMAPI tests") {
pipInstallSanitySpec = createKubernetesPodConfig(values[5], gpu_type, k8s_arch)
@ -2484,7 +2481,7 @@ pipeline {
def testPhase2StageName = env.testPhase2StageName
if (testPhase2StageName) {
def dgxSigns = ["DGX_H100", "DGX_H200", "GB200", "DGX_B200"]
def dgxSigns = ["DGX_H100", "DGX_H200", "GB200", "DGX_B200", "RTXPro6000-4_GPUs"]
singleGpuJobs = parallelJobs.findAll{!dgxSigns.any{sign -> it.key.contains(sign)}}
dgxJobs = parallelJobs.findAll{dgxSigns.any{sign -> it.key.contains(sign)}}
}

View File

@ -1,6 +1,6 @@
--extra-index-url https://download.pytorch.org/whl/cu128
-c constraints.txt
accelerate>=0.25.0
accelerate>=1.7.0
build
colored
cuda-python>=12

View File

@ -435,7 +435,7 @@ def main(*,
install: bool = False,
skip_building_wheel: bool = False,
linking_install_binary: bool = False,
binding_type: str = "pybind",
binding_type: str = "nanobind",
benchmarks: bool = False,
micro_benchmarks: bool = False,
nvtx: bool = False,
@ -984,8 +984,8 @@ def add_arguments(parser: ArgumentParser):
)
parser.add_argument("--binding_type",
choices=["pybind", "nanobind"],
default="pybind",
help="Which binding type to build: pybind or nanobind")
default="nanobind",
help="Which binding library to use: pybind or nanobind")
parser.add_argument("--benchmarks",
action="store_true",
help="Build the benchmarks for the C++ runtime")

View File

@ -19,6 +19,11 @@ transforms:
stage: post_export
cleanup_input_constraints:
stage: post_export
############################################################################################
# RUN PATTERN MATCHER TRANSFORMATIONS TO STANDARDIZE GRAPH REPRESENTATION
############################################################################################
match_moe_pattern:
stage: pattern_matcher
match_repeat_kv:
stage: pattern_matcher
match_eager_attention:
@ -27,12 +32,13 @@ transforms:
stage: pattern_matcher
match_attention_layout:
stage: pattern_matcher
match_moe_pattern:
stage: pattern_matcher
match_rope_pattern:
stage: pattern_matcher
match_rope_layout:
stage: pattern_matcher
############################################################################################
# RUN TRANSFORMATIONS ON STANDARDIZED GRAPH REPRESENTATION
############################################################################################
eliminate_redundant_transposes:
stage: pattern_matcher
# TODO (lucaslie): let's move this to perf optimization once TP sharding is improved
@ -57,5 +63,44 @@ transforms:
sharding_transform_executor:
stage: sharding
run_shape_prop: true
############################################################################################
# MOVE MODEL AND LOAD WEIGHTS
############################################################################################
load_weights:
stage: weight_load
############################################################################################
# RUN POST-LOAD FUSION AND OPTIMIZATIONS
############################################################################################
# TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/4674 this is causing OOMs
# fuse_moe:
# stage: post_load_fusion
# fuse_gemms:
# stage: post_load_fusion
fuse_allreduce_residual_rmsnorm:
stage: post_load_fusion
fuse_collectives:
stage: post_load_fusion
# TODO (lucaslie): add backend selection as part of configurable inference optimizers
# check if we can fuse rmsnorm
fuse_rmsnorm:
stage: post_load_fusion
backend: flashinfer
############################################################################################
# SWITCH TO CACHED+FLATTENED ATTENTION + INITIALIZE CACHES
############################################################################################
update_in_out_nodes:
stage: cache_init
insert_cached_attention:
stage: cache_init
insert_cached_mla_attention:
stage: cache_init
attn_backend: MultiHeadLatentAttention
initialize_cache:
stage: cache_init
resize_kv_cache:
stage: cache_init
############################################################################################
# COMPILE MODEL
############################################################################################
compile_model:
stage: compile

View File

@ -198,7 +198,6 @@ def prepare_flashinfer_metadata(
flashinfer.get_seq_lens(paged_kv_indptr, paged_kv_last_page_len, page_size),
position_ids.numel(),
)
# return metadata
return (
qo_indptr,

View File

@ -274,6 +274,16 @@ class LlmArgs(AutoDeployConfig, BaseLlmArgs, BaseSettings):
self._quant_config = value
### VALIDATION #################################################################################
@field_validator("max_seq_len", mode="before")
@classmethod
def ensure_max_seq_len(cls, value: Any, info: ValidationInfo) -> Any:
if value is None:
# Fallback to the AutoDeployConfig default when not provided
return AutoDeployConfig.model_fields["max_seq_len"].get_default(
call_default_factory=True
)
return value
@field_validator("build_config", mode="before")
@classmethod
def ensure_no_build_config(cls, value: Any, info: ValidationInfo) -> Any:

View File

@ -54,6 +54,7 @@ class SharedConfig(BaseModel):
sharding_config: ShardingConfig = Field(default_factory=ShardingConfig)
local_rank: int = Field(default=0)
world_size: int = Field(default=1)
attn_backend: str = Field(default="flashinfer", description="The attention backend to use.")
class TransformConfig(BaseModel):
@ -285,7 +286,10 @@ class BaseTransform(ABC):
# update + store new meta data
history[t_name] = info
autodeploy_meta[self._history_key] = history
self._set_autodeploy_meta(gm, autodeploy_meta)
if isinstance(gm, GraphModule):
# After compilation, gm becomes type CapturedGraph with no meta data.
self._set_autodeploy_meta(gm, autodeploy_meta)
# return the graph module
return gm

View File

@ -0,0 +1,204 @@
import operator
from typing import Tuple
import torch
from torch.fx import GraphModule
from ...distributed.trtllm import is_trtllm_op_available
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils.node_utils import get_op_overload_packet, get_user_if_pattern_match, is_op
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
# TODO: This is an overly simplified model that works well for vanilla Llama models.
# However, we eventually want to consider more sophisticated patterns such as
# * all_reduce(lin1(x) + lin2(x))
# * version above with fused GEMMs (i.e. with a split node)
# * all_reduce(pointwise_op(linear(x)))
# * ...
@TransformRegistry.register("fuse_collectives")
class FuseCollectives(BaseTransform):
"""
Fuses all_reduce ops with preceding (quantized) linear ops into a single fused node for improved performance.
"""
def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
num_gemm_collective_fusions = 0
# lookup for fused ops
# TODO: avoid this hardcoded lookup, e.g., by generating fused ops on the fly.
lookup = {
torch.ops.auto_deploy.torch_linear_simple: torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce,
torch.ops.aten.linear: torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce,
torch.ops.auto_deploy.torch_quant_fp8_linear: torch.ops.auto_deploy.torch_quant_fused_fp8_linear_all_reduce,
}
# go through all nodes and find all_reduce nodes
for node in gm.graph.nodes:
if not is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce):
continue
# check if args are as expected
assert len(node.args) == 1 and not len(node.kwargs), (
"Unexpected args/kwargs for all_reduce"
)
# retrieve parent and check a few conditions on the parent node
parent_node = node.args[0]
if not is_op(parent_node, lookup.keys()):
continue
if len(parent_node.users) > 1:
continue
with gm.graph.inserting_before(node):
# insert fused node
fused_linear_collective_node = gm.graph.call_function(
lookup[get_op_overload_packet(parent_node.target)],
args=parent_node.args,
kwargs=parent_node.kwargs,
)
node.replace_all_uses_with(fused_linear_collective_node)
gm.graph.erase_node(node)
gm.graph.erase_node(parent_node)
num_gemm_collective_fusions += 1
info = TransformInfo(
skipped=False,
num_matches=num_gemm_collective_fusions,
is_clean=False,
has_valid_shapes=False,
)
return gm, info
@TransformRegistry.register("fuse_allreduce_residual_rmsnorm")
class FuseAllreduceResidualRMSNorm(BaseTransform):
"""Essentially, this transformation fuses the following operators into one allreduce trtllm implementation.
* target pattern:
x = all_reduce(x)
y = x + residual
return rmsnorm(y), y
* replacement:
fused_allreduce_residual_rmsnorm(x, residual, rmsnorm_weight, rmsnorm_eps)
"""
def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
if not is_trtllm_op_available():
return gm, TransformInfo(
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
)
num_ar_r_rms_fusions = 0
def trace_and_fuse(allreduce_node, graph):
# Check if all_reduce is followed by addition
users = list(allreduce_node.users.keys())
if len(users) != 1:
return # Skip if all_reduce has more than one consumer
add_node = users[0]
# Traverse nodes for RMSNorm pattern which is composed of to_copy, pow, mean, add, refer
# the Huggingface LlamaRMSNorm implementation as example for more details
to_copy_1 = get_user_if_pattern_match(add_node, [torch.ops.aten.add, operator.add], 2)
# operand of pow and mul
pow_node = get_user_if_pattern_match(
to_copy_1, [torch.ops.aten._to_copy, torch.ops.aten.to], 2
)
mean_node = get_user_if_pattern_match(pow_node, torch.ops.aten.pow, 1)
add_eps_node = get_user_if_pattern_match(mean_node, torch.ops.aten.mean, 1)
rsqrt_node = get_user_if_pattern_match(
add_eps_node, [torch.ops.aten.add, operator.add], 1
)
mul_node_1 = get_user_if_pattern_match(rsqrt_node, torch.ops.aten.rsqrt, 1)
to_copy_2 = get_user_if_pattern_match(mul_node_1, torch.ops.aten.mul, 1)
mul_node_2 = get_user_if_pattern_match(
to_copy_2, [torch.ops.aten._to_copy, torch.ops.aten.to], 1
)
# check args of ops: pow(2) and mean(-1)
ARGS_MATCH = pow_node is not None and pow_node.args[1] == 2 # exponent
ARGS_MATCH &= mean_node is not None and mean_node.args[1] == [-1] # dimensions
# Match found: Replace with fused operation
if (
to_copy_1
and pow_node
and mean_node
and add_eps_node
and rsqrt_node
and mul_node_1
and to_copy_2
and mul_node_2
and ARGS_MATCH
):
# Gather the inputs for the custom operation
tensor = allreduce_node.args[0]
# Identify the residual argument in the add operation
# One of the args in add_node.args is the output of all_reduce
# The same idea also applies to norm_weight
residual = (
add_node.args[0] if add_node.args[1] is allreduce_node else add_node.args[1]
)
norm_weight = (
mul_node_2.args[0] if mul_node_2.args[1] is to_copy_2 else mul_node_2.args[1]
)
eps = add_eps_node.args[1]
# Insert nodes
with graph.inserting_before(allreduce_node):
fused_node = graph.call_function(
torch.ops.dist.fused_allreduce_residual_rmsnorm,
args=(
tensor,
residual,
norm_weight,
eps,
),
)
# Extract outputs from the tuple returned by `fused_node`
final_output_node = gm.graph.create_node(
"call_function",
target=operator.getitem,
args=(fused_node, 0),
)
add_output_node = gm.graph.create_node(
"call_function",
target=operator.getitem,
args=(fused_node, 1),
)
# Replace all uses of rmsnorm_node with final_output_node
mul_node_2.replace_all_uses_with(final_output_node)
# Replace all uses of add_node with add_output_node
add_node.replace_all_uses_with(add_output_node)
nonlocal num_ar_r_rms_fusions
num_ar_r_rms_fusions += 1
# Traverse all nodes
for node in gm.graph.nodes:
if is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce):
trace_and_fuse(allreduce_node=node, graph=gm.graph)
info = TransformInfo(
skipped=False, num_matches=num_ar_r_rms_fusions, is_clean=False, has_valid_shapes=False
)
return gm, info

View File

@ -0,0 +1,65 @@
from typing import List, Literal, Optional, Tuple, Type
from pydantic import Field
from torch.fx import GraphModule
from ...compile import compile_and_capture
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ..interface import (
BaseTransform,
SharedConfig,
TransformConfig,
TransformInfo,
TransformRegistry,
)
class CompileModelConfig(TransformConfig):
"""Configuration for the compile model transform."""
cuda_graph_batch_sizes: Optional[List[int]] = Field(
default=None, description="The batch sizes to use for CUDA graphs."
)
num_batched_inputs: int = Field(
default=2, description="The number of batched inputs to use for CUDA graphs."
)
compile_backend: Literal["torch-simple", "torch-compile", "torch-cudagraph", "torch-opt"] = (
Field(description="The backend to use for compiling the model.")
)
@TransformRegistry.register("compile_model")
class CompileModel(BaseTransform):
"""A transform to compile the model."""
config: CompileModelConfig
@classmethod
def get_config_class(cls) -> Type[TransformConfig]:
return CompileModelConfig
def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
cm.info.set_generate_only_batch()
egm_compiled = compile_and_capture(
gm,
self.config.compile_backend,
args=cm.args,
dynamic_shapes=cm.dynamic_shapes,
compiler_kwargs={
"cuda_graph_batch_sizes": self.config.cuda_graph_batch_sizes,
"num_batched_inputs": self.config.num_batched_inputs,
},
)
cm.info.reset()
# store info object about the transform
info = TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True)
return egm_compiled, info

View File

@ -6,6 +6,8 @@ import torch
import torch.nn as nn
from torch.fx import GraphModule, Node
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils.cuda_mem_tracker import cuda_memory_tracker
from ...utils.logger import ad_logger
from ...utils.node_utils import (
@ -14,7 +16,7 @@ from ...utils.node_utils import (
is_linear_op,
)
from ...utils.quantization_utils import QuantizationImpl
from .._graph import canonicalize_graph
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
def _insert_fused_gemm(gm: GraphModule, idx: int, parent_node: Node, linear_nodes: List[Node]):
@ -116,30 +118,36 @@ def _insert_fused_gemm(gm: GraphModule, idx: int, parent_node: Node, linear_node
gm.delete_all_unused_submodules()
def fuse_gemms(gm: GraphModule) -> None:
ad_logger.info("GEMM fusion")
ad_logger.debug("Before GEMM fusion: " + str(gm))
# sort linear nodes by parent node
linear_nodes = defaultdict(list)
for node in gm.graph.nodes:
# TODO: we don't handle bias for now...
if is_linear_op(node, include_quantization=True) and node.args[2] is None:
linear_nodes[node.args[0]].append(node)
@TransformRegistry.register("fuse_gemms")
class FuseGemms(BaseTransform):
def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
# sort linear nodes by parent node
linear_nodes = defaultdict(list)
for node in gm.graph.nodes:
# TODO: we don't handle bias for now...
if is_linear_op(node, include_quantization=True) and node.args[2] is None:
linear_nodes[node.args[0]].append(node)
# fuse linear nodes
idx = -1
with cuda_memory_tracker():
for parent_node, lin_children in linear_nodes.items():
if len(lin_children) < 2:
continue
# linear nodes to fuse
ad_logger.debug(
f"Found linear nodes to fuse: {lin_children} with parent node: {parent_node}"
)
_insert_fused_gemm(gm, idx := idx + 1, parent_node, lin_children)
# fuse linear nodes
idx = -1
num_matches = 0
with cuda_memory_tracker():
for parent_node, lin_children in linear_nodes.items():
if len(lin_children) < 2:
continue
# linear nodes to fuse
_insert_fused_gemm(gm, idx := idx + 1, parent_node, lin_children)
num_matches += 1
# clean up and return
canonicalize_graph(gm)
torch.cuda.empty_cache()
ad_logger.debug("After GEMM fusion: " + str(gm))
torch.cuda.empty_cache()
info = TransformInfo(
skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=False
)
return gm, info

View File

@ -0,0 +1,299 @@
"""Graph transformation to automatically add kv cache into fused MHA op."""
import operator
from typing import Dict, Optional, Tuple, Type
import torch
from pydantic import Field
from torch.fx import Graph, GraphModule, Node
from ...custom_ops.attention_interface import AttentionRegistry
from ...distributed.common import all_gather_object, get_world_size
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...transformations._graph import add_graph_input
from ...utils.logger import ad_logger
from ...utils.node_utils import get_all_input_output_nodes, is_op
from ..interface import (
BaseTransform,
SharedConfig,
TransformConfig,
TransformInfo,
TransformRegistry,
)
@TransformRegistry.register("update_in_out_nodes")
class UpdateInOutNodes(BaseTransform):
"""Modify the graph module by adding new input nodes.
The new input nodes correspond to the extra arguments needed for cached and flattened attention.
Args:
egm: The graph module to analyze and modify.
cm: Cached sequence interface containing extra argument information.
"""
def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
# loop through nodes to get input, output, and get_attr nodes
input_nodes, output_nodes = get_all_input_output_nodes(gm.graph)
# we only expect one input node
assert len(input_nodes) == 2, "Expected exactly two input nodes (input_ids, position_ids)."
# NOTE: for now, we wanna make sure we *only* return the final output and no hidden states.
# Later on, we can revisit how to support returning hidden states.
assert len(output_nodes) == 1, "Expected exactly one output node!"
assert len(output_nodes[0].all_input_nodes) == 1, (
"Expected to only return final tensor output!"
)
# Activate and add extra argument nodes
new_args = cm.info.switch_to_cached_attn_inputs()
for name in new_args:
input_nodes.append(add_graph_input(gm, name))
info = TransformInfo(skipped=False, num_matches=1, is_clean=False, has_valid_shapes=False)
return gm, info
class InsertCachedAttentionConfig(TransformConfig):
"""Configuration for the insert cached attention transform."""
attn_backend: Optional[str] = Field(default=None, description="The attention backend to use.")
@TransformRegistry.register("insert_cached_attention")
class InsertCachedAttention(BaseTransform):
"""
A transform to insert cached attention into the graph module.
If attn_backend is not provided in transform config, will find from shared config.
"""
config: InsertCachedAttentionConfig
@classmethod
def get_config_class(cls) -> Type[TransformConfig]:
return InsertCachedAttentionConfig
def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
"""Replace uncached source attention node with corresponding cached attn node."""
attn_descriptor = AttentionRegistry.get(self.config.attn_backend)
cache_config = factory.get_cache_config()
# Get all attention nodes and their info objects
source_op = attn_descriptor.get_source_attention_op()
# pick up graph
graph: Graph = gm.graph
# look for relevant source attention nodes
source_attn_nodes = [n for n in graph.nodes if is_op(n, source_op)]
if not source_attn_nodes:
# If there are no nodes for kv cache insertion found, return current graph
return gm, TransformInfo(
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
)
# Sanity check
if cm.info.is_paged:
assert attn_descriptor.is_paged(), "Paged sequence info requires paged attention op."
# retrieve input nodes
input_nodes, _ = get_all_input_output_nodes(gm.graph)
# insert metadata computation and extract each argument as a node
get_metadata, num_metadata = attn_descriptor.get_prepare_metadata_op()
with graph.inserting_before(input_nodes[-1].next):
ret_node = graph.call_function(
get_metadata,
args=(
*input_nodes,
cm.info.page_size,
),
)
metadata_nodes = [
graph.call_function(operator.getitem, args=(ret_node, idx))
for idx in range(num_metadata)
]
buffer_in_lookup: Dict[str, Node] = {}
# replace fused attention node with attention node that has kv cache
num_cached_attn_replacements = 0
for idx, attn_node in enumerate(source_attn_nodes):
# pick out GEMMs
qkv = attn_node.args[: attn_descriptor.get_num_qkv_args()]
# setup + store cache initializers and caches as input nodes
cache_in_nodes = []
for k, get_cache in attn_descriptor.get_cache_initializers(
attn_node, cache_config
).items():
k_indexed = f"{k}_{idx}"
cm.add_cache(k_indexed, get_cache)
cache_in_nodes.append(add_graph_input(gm, k_indexed))
# setup + store global buffer initializers and buffers as input nodes
# NOTE: we have to check against existing keys to make sure nothing is registered twice...
buffer_in_nodes = []
for k, get_buffer in attn_descriptor.get_global_buffer_initializers(attn_node).items():
if k not in buffer_in_lookup:
cm.add_cache(k, get_buffer)
buffer_in_lookup[k] = add_graph_input(gm, k)
buffer_in_nodes.append(buffer_in_lookup[k]) # store buffer nodes for this op
# retrieve constants for attention_op
constants = attn_descriptor.get_constants(attn_node)
# insert cached attention replacement op
with graph.inserting_before(attn_node):
cached_attn_node = graph.call_function(
attn_descriptor.get_cached_attention_op(),
args=(*qkv, *metadata_nodes, *cache_in_nodes, *buffer_in_nodes, *constants),
)
attn_node.replace_all_uses_with(cached_attn_node)
graph.erase_node(attn_node)
num_cached_attn_replacements += 1
info = TransformInfo(
skipped=False,
num_matches=num_cached_attn_replacements,
is_clean=False,
has_valid_shapes=False,
)
return gm, info
@TransformRegistry.register("insert_cached_mla_attention")
class InsertCachedMLAAttention(InsertCachedAttention):
"""
A transform to insert cached MLA attention into the graph module.
This class is identical to InsertCachedAttention and inherits all its behavior.
"""
pass
class ResizeKVCacheConfig(TransformConfig):
"""Configuration for the resize kv cache transform."""
free_mem_ratio: float = Field(
description="The fraction of available memory to occupy.", default=0.8
)
@TransformRegistry.register("resize_kv_cache")
class ResizeKVCache(BaseTransform):
"""Inflate the kv cache to occupy the available GPU memory.
free_mem_ratio specifies the fraction of available memory to occupy.
"""
config: ResizeKVCacheConfig
@classmethod
def get_config_class(cls) -> Type[TransformConfig]:
return ResizeKVCacheConfig
def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
free_mem_ratio = self.config.free_mem_ratio
def _get_mem_info_in_mb():
free_mem, total_mem = torch.cuda.mem_get_info()
return free_mem // 1024**2, total_mem // 1024**2
free_mem, total_mem = _get_mem_info_in_mb()
ad_logger.info(f"Free memory (MB): {free_mem}, Total memory (MB): {total_mem}")
current_cache_size = cm.current_cache_size_bytes()
current_num_pages = cm.info.num_pages
ad_logger.info(
f"Current cache size: {current_cache_size}, Current num pages: {current_num_pages}"
)
if free_mem_ratio == 0.0:
ad_logger.info(f"Skipping cache resize for {free_mem_ratio=}")
return gm, TransformInfo(
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
)
try:
# Let's run a forward pass to get the memory usage
cm.info._set_max_num_tokens_sample()
free_mem_pre, _ = _get_mem_info_in_mb()
ad_logger.info(f"Free memory before forward pass (MB): {free_mem_pre}")
gm(*cm.args)
free_mem_post, _ = _get_mem_info_in_mb()
ad_logger.info(f"Free memory after forward pass (MB): {free_mem_post}")
memory_for_forward_pass = free_mem_pre - free_mem_post
ad_logger.info(f"Memory for forward pass (MB): {memory_for_forward_pass}")
new_cache_size = free_mem_post * 1024 * 1024 * free_mem_ratio + current_cache_size
new_num_pages = int(new_cache_size // (current_cache_size // current_num_pages))
# Need to sync all the GPUs
gathered_num_pages = [None] * get_world_size()
all_gather_object(gathered_num_pages, new_num_pages)
new_num_pages = min(gathered_num_pages)
ad_logger.info(f"After all_gather - new_num_pages: {new_num_pages}")
cm.resize_cache(new_num_pages)
except Exception as e:
ad_logger.warning(
f"Error encountered while resizing kv cache: {e}.\nSkipping cache resize."
)
# Free memory
torch.cuda.empty_cache()
info = TransformInfo(
skipped=False,
num_matches=0,
is_clean=True,
has_valid_shapes=True,
)
return gm, info
@TransformRegistry.register("initialize_cache")
class InitializeCache(BaseTransform):
def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
cm.initialize_caches()
info = TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True)
return gm, info

View File

@ -0,0 +1,148 @@
"""Graph transform to optimize RMSNorm execution using FlashInfer."""
from functools import partial
from typing import Tuple, Type
import torch
from pydantic import Field
from torch.fx import GraphModule
from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
# It is important to import ADPatternMatcherPass from pattern_matcher.py, not from torch._inductor.pattern_matcher
from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern
from ..interface import (
BaseTransform,
SharedConfig,
TransformConfig,
TransformInfo,
TransformRegistry,
)
_BACKEND_OPS = {
"flashinfer": torch.ops.auto_deploy.flashinfer_rms_norm,
"triton": torch.ops.auto_deploy.triton_rms_norm,
"torch": torch.ops.auto_deploy.torch_rmsnorm,
}
def _rms_norm_pattern(data: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
"""Implements the RMSNorm pattern for pattern matching.
Args:
data: Input tensor to normalize.
weight: Scaling weights for the normalized output.
eps: Small constant for numerical stability.
Returns:
Normalized and scaled tensor.
"""
input_dtype = data.dtype
data = data.to(torch.float32)
variance = data.pow(2).mean(-1, keepdim=True)
data = data * torch.rsqrt(variance + eps)
return weight * data.to(input_dtype)
def _rms_norm_replacement(
data: torch.Tensor, weight: torch.Tensor, eps: float, backend: str
) -> torch.Tensor:
"""Backend-specific rms_norm implementation.
Args:
data: Input tensor to normalize.
weight: Scaling weights for the normalized output.
eps: Small constant for numerical stability.
backend: Backend to use for RMSNorm computation ("flashinfer" or "triton").
Returns:
Normalized and scaled tensor using the specified backend implementation.
"""
assert backend.lower() in _BACKEND_OPS, (
f"Invalid {backend=}; must be one of {list(_BACKEND_OPS)}"
)
return _BACKEND_OPS[backend.lower()](data, weight, eps)
class FuseRMSNormConfig(TransformConfig):
"""Configuration for the RMSNorm fusion transform."""
backend: str = Field(
default="flashinfer",
description="Backend to use for RMSNorm computation ('flashinfer' or 'triton').",
)
@TransformRegistry.register("fuse_rmsnorm")
class FuseRMSNorm(BaseTransform):
"""Matches and replaces RMSNorm patterns in the graph with FlashInfer or Triton implementation.
This function sets up pattern matching to identify RMSNorm operations in the graph
and replaces them with optimized implementations. It uses dummy tensors to register
the pattern matching rules.
Args:
gm: Input graph module to transform.
backend: Backend to use for RMSNorm computation ("flashinfer" or "triton").
Returns:
Transformed graph module with optimized RMSNorm operations.
"""
config: FuseRMSNormConfig
@classmethod
def get_config_class(cls) -> Type[TransformConfig]:
return FuseRMSNormConfig
def _apply(
self,
gm: GraphModule,
cm: CachedSequenceInterface,
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
if self.config.backend.lower() not in _BACKEND_OPS:
raise ValueError(
f"Invalid backend, must be one of {list(_BACKEND_OPS)}, got {self.config.backend}"
)
graph = gm.graph
patterns = ADPatternMatcherPass()
# Create dummy tensors for pattern matching
bs = 2
hidden_size = 512
def dummy_args(input_dtype: torch.dtype, weight_dtype: torch.dtype, eps: float = 1e-6):
return [
torch.randn(bs, hidden_size, device="cuda", dtype=input_dtype),
torch.randn(hidden_size, device="cuda", dtype=weight_dtype),
eps,
]
# Define configurations for different data types
configs = [
(torch.bfloat16, torch.bfloat16),
(torch.float16, torch.float16),
(torch.float32, torch.float32),
]
# Register patterns for each configuration
for input_dtype, weight_dtype in configs:
register_ad_pattern(
search_fn=_rms_norm_pattern,
replace_fn=partial(_rms_norm_replacement, backend=self.config.backend),
patterns=patterns,
dummy_args=dummy_args(input_dtype, weight_dtype),
op_ignore_types={},
scalar_workaround={"eps": 1e-6},
)
cnt = patterns.apply(graph)
info = TransformInfo(skipped=False, num_matches=cnt, is_clean=False, has_valid_shapes=False)
return gm, info

View File

@ -1,11 +1,5 @@
"""A library of transformation passes."""
from .collectives import *
from .fused_moe import *
from .fusion import *
from .kvcache import *
from .rms_norm import *
try:
from .visualization import visualize_namespace
except ImportError:

View File

@ -1,167 +0,0 @@
import operator
import torch
from torch.fx import GraphModule
from ...distributed.trtllm import is_trtllm_op_available
from ...utils.logger import ad_logger
from ...utils.node_utils import get_op_overload_packet, get_user_if_pattern_match, is_op
from .._graph import canonicalize_graph
# TODO: This is an overly simplified model that works well for vanilla Llama models.
# However, we eventually want to consider more sophisticated patterns such as
# * all_reduce(lin1(x) + lin2(x))
# * version above with fused GEMMs (i.e. with a split node)
# * all_reduce(pointwise_op(linear(x)))
# * ...
def fuse_collectives(gm: GraphModule) -> None:
num_gemm_collective_fusions = 0
ad_logger.debug("Before GEMM+Collective fusion: " + str(gm))
# lookup for fused ops
# TODO: avoid this hardcoded lookup, e.g., by generating fused ops on the fly.
lookup = {
torch.ops.auto_deploy.torch_linear_simple: torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce,
torch.ops.aten.linear: torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce,
torch.ops.auto_deploy.torch_quant_fp8_linear: torch.ops.auto_deploy.torch_quant_fused_fp8_linear_all_reduce,
}
# go through all nodes and find all_reduce nodes
for node in gm.graph.nodes:
if not is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce):
continue
# check if args are as expected
assert len(node.args) == 1 and not len(node.kwargs), "Unexpected args/kwargs for all_reduce"
# retrieve parent and check a few conditions on the parent node
parent_node = node.args[0]
if not is_op(parent_node, lookup.keys()):
continue
if len(parent_node.users) > 1:
continue
with gm.graph.inserting_before(node):
# insert fused node
fused_linear_collective_node = gm.graph.call_function(
lookup[get_op_overload_packet(parent_node.target)],
args=parent_node.args,
kwargs=parent_node.kwargs,
)
node.replace_all_uses_with(fused_linear_collective_node)
gm.graph.erase_node(node)
gm.graph.erase_node(parent_node)
num_gemm_collective_fusions += 1
canonicalize_graph(gm)
ad_logger.info(f"Found {num_gemm_collective_fusions} GEMM+Collective fusions")
ad_logger.debug("After GEMM+Collective fusion: " + str(gm))
def fuse_allreduce_residual_rmsnorm(gm: GraphModule) -> None:
"""Essentially, this function fuses the following operators into one allreduce trtllm implementation.
* target pattern:
x = all_reduce(x)
y = x + residual
return rmsnorm(y), y
* replacement:
fused_allreduce_residual_rmsnorm(x, residual, rmsnorm_weight, rmsnorm_eps)
"""
if not is_trtllm_op_available():
return
num_ar_r_rms_fusions = 0
ad_logger.debug("Before allreduce+residual+rmsnorm fusion: " + str(gm))
def trace_and_fuse(allreduce_node, graph):
# Check if all_reduce is followed by addition
users = list(allreduce_node.users.keys())
if len(users) != 1:
return # Skip if all_reduce has more than one consumer
add_node = users[0]
# Traverse nodes for RMSNorm pattern which is composed of to_copy, pow, mean, add, refer
# the Huggingface LlamaRMSNorm implementation as example for more details
to_copy_1 = get_user_if_pattern_match(add_node, [torch.ops.aten.add, operator.add], 2)
# operand of pow and mul
pow_node = get_user_if_pattern_match(
to_copy_1, [torch.ops.aten._to_copy, torch.ops.aten.to], 2
)
mean_node = get_user_if_pattern_match(pow_node, torch.ops.aten.pow, 1)
add_eps_node = get_user_if_pattern_match(mean_node, torch.ops.aten.mean, 1)
rsqrt_node = get_user_if_pattern_match(add_eps_node, [torch.ops.aten.add, operator.add], 1)
mul_node_1 = get_user_if_pattern_match(rsqrt_node, torch.ops.aten.rsqrt, 1)
to_copy_2 = get_user_if_pattern_match(mul_node_1, torch.ops.aten.mul, 1)
mul_node_2 = get_user_if_pattern_match(
to_copy_2, [torch.ops.aten._to_copy, torch.ops.aten.to], 1
)
# check args of ops: pow(2) and mean(-1)
ARGS_MATCH = pow_node is not None and pow_node.args[1] == 2 # exponent
ARGS_MATCH &= mean_node is not None and mean_node.args[1] == [-1] # dimensions
# Match found: Replace with fused operation
if (
to_copy_1
and pow_node
and mean_node
and add_eps_node
and rsqrt_node
and mul_node_1
and to_copy_2
and mul_node_2
and ARGS_MATCH
):
# Gather the inputs for the custom operation
tensor = allreduce_node.args[0]
# Identify the residual argument in the add operation
# One of the args in add_node.args is the output of all_reduce
# The same idea also applies to norm_weight
residual = add_node.args[0] if add_node.args[1] is allreduce_node else add_node.args[1]
norm_weight = (
mul_node_2.args[0] if mul_node_2.args[1] is to_copy_2 else mul_node_2.args[1]
)
eps = add_eps_node.args[1]
# Insert nodes
with graph.inserting_before(allreduce_node):
fused_node = graph.call_function(
torch.ops.dist.fused_allreduce_residual_rmsnorm,
args=(
tensor,
residual,
norm_weight,
eps,
),
)
# Extract outputs from the tuple returned by `fused_node`
final_output_node = gm.graph.create_node(
"call_function",
target=operator.getitem,
args=(fused_node, 0),
)
add_output_node = gm.graph.create_node(
"call_function",
target=operator.getitem,
args=(fused_node, 1),
)
# Replace all uses of rmsnorm_node with final_output_node
mul_node_2.replace_all_uses_with(final_output_node)
# Replace all uses of add_node with add_output_node
add_node.replace_all_uses_with(add_output_node)
nonlocal num_ar_r_rms_fusions
num_ar_r_rms_fusions += 1
# Traverse all nodes
for node in gm.graph.nodes:
if is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce):
trace_and_fuse(allreduce_node=node, graph=gm.graph)
canonicalize_graph(gm)
ad_logger.info(f"Found {num_ar_r_rms_fusions} allreduce+residual+rmsnorm fusions")
ad_logger.debug("After allreduce+residual+rmsnorm fusion: " + str(gm))

View File

@ -1,511 +0,0 @@
from collections import defaultdict
from typing import Optional
import torch
from torch.fx import GraphModule, Node
from ...utils.cuda_mem_tracker import cuda_memory_tracker
from ...utils.logger import ad_logger
from ...utils.node_utils import bfs, identify_regions_between_residuals, is_linear_op, is_op
from ...utils.quantization_utils import get_scales_and_type_from_node
from .._graph import canonicalize_graph
def match_moe_pattern(gm: GraphModule) -> None:
graph = gm.graph
ad_logger.debug("Before MoE Pattern Matching: " + str(gm))
# Preprocessing: Identify boundary nodes (e.g. residual connections) in the graph.
boundary_nodes = identify_regions_between_residuals(gm)
num_moe_patterns = 0
for start_boundary, end_boundary in zip(boundary_nodes[:-1], boundary_nodes[1:]):
# Step 1: Identify Expert Compute pattern
(pattern_input_nodes, pattern_output_nodes, expert_weights, expert_scales, weight_type) = (
_match_expert_compute_pattern(start_boundary, end_boundary)
)
if not expert_weights:
continue
# TODO: naming convention to verify the order of the weight nodes
# Step 2: Trace upwards to locate normalize_routing_weight and selected_experts:
arg1_list, arg2_list = _extract_index_branches_from_expert_outputs(pattern_output_nodes)
normalized_routing_weights = _find_lowest_common_ancessor(arg1_list)
if not normalized_routing_weights:
continue
common_ancessor2 = _find_lowest_common_ancessor(arg2_list)
if not common_ancessor2:
continue
selected_experts = bfs(
common_ancessor2,
lambda node: is_op(node, torch.ops.aten.one_hot),
attr_next="all_input_nodes",
boundary=start_boundary,
).args[0]
if not selected_experts:
continue
# Step 3: Trace upwards to find input node:
hidden_states = _find_lowest_common_ancessor(pattern_input_nodes)
if not hidden_states:
continue
# Step 4: Find output node with the combine pattern
final_hidden_state_node = _find_final_hidden_state_node(pattern_output_nodes, end_boundary)
if final_hidden_state_node is None:
continue
# Step 5: Insert the MoE op into the graph.
ad_logger.debug(
f"Found MoE Pattern: between boundary {start_boundary} and {end_boundary}.\n"
f"Input hidden states node: {hidden_states}, "
f"selected_experts node: {selected_experts}, "
f"routing_weights node: {normalized_routing_weights}, "
f"expert weights: {expert_weights}, weight type: {weight_type}"
)
with graph.inserting_before(final_hidden_state_node):
w1_list = expert_weights["w1"]
w2_list = expert_weights["w2"]
w3_list = expert_weights["w3"]
if weight_type == "fp8":
fused_moe_node = graph.call_function(
torch.ops.auto_deploy.torch_quant_fp8_moe,
args=(
hidden_states,
selected_experts,
normalized_routing_weights,
w1_list,
w2_list,
w3_list,
expert_scales["w1_input_scale"],
expert_scales["w2_input_scale"],
expert_scales["w3_input_scale"],
expert_scales["w1_weight_scale"],
expert_scales["w2_weight_scale"],
expert_scales["w3_weight_scale"],
),
)
elif weight_type == "fp4":
fused_moe_node = graph.call_function(
torch.ops.auto_deploy.torch_quant_fp4_moe,
args=(
hidden_states,
selected_experts,
normalized_routing_weights,
w1_list,
w2_list,
w3_list,
expert_scales["w1_input_scale"],
expert_scales["w2_input_scale"],
expert_scales["w3_input_scale"],
expert_scales["w1_weight_scale"],
expert_scales["w2_weight_scale"],
expert_scales["w3_weight_scale"],
expert_scales["w1_alpha"],
expert_scales["w2_alpha"],
expert_scales["w3_alpha"],
),
)
else:
fused_moe_node = graph.call_function(
torch.ops.auto_deploy.torch_moe,
args=(
hidden_states,
selected_experts,
normalized_routing_weights,
w1_list,
w2_list,
w3_list,
),
)
final_hidden_state_node.replace_all_uses_with(fused_moe_node)
graph.erase_node(final_hidden_state_node)
while _remove_dead_inplace_nodes_in_region(gm.graph, start_boundary, end_boundary):
gm.graph.eliminate_dead_code()
num_moe_patterns += 1
canonicalize_graph(gm)
ad_logger.info(f"Found {num_moe_patterns} MoE Patterns")
ad_logger.debug("After MoE Pattern Matching: " + str(gm))
def fuse_moe(gm: torch.fx.GraphModule) -> None:
"""
Scan the FX graph and replace all calls to torch.ops.auto_deploy.torch_moe with
torch.ops.auto_deploy.trtllm_moe_fused.
"""
ad_logger.debug("Before MoE fusion: " + str(gm))
with cuda_memory_tracker():
fused_key_counter = _insert_fused_moe_ops(gm)
if fused_key_counter:
canonicalize_graph(gm)
ad_logger.info(f"Found {fused_key_counter} MoE fusions")
ad_logger.debug("After MoE fusion: " + str(gm))
def _insert_fused_moe_ops(gm: GraphModule) -> int:
fused_key_counter = 0
graph = gm.graph
for node in list(graph.nodes):
if not is_op(node, torch.ops.auto_deploy.torch_moe):
continue
ad_logger.debug(f"Found MoE op to fuse: {node} with args: {node.args}")
hidden_states, selected_experts, routing_weights, w1_list, w2_list, w3_list = node.args
fused_w3_w1_experts = torch.stack(
[
torch.cat(
[gm.get_parameter(w3_node.target), gm.get_parameter(w1_node.target)], dim=-2
)
for w1_node, w3_node in zip(w1_list, w3_list)
],
dim=0,
)
fused_w2_experts = torch.stack([gm.get_parameter(n.target) for n in w2_list], dim=0)
new_key_w3_w1 = f"fused_moe_w3_w1_stacked_{fused_key_counter}"
new_key_w2 = f"fused_moe_w2_stacked_{fused_key_counter}"
fused_key_counter += 1
param_w3_w1 = torch.nn.Parameter(fused_w3_w1_experts)
param_w2 = torch.nn.Parameter(fused_w2_experts)
gm.register_parameter(new_key_w3_w1, param_w3_w1)
gm.register_parameter(new_key_w2, param_w2)
with graph.inserting_before(node):
new_node = graph.call_function(
# TODO(Fridah-nv): torch.ops.auto_deploy.trtllm_moe_fused for quantized models
torch.ops.auto_deploy.trtllm_moe_fused,
args=(
hidden_states,
selected_experts,
routing_weights,
graph.get_attr(new_key_w3_w1),
graph.get_attr(new_key_w2),
),
)
node.replace_all_uses_with(new_node)
graph.erase_node(node)
return fused_key_counter
def _find_lowest_common_ancessor(nodes: list[Node]) -> Optional[Node]:
"""
Find the lowest common ancestor for a list of nodes in a torch.fx Graph by following
each node's primary branch (recursively following the first Node argument).
It first finds the LCA of the first two nodes and then
iteratively computes the LCA of the result with the next node, and so on.
Returns:
The common ancestor Node if found, otherwise None.
"""
if not nodes:
return None
def get_parent(node: Node) -> Optional[Node]:
"""Return the first Node-valued argument for a given node, or None if not found."""
for arg in node.args:
if isinstance(arg, Node):
return arg
return None
def get_depth(node: Node) -> int:
"""
Recursively compute the depth of the node by following its primary branch.
Depth is defined as the number of steps to reach a node with no parent.
"""
parent = get_parent(node)
if parent is None:
return 0
return 1 + get_depth(parent)
def lca_two(a: Node, b: Node) -> Optional[Node]:
"""
Find the lowest common ancestor of two nodes by first equalizing their depth
and then moving upward until a common node is found.
"""
depth_a = get_depth(a)
depth_b = get_depth(b)
# Equalize depths
while depth_a > depth_b:
a = get_parent(a)
depth_a -= 1
while depth_b > depth_a:
b = get_parent(b)
depth_b -= 1
# Walk upward in lockstep
while a is not None and b is not None:
if a is b:
return a
a = get_parent(a)
b = get_parent(b)
return None
# Iteratively compute the LCA across all nodes.
common = nodes[0]
for node in nodes[1:]:
common = lca_two(common, node)
if common is None:
return None
return common
def _extract_linear_parameters(linear_node: Node) -> tuple[Node, torch.Tensor, Optional[dict], str]:
"""
Given a linear op node, extract the input tensor node, weight tensor,
any quantization scales (if the op is quantized), and return a weight type.
For a torch.ops.auto_deploy.torch_linear_simple.default op:
- Returns (input_node, weight, None, "simple")
For a torch.ops.auto_deploy.torch_quant_fp8_linear op:
- Returns (input_node, weight, {"input_scale": input_scale, "weight_scale": weight_scale}, "fp8")
For a torch.ops.auto_deploy.torch_quant_fp4_linear op:
- Returns (input_node, weight, {"input_scale": input_scale, "weight_scale": weight_scale, "alpha": alpha}, "fp4")
"""
input_node = linear_node.args[0]
if is_op(linear_node, torch.ops.auto_deploy.torch_linear_simple):
weight = linear_node.args[1]
return input_node, weight, None, ""
elif {
is_op(linear_node, torch.ops.auto_deploy.torch_quant_fp4_linear),
is_op(linear_node, torch.ops.auto_deploy.torch_quant_fp8_linear),
}:
weight = linear_node.args[1]
scales, quant_type = get_scales_and_type_from_node(linear_node)
return input_node, weight, scales, quant_type
def _match_expert_compute_pattern(start_boundary: Node, end_boundary: Node):
"""
Match the expert compute pattern between the given boundaries.
The expert compute pattern corresponds to:
(F.silu(x @ w1.t()) * (x @ w3.t())) @ w2.t()
For each expert, the function extracts the input node from the w1 branch and
collects the weight parameters from three linear ops (w1, w3, and w2 branches).
This function supports both:
- torch.ops.auto_deploy.torch_linear_simple.default ops, and
- torch.ops.auto_deploy.torch_quant_fp8_linear ops (also extracts quantization scales).
- torch.ops.auto_deploy.torch_quant_fp4_linear ops (also extracts quantization scales).
Returns:
A tuple:
(pattern_input_nodes, pattern_output_nodes, expert_weights, expert_scales, weight_type)
- pattern_input_nodes: List of input nodes (x) used for the expert compute.
- pattern_output_nodes: List of final expert output nodes (the linear op with weight w2).
- expert_weights: Dict with keys "w1", "w2", "w3" mapping to lists of weight tensors.
- expert_scales: Dict with keys "w1_input_scale", "w1_weight_scale", etc., containing scale tensors
(empty if weight_type is "simple").
- weight_type: "fp8" if FP8 ops were used, "simple" otherwise.
"""
pattern_input_nodes, pattern_output_nodes = [], []
expert_weights = defaultdict(list)
expert_scales = defaultdict(list)
weight_type = "simple" # default
nodes = list(start_boundary.graph.nodes)
region_nodes = nodes[nodes.index(start_boundary) + 1 : nodes.index(end_boundary)]
for node in region_nodes:
# Accept both simple and quantized linear ops.
if not is_linear_op(node, include_quantization=True):
continue
final_linear = node
if not final_linear.args or not isinstance(final_linear.args[0], Node):
continue
mul_node = final_linear.args[0]
if not is_op(mul_node, torch.ops.aten.mul) or len(mul_node.args) < 2:
continue
arg_a, arg_b = mul_node.args[:2]
silu_node = (
arg_a
if is_op(arg_a, torch.ops.aten.silu)
else arg_b
if is_op(arg_b, torch.ops.aten.silu)
else None
)
if silu_node is None:
continue
if not (silu_node.args and is_linear_op(silu_node.args[0], include_quantization=True)):
continue
linear_w1_node = silu_node.args[0]
# The other branch should be a linear op (w3 branch).
linear_w3_node = arg_b if arg_a is silu_node else arg_a
if not is_linear_op(linear_w3_node, include_quantization=True):
continue
if not (linear_w1_node.args and linear_w3_node.args):
continue
# Extract parameters from each linear op.
input_node_w1, weight_w1, quant_params_w1, wt_type_w1 = _extract_linear_parameters(
linear_w1_node
)
_, weight_w3, quant_params_w3, wt_type_w3 = _extract_linear_parameters(linear_w3_node)
_, weight_w2, quant_params_w2, wt_type_w2 = _extract_linear_parameters(final_linear)
if None in (weight_w1, weight_w3, weight_w2):
continue
# Ensure the weight type is consistent across branches.
if wt_type_w1 != wt_type_w3 or wt_type_w1 != wt_type_w2:
continue
weight_type = wt_type_w1
pattern_input_nodes.append(input_node_w1)
pattern_output_nodes.append(final_linear)
expert_weights["w1"].append(weight_w1)
expert_weights["w3"].append(weight_w3)
expert_weights["w2"].append(weight_w2)
# TODO: sanity check that all experts have same weight type
if weight_type == "fp8":
expert_scales["w1_input_scale"].append(quant_params_w1["input_scale"])
expert_scales["w1_weight_scale"].append(quant_params_w1["weight_scale"])
expert_scales["w3_input_scale"].append(quant_params_w3["input_scale"])
expert_scales["w3_weight_scale"].append(quant_params_w3["weight_scale"])
expert_scales["w2_input_scale"].append(quant_params_w2["input_scale"])
expert_scales["w2_weight_scale"].append(quant_params_w2["weight_scale"])
elif weight_type == "fp4":
expert_scales["w1_input_scale"].append(quant_params_w1["input_scale"])
expert_scales["w1_weight_scale"].append(quant_params_w1["weight_scale"])
expert_scales["w1_alpha"].append(quant_params_w1["alpha"])
expert_scales["w3_input_scale"].append(quant_params_w3["input_scale"])
expert_scales["w3_weight_scale"].append(quant_params_w3["weight_scale"])
expert_scales["w3_alpha"].append(quant_params_w3["alpha"])
expert_scales["w2_input_scale"].append(quant_params_w2["input_scale"])
expert_scales["w2_weight_scale"].append(quant_params_w2["weight_scale"])
expert_scales["w2_alpha"].append(quant_params_w2["alpha"])
return pattern_input_nodes, pattern_output_nodes, expert_weights, expert_scales, weight_type
def _find_final_hidden_state_node(
pattern_output_nodes: list[Node], end_boundary: Node
) -> Optional[Node]:
"""
Identify the final hidden state node corresponding to the combine pattern:
(expert_output * routing_weight) index_add_
For each expert output node (from the expert compute pattern), this function:
1. Retrieves a multiplication node from its users.
2. Extracts the second argument from the multiplication node (assumed to be the index node).
3. Uses a BFS to locate the subsequent index_add_ node (guarded by the end_boundary).
After collecting all such index_add_ nodes, the final hidden state node is determined
as the one that is not used by any of the other index_add_ nodes.
If any required attribute (users or args) is missing during the process or if no valid
final node is found, the function returns None.
"""
if not pattern_output_nodes:
return None
index_add_nodes = []
for node in pattern_output_nodes:
if not node.users:
return None
mul_node = next(iter(node.users))
if not (hasattr(mul_node, "args") and len(mul_node.args) >= 2):
return None
index_node = mul_node.args[1]
index_add_node = bfs(
index_node, lambda n: is_op(n, torch.ops.aten.index_add_), boundary=end_boundary
)
if not index_add_node:
return None
index_add_nodes.append(index_add_node)
# The final node is defined as the index_add_node that is not used by any other index_add_nodes
return next(
(
candidate
for candidate in index_add_nodes
if not any(
candidate in other.args for other in index_add_nodes if candidate is not other
)
),
None,
)
def _extract_index_branches_from_expert_outputs(
pattern_output_nodes: list[Node],
) -> tuple[list[Node], list[Node]]:
"""
Extract routing and experts branches from expert outputs.
For each expert output, find its multiplication user. From the
multiplication node's second argument (an index node),
extract:
- The first argument as the routing branch.
- The second argument (flattened if a list/tuple) as the experts branch.
Returns:
A tuple (routing_branches, experts_branches).
"""
routing_branches, experts_branches = [], []
for out in pattern_output_nodes:
mul = next((u for u in out.users if is_op(u, torch.ops.aten.mul)), None)
if not mul or len(mul.args) < 2:
continue
idx_node = mul.args[1]
if not is_op(idx_node, torch.ops.aten.index):
continue
routing_branches.append(idx_node.args[0])
experts = idx_node.args[1]
experts_branches.extend(experts) if isinstance(
experts, (list, tuple)
) else experts_branches.append(experts)
return routing_branches, experts_branches
def _remove_dead_inplace_nodes_in_region(
graph: torch.fx.Graph,
start_boundary: torch.fx.Node,
end_boundary: torch.fx.Node,
) -> bool:
"""
Searches (via BFS) for a dead in-place node (index_add_) in the region
between start_boundary and end_boundary. If one is found, it is removed from the graph.
Returns True if a node was removed, False otherwise.
"""
def target(n: torch.fx.Node) -> bool:
return is_op(n, {torch.ops.aten.index_add_}) and len(n.users) == 0
try:
node_to_remove = bfs(start_boundary, target, attr_next="users", boundary=end_boundary)
ad_logger.debug(f"Removing In-place Dead Node: {node_to_remove}")
graph.erase_node(node_to_remove)
return True
except RuntimeError:
return False

View File

@ -1,193 +0,0 @@
"""Graph transformation to automatically add kv cache into fused MHA op."""
import operator
from typing import Dict, Type
import torch
from torch.fx import Graph, GraphModule, Node
from ...custom_ops.attention_interface import AttentionDescriptor, CacheConfig
from ...distributed.common import all_gather_object, get_world_size
from ...shim.interface import CachedSequenceInterface
from ...utils.logger import ad_logger
from ...utils.node_utils import get_all_input_output_nodes, is_op
from .._graph import add_graph_input, canonicalize_graph
def update_in_out_nodes(egm: GraphModule, cm: CachedSequenceInterface) -> None:
"""Modify the graph module by adding new input nodes and canonicalizing the graph.
The new input nodes correspond to the extra arguments needed for cached and flattened attention.
Args:
egm: The graph module to analyze and modify.
cm: Cached sequence interface containing extra argument information.
"""
# loop through nodes to get input, output, and get_attr nodes
input_nodes, output_nodes = get_all_input_output_nodes(egm.graph)
# we only expect one input node
assert len(input_nodes) == 2, "Expected exactly two input nodes (input_ids, position_ids)."
# NOTE: for now, we wanna make sure we *only* return the final output and no hidden states.
# Later on, we can revisit how to support returning hidden states.
assert len(output_nodes) == 1, "Expected exactly one output node!"
assert len(output_nodes[0].all_input_nodes) == 1, "Expected to only return final tensor output!"
ad_logger.info(f"Found {len(input_nodes)} input nodes and {len(output_nodes)} output nodes")
# Activate and add extra argument nodes
new_args = cm.info.switch_to_cached_attn_inputs()
for name in new_args:
input_nodes.append(add_graph_input(egm, name))
ad_logger.info(f"Added {len(new_args)} new input nodes for cached attention metadata")
canonicalize_graph(egm)
def insert_cached_attention(
egm: GraphModule,
cm: CachedSequenceInterface,
attn_descriptor: Type[AttentionDescriptor],
cache_config: CacheConfig,
) -> None:
"""Replace uncached source attention node with corresponding cached attn node."""
# Get all attention nodes and their info objects
source_op = attn_descriptor.get_source_attention_op()
# pick up graph
graph: Graph = egm.graph
# look for relevant source attention nodes
source_attn_nodes = [n for n in graph.nodes if is_op(n, source_op)]
if not source_attn_nodes:
# If there are no nodes for kv cache insertion found, return current graph
return
# Sanity check
if cm.info.is_paged:
assert attn_descriptor.is_paged(), "Paged sequence info requires paged attention op."
ad_logger.debug(f"Before inserting {attn_descriptor=} with cache: {egm}")
# retrieve input nodes
input_nodes, _ = get_all_input_output_nodes(egm.graph)
# insert metadata computation and extract each argument as a node
get_metadata, num_metadata = attn_descriptor.get_prepare_metadata_op()
with graph.inserting_before(input_nodes[-1].next):
ret_node = graph.call_function(
get_metadata,
args=(
*input_nodes,
cm.info.page_size,
),
)
metadata_nodes = [
graph.call_function(operator.getitem, args=(ret_node, idx))
for idx in range(num_metadata)
]
buffer_in_lookup: Dict[str, Node] = {}
# replace fused attention node with attention node that has kv cache
num_cached_attn_replacements = 0
for idx, attn_node in enumerate(source_attn_nodes):
# pick out GEMMs
qkv = attn_node.args[: attn_descriptor.get_num_qkv_args()]
# setup + store cache initializers and caches as input nodes
cache_in_nodes = []
for k, get_cache in attn_descriptor.get_cache_initializers(attn_node, cache_config).items():
k_indexed = f"{k}_{idx}"
cm.add_cache(k_indexed, get_cache)
cache_in_nodes.append(add_graph_input(egm, k_indexed))
# setup + store global buffer initializers and buffers as input nodes
# NOTE: we have to check against existing keys to make sure nothing is registered twice...
buffer_in_nodes = []
for k, get_buffer in attn_descriptor.get_global_buffer_initializers(attn_node).items():
if k not in buffer_in_lookup:
cm.add_cache(k, get_buffer)
buffer_in_lookup[k] = add_graph_input(egm, k)
buffer_in_nodes.append(buffer_in_lookup[k]) # store buffer nodes for this op
# retrieve constants for attention_op
constants = attn_descriptor.get_constants(attn_node)
# insert cached attention replacement op
with graph.inserting_before(attn_node):
cached_attn_node = graph.call_function(
attn_descriptor.get_cached_attention_op(),
args=(*qkv, *metadata_nodes, *cache_in_nodes, *buffer_in_nodes, *constants),
)
attn_node.replace_all_uses_with(cached_attn_node)
graph.erase_node(attn_node)
num_cached_attn_replacements += 1
canonicalize_graph(egm)
ad_logger.info(
f"Replaced {num_cached_attn_replacements} {source_op} ops "
f"with {attn_descriptor.get_cached_attention_op()}"
)
ad_logger.debug(f"After inserting {attn_descriptor=} with cache: {egm}")
def resize_kv_cache(
egm: GraphModule,
cm: CachedSequenceInterface,
free_mem_ratio: float = 0.8,
) -> None:
"""Inflate the kv cache to occupy the available GPU memory.
free_mem_ratio specifies the fraction of available memory to occupy.
"""
def _get_mem_info_in_mb():
free_mem, total_mem = torch.cuda.mem_get_info()
return free_mem // 1024**2, total_mem // 1024**2
free_mem, total_mem = _get_mem_info_in_mb()
ad_logger.info(f"Free memory (MB): {free_mem}, Total memory (MB): {total_mem}")
current_cache_size = cm.current_cache_size_bytes()
current_num_pages = cm.info.num_pages
ad_logger.info(
f"Current cache size: {current_cache_size}, Current num pages: {current_num_pages}"
)
if free_mem_ratio == 0.0:
ad_logger.info(f"Skipping cache resize for {free_mem_ratio=}")
return
try:
# Let's run a forward pass to get the memory usage
cm.info._set_max_num_tokens_sample()
free_mem_pre, _ = _get_mem_info_in_mb()
ad_logger.info(f"Free memory before forward pass (MB): {free_mem_pre}")
egm(*cm.args)
free_mem_post, _ = _get_mem_info_in_mb()
ad_logger.info(f"Free memory after forward pass (MB): {free_mem_post}")
memory_for_forward_pass = free_mem_pre - free_mem_post
ad_logger.info(f"Memory for forward pass (MB): {memory_for_forward_pass}")
new_cache_size = free_mem_post * 1024 * 1024 * free_mem_ratio + current_cache_size
new_num_pages = int(new_cache_size // (current_cache_size // current_num_pages))
# Need to sync all the GPUs
gathered_num_pages = [None] * get_world_size()
all_gather_object(gathered_num_pages, new_num_pages)
new_num_pages = min(gathered_num_pages)
ad_logger.info(f"After all_gather - new_num_pages: {new_num_pages}")
cm.resize_cache(new_num_pages)
except Exception as e:
ad_logger.warning(
f"Error encountered while resizing kv cache: {e}.\nSkipping cache resize."
)
# Free memory
torch.cuda.empty_cache()

View File

@ -1,113 +0,0 @@
"""Graph transform to optimize RMSNorm execution using FlashInfer."""
from functools import partial
import torch
from torch.fx import GraphModule
from ...utils.logger import ad_logger
# It is important to import ADPatternMatcherPass from pattern_matcher.py, not from torch._inductor.pattern_matcher
from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern
from .._graph import canonicalize_graph
_BACKEND_OPS = {
"flashinfer": torch.ops.auto_deploy.flashinfer_rms_norm,
"triton": torch.ops.auto_deploy.triton_rms_norm,
"torch": torch.ops.auto_deploy.torch_rmsnorm,
}
def _rms_norm_pattern(data: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
"""Implements the RMSNorm pattern for pattern matching.
Args:
data: Input tensor to normalize.
weight: Scaling weights for the normalized output.
eps: Small constant for numerical stability.
Returns:
Normalized and scaled tensor.
"""
input_dtype = data.dtype
data = data.to(torch.float32)
variance = data.pow(2).mean(-1, keepdim=True)
data = data * torch.rsqrt(variance + eps)
return weight * data.to(input_dtype)
def _rms_norm_replacement(
data: torch.Tensor, weight: torch.Tensor, eps: float, backend: str
) -> torch.Tensor:
"""Backend-specific rms_norm implementation.
Args:
data: Input tensor to normalize.
weight: Scaling weights for the normalized output.
eps: Small constant for numerical stability.
backend: Backend to use for RMSNorm computation ("flashinfer" or "triton").
Returns:
Normalized and scaled tensor using the specified backend implementation.
"""
assert backend.lower() in _BACKEND_OPS, (
f"Invalid {backend=}; must be one of {list(_BACKEND_OPS)}"
)
return _BACKEND_OPS[backend.lower()](data, weight, eps)
def fuse_rmsnorm(gm: GraphModule, backend: str = "triton") -> None:
"""Matches and replaces RMSNorm patterns in the graph with FlashInfer or Triton implementation.
This function sets up pattern matching to identify RMSNorm operations in the graph
and replaces them with optimized implementations. It uses dummy tensors to register
the pattern matching rules.
Args:
gm: Input graph module to transform.
backend: Backend to use for RMSNorm computation ("flashinfer" or "triton").
Returns:
Transformed graph module with optimized RMSNorm operations.
"""
if backend.lower() not in _BACKEND_OPS:
raise ValueError(f"Invalid backend, must be one of {list(_BACKEND_OPS)}, got {backend}")
ad_logger.info(f"Starting RMSNorm pattern matching with backend: {backend}")
graph = gm.graph
patterns = ADPatternMatcherPass()
# Create dummy tensors for pattern matching
bs = 2
hidden_size = 512
def dummy_args(input_dtype: torch.dtype, weight_dtype: torch.dtype, eps: float = 1e-6):
return [
torch.randn(bs, hidden_size, device="cuda", dtype=input_dtype),
torch.randn(hidden_size, device="cuda", dtype=weight_dtype),
eps,
]
# Define configurations for different data types
configs = [
(torch.bfloat16, torch.bfloat16),
(torch.float16, torch.float16),
(torch.float32, torch.float32),
]
# Register patterns for each configuration
for input_dtype, weight_dtype in configs:
register_ad_pattern(
search_fn=_rms_norm_pattern,
replace_fn=partial(_rms_norm_replacement, backend=backend),
patterns=patterns,
dummy_args=dummy_args(input_dtype, weight_dtype),
op_ignore_types={},
scalar_workaround={"eps": 1e-6},
)
cnt = patterns.apply(graph)
ad_logger.info(f"RMSNorm pattern count: {cnt}")
canonicalize_graph(gm)
ad_logger.debug("RMSNorm pattern matching completed.")

View File

@ -5,21 +5,11 @@ import gc
import torch
import torch.nn as nn
from ..compile import compile_and_capture
from ..custom_ops.attention_interface import AttentionRegistry
from ..llm_args import AutoDeployConfig
from ..models.factory import ModelFactory
from ..shim.interface import CachedSequenceInterface
from ..transform.optimizer import InferenceOptimizer as ModularInferenceOptimizer
from ..utils.logger import ad_logger
from .library import (
fuse_allreduce_residual_rmsnorm,
fuse_collectives,
fuse_rmsnorm,
insert_cached_attention,
resize_kv_cache,
update_in_out_nodes,
)
class InferenceOptimizer:
@ -55,88 +45,60 @@ class InferenceOptimizer:
self.ad_config.attn_backend
).get_attention_layout()
new_optimizer = ModularInferenceOptimizer(self.factory, self.ad_config.transforms)
# TODO (hg): similar to above.
if "load_weights" in new_optimizer.config:
new_optimizer.config[
if "load_weights" in self.ad_config.transforms:
self.ad_config.transforms[
"load_weights"
].checkpoint_device = self.ad_config.checkpoint_device
new_optimizer.config["load_weights"].device = cm.device
self.ad_config.transforms["load_weights"].device = cm.device
if "resize_kv_cache" in self.ad_config.transforms:
self.ad_config.transforms[
"resize_kv_cache"
].free_mem_ratio = self.ad_config.free_mem_ratio
if "insert_cached_attention" in self.ad_config.transforms:
self.ad_config.transforms[
"insert_cached_attention"
].attn_backend = self.ad_config.attn_backend
if "insert_cached_mla_attention" in self.ad_config.transforms:
self.ad_config.transforms[
"insert_cached_mla_attention"
].attn_backend = self.ad_config.mla_backend
# TODO: (hg)Missing MLA here. Figure out how to add MLA since duplicate transforms are not allowed.
# Old code:
# detect attention op and replace with cache-aware op
# for a_backend in [self.ad_config.attn_backend, self.ad_config.mla_backend]:
# attn_descriptor = AttentionRegistry.get(a_backend)
# insert_cached_attention(egm, cm, attn_descriptor, self.factory.get_cache_config())
if "compile_model" in self.ad_config.transforms:
self.ad_config.transforms[
"compile_model"
].cuda_graph_batch_sizes = self.ad_config.cuda_graph_batch_sizes
self.ad_config.transforms[
"compile_model"
].compile_backend = self.ad_config.compile_backend
new_optimizer = ModularInferenceOptimizer(self.factory, self.ad_config.transforms)
# TODO: (hg) move this. let match_rope_layout and match_atten_layout use this shared config
new_optimizer.shared_config.attn_backend = self.ad_config.attn_backend
egm = new_optimizer(cm)
# TODO (lucaslie): continue moving legacy transforms to the new optimizer
############################################################################################
# RUN POST-LOAD FUSION AND OPTIMIZATIONS
############################################################################################
# NOTE: (hg)Disabled visualization since compiled gm is a CapturedGraph instead of GraphModule.
# We can add a new stage in the optimizer to visualize the intermediate gm.
# if self.ad_config.visualize:
# try:
# from .library import visualize_namespace
# run MoE fusion
# TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/4674 this is causing OOMs
# fuse_moe(egm)
# run GEMM fusion
# TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/4674 this is causing OOMs
# fuse_gemms(egm)
# check if we can fuse allreduce, residual and rmsnorm
fuse_allreduce_residual_rmsnorm(egm)
# check if we can fuse collectives
fuse_collectives(egm)
# TODO (lucaslie): add backend selection as part of configurable inference optimizers
# check if we can fuse rmsnorm
fuse_rmsnorm(egm, "flashinfer")
# visualize the final graph
if self.ad_config.visualize:
try:
from .library import visualize_namespace
visualize_namespace(egm, args=cm.args, dynamic_shapes=cm.dynamic_shapes)
ad_logger.warning(
"Please run `pip install -r examples/auto_deploy/requirements.txt` to visualize"
" the graph."
)
except ImportError:
pass
############################################################################################
# SWITCH TO CACHED+FLATTENED ATTENTION + INITIALIZE CACHES
############################################################################################
update_in_out_nodes(egm, cm)
# detect attention op and replace with cache-aware op
for a_backend in [self.ad_config.attn_backend, self.ad_config.mla_backend]:
attn_descriptor = AttentionRegistry.get(a_backend)
insert_cached_attention(egm, cm, attn_descriptor, self.factory.get_cache_config())
# initialize cache on correct device
cm.initialize_caches()
# resize kv cache to occupy the available GPU memory up to free_mem_ratio
resize_kv_cache(egm, cm, free_mem_ratio=self.ad_config.free_mem_ratio)
############################################################################################
# COMPILE MODEL
############################################################################################
cm.info.set_generate_only_batch()
compiler_kwargs = {
"cuda_graph_batch_sizes": self.ad_config.cuda_graph_batch_sizes,
"num_batched_inputs": 2, # TODO (lucaslie): improve once we have a config system...
}
egm_compiled = compile_and_capture(
egm,
self.ad_config.compile_backend,
args=cm.args,
dynamic_shapes=cm.dynamic_shapes,
compiler_kwargs=compiler_kwargs,
)
cm.info.reset()
# visualize_namespace(egm, args=cm.args, dynamic_shapes=cm.dynamic_shapes)
# ad_logger.warning(
# "Please run `pip install -r examples/auto_deploy/requirements.txt` to visualize"
# " the graph."
# )
# except ImportError:
# pass
torch.cuda.empty_cache()
gc.collect()
return egm_compiled
return egm

View File

@ -43,11 +43,13 @@ def _patch_unsupported_input_tensor():
"""
original_fn = lowering.unsupported_input_tensor
def patched_fn(t: torch.Tensor, node=None):
def patched_fn(t: torch.Tensor, *args, **kwargs):
"""Bypass meta tensor check."""
if t.is_meta:
return False
return original_fn(t, node)
return original_fn(
t, *args, **kwargs
) # a generic pass-through of the arguments to accommodate torch side change
lowering.unsupported_input_tensor = patched_fn
try:

View File

@ -453,7 +453,8 @@ class AutoTuner:
p.name
for p in inspect.signature(runner.forward).parameters.values()
}
valid_tactics = runner.get_valid_tactics(input_tensors, profile)
valid_tactics = runner.get_valid_tactics(input_tensors, profile,
**kwargs)
if "do_preparation" in runner_arg_names and len(valid_tactics) > 0:
runner(
input_tensors,

View File

@ -531,3 +531,11 @@ def _register_fake():
return router_logits.new_empty(
sz, dtype=torch.int32), router_logits.new_empty(sz,
dtype=torch.float32)
@torch.library.register_fake("trtllm::default_moe_routing_op")
def _(router_logits, topk):
num_tokens = router_logits.shape[0]
sz = (num_tokens, topk)
return router_logits.new_empty(
sz, dtype=torch.int32), router_logits.new_empty(sz,
dtype=torch.float32)

View File

@ -81,12 +81,9 @@ class MoERunner(TunableRunner):
use_fused_finalize)
self.fused_moe_runner = MoERunner.runner_dict[instance_key]
def get_valid_tactics(
self,
inputs: List[torch.Tensor],
profile: OptimizationProfile,
) -> List[int]:
return range(self.fused_moe_runner.get_tactic_num())
def get_valid_tactics(self, inputs: List[torch.Tensor],
profile: OptimizationProfile, **kwargs) -> List[int]:
return range(self.fused_moe_runner.get_tactic_num(kwargs["gemm_idx"]))
def forward(
self,
@ -318,11 +315,8 @@ class FP8RowwiseGemmRunner(TunableRunner):
self.fp8_rowwise_gemm_runner = FP8RowwiseGemmRunner.runner_dict[
instance_key]
def get_valid_tactics(
self,
inputs: List[torch.Tensor],
profile: OptimizationProfile,
) -> List[int]:
def get_valid_tactics(self, inputs: List[torch.Tensor],
profile: OptimizationProfile, **kwargs) -> List[int]:
return list(range(self.fp8_rowwise_gemm_runner.get_num_configs()))
def forward(
@ -403,11 +397,8 @@ class FP4GemmRunner(TunableRunner):
output_dtype, int(fp4_gemm_type))
self.fp4_gemm_runner = FP4GemmRunner.runner_dict[instance_key]
def get_valid_tactics(
self,
inputs: List[torch.Tensor],
profile: OptimizationProfile,
) -> List[int]:
def get_valid_tactics(self, inputs: List[torch.Tensor],
profile: OptimizationProfile, **kwargs) -> List[int]:
return list(range(self.fp4_gemm_runner.get_num_configs()))
def forward(
@ -518,11 +509,8 @@ class FP8BatchedGemmRunner(TunableRunner):
return out_tensors
def get_valid_tactics(
self,
inputs: List[torch.Tensor],
profile: OptimizationProfile,
) -> List[int]:
def get_valid_tactics(self, inputs: List[torch.Tensor],
profile: OptimizationProfile, **kwargs) -> List[int]:
mat1, mat2, _, _, _ = inputs
@ -735,11 +723,8 @@ class WeightOnlyQuantGemmRunner(TunableRunner):
self.weight_only_quant_gemm_runner = WeightOnlyQuantGemmRunner.runner_dict[
instance_key]
def get_valid_tactics(
self,
inputs: List[torch.Tensor],
profile: OptimizationProfile,
) -> List[int]:
def get_valid_tactics(self, inputs: List[torch.Tensor],
profile: OptimizationProfile, **kwargs) -> List[int]:
return list(range(self.weight_only_quant_gemm_runner.get_num_configs()))
def forward(
@ -813,11 +798,8 @@ class FinegrainedMixedDtypeGemm(TunableRunner):
self._finegrained_mixed_dtype_gemm_runner = FinegrainedMixedDtypeGemm._runner_dict[
instance_key]
def get_valid_tactics(
self,
inputs: List[torch.Tensor],
profile: OptimizationProfile,
) -> List[int]:
def get_valid_tactics(self, inputs: List[torch.Tensor],
profile: OptimizationProfile, **kwargs) -> List[int]:
return list(
range(self._finegrained_mixed_dtype_gemm_runner.get_num_configs()))

View File

@ -122,11 +122,8 @@ class FP4BlockScaleMoERunner(TunableRunner):
self.local_num_experts, self.routed_scaling_factor,
self.routing_method_type, self.do_finalize, tactic)
def get_valid_tactics(
self,
inputs: List[torch.Tensor],
profile: OptimizationProfile,
) -> List[int]:
def get_valid_tactics(self, inputs: List[torch.Tensor],
profile: OptimizationProfile, **kwargs) -> List[int]:
args = FP4BlockScaleMoEInputs(*inputs)
@ -409,11 +406,8 @@ class FP8BlockScaleMoERunner(TunableRunner):
self.local_expert_offset, self.local_num_experts,
self.routed_scaling_factor, self.routing_method_type, tactic)
def get_valid_tactics(
self,
inputs: List[torch.Tensor],
profile: OptimizationProfile,
) -> List[int]:
def get_valid_tactics(self, inputs: List[torch.Tensor],
profile: OptimizationProfile, **kwargs) -> List[int]:
args = FP8BlockScaleMoEInputs(*inputs)
@ -670,11 +664,8 @@ class MxE4m3MxE2m1BlockScaleMoERunner(TunableRunner):
self.local_expert_offset, self.local_num_experts,
self.routed_scaling_factor, self.routing_method_type, tactic)
def get_valid_tactics(
self,
inputs: List[torch.Tensor],
profile: OptimizationProfile,
) -> List[int]:
def get_valid_tactics(self, inputs: List[torch.Tensor],
profile: OptimizationProfile, **kwargs) -> List[int]:
args = MxE4m3MxE2m1BlockScaleMoEInputs(*inputs)
@ -907,11 +898,8 @@ class E4m3MxE2m1BlockScaleMoERunner(TunableRunner):
self.local_expert_offset, self.local_num_experts,
self.routed_scaling_factor, self.routing_method_type, tactic)
def get_valid_tactics(
self,
inputs: List[torch.Tensor],
profile: OptimizationProfile,
) -> List[int]:
def get_valid_tactics(self, inputs: List[torch.Tensor],
profile: OptimizationProfile, **kwargs) -> List[int]:
args = E4m3MxE2m1BlockScaleMoEInputs(*inputs)
@ -1123,11 +1111,8 @@ class Bf16MxE2m1BlockScaleMoERunner(TunableRunner):
self.local_num_experts, self.routed_scaling_factor,
self.routing_method_type, tactic)
def get_valid_tactics(
self,
inputs: List[torch.Tensor],
profile: OptimizationProfile,
) -> List[int]:
def get_valid_tactics(self, inputs: List[torch.Tensor],
profile: OptimizationProfile, **kwargs) -> List[int]:
args = Bf16MxE2m1BlockScaleMoEInputs(*inputs)

View File

@ -65,7 +65,7 @@ from ..modules.linear import Linear, TensorParallelMode, WeightsLoadingConfig
from ..modules.multi_stream_utils import maybe_execute_in_parallel
from ..modules.rms_norm import RMSNorm
from ..peft.lora.layer import LoraLayer
from ..speculative import MTPSpecMetadata, SpecMetadata
from ..speculative import SpecMetadata
from ..utils import AuxStreamType, EventType, Fp4QuantizedTensor
from .modeling_speculative import SpecDecOneEngineForCausalLM
from .modeling_utils import (DecoderModel, EagerFusionConfig, filter_weights,
@ -230,7 +230,7 @@ class DeepseekV3Attention(MLA):
aux_stream: Optional[torch.cuda.Stream] = None,
):
config = model_config.pretrained_config
predicted_tokens_per_seq = model_config.spec_config.num_nextn_predict_layers + 1 if model_config.spec_config is not None else 1
predicted_tokens_per_seq = model_config.spec_config.max_draft_len + 1 if model_config.spec_config is not None else 1
super().__init__(hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
@ -750,6 +750,7 @@ class DeepseekV3DecoderLayer(DecoderLayer):
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: torch.Tensor,
spec_metadata: Optional[SpecMetadata] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
if residual is None:
@ -765,16 +766,24 @@ class DeepseekV3DecoderLayer(DecoderLayer):
**kwargs,
)
if isinstance(self.mlp, Deepseekv3MoE):
if spec_metadata is not None and spec_metadata.is_layer_capture(
self.layer_idx):
self.fusion_config.POST_MOE_FUSION = False
return self.forward_MoE(
hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual,
spec_metadata=spec_metadata,
)
else:
if spec_metadata is not None and spec_metadata.is_layer_capture(
self.layer_idx):
self.fusion_config.POST_MLP_FUSION = False
assert isinstance(self.mlp, GatedMLP)
return self.forward_mlp(
hidden_states=hidden_states,
residual=residual,
spec_metadata=spec_metadata,
)
def forward_MoE(
@ -782,6 +791,7 @@ class DeepseekV3DecoderLayer(DecoderLayer):
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: torch.Tensor,
spec_metadata: Optional[SpecMetadata] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
def _run_MoE(hidden_states, hidden_states_fp4, do_finalize):
@ -856,6 +866,10 @@ class DeepseekV3DecoderLayer(DecoderLayer):
hidden_states, residual = self.moe_allreduce(
fc2_output, all_reduce_params=moe_all_reduce_params)
else:
if spec_metadata is not None and spec_metadata.is_layer_capture(
self.layer_idx):
spec_metadata.maybe_capture_hidden_states(
self.layer_idx, hidden_states, residual)
if self.next_layer_layernorm is not None:
hidden_states, residual = self.next_layer_layernorm(
hidden_states, residual)
@ -866,6 +880,7 @@ class DeepseekV3DecoderLayer(DecoderLayer):
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
spec_metadata: Optional[SpecMetadata] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.fusion_config.PRE_MLP_FUSION:
@ -903,6 +918,10 @@ class DeepseekV3DecoderLayer(DecoderLayer):
),
)
else:
if spec_metadata is not None and spec_metadata.is_layer_capture(
self.layer_idx):
spec_metadata.maybe_capture_hidden_states(
self.layer_idx, hidden_states, residual)
if self.next_layer_layernorm is not None:
hidden_states, residual = self.next_layer_layernorm(
hidden_states, residual)
@ -1105,6 +1124,7 @@ class DeepseekV3Model(DecoderModel):
hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual,
spec_metadata=spec_metadata,
)
return hidden_states
@ -1132,7 +1152,8 @@ class DeepseekV3ForCausalLM(SpecDecOneEngineForCausalLM[DeepseekV3Model,
model_config=model_config)
self.model_nextn = 0
if model_config.spec_config is not None:
if model_config.spec_config is not None and model_config.spec_config.spec_dec_mode.is_mtp(
):
model_nextn = model_config.spec_config.num_nextn_predict_layers
ckpt_nextn = self.config.num_nextn_predict_layers
self.num_hidden_layers = self.config.num_hidden_layers
@ -1167,11 +1188,10 @@ class DeepseekV3ForCausalLM(SpecDecOneEngineForCausalLM[DeepseekV3Model,
input_ids: torch.IntTensor = None,
position_ids: Optional[torch.IntTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
spec_metadata: Optional[MTPSpecMetadata] = None,
spec_metadata: Optional[SpecMetadata] = None,
return_context_logits: bool = False,
**kwargs,
) -> torch.Tensor:
attn_metadata.num_generations_per_batch = self.model_nextn + 1
return super().forward(attn_metadata=attn_metadata,
input_ids=input_ids,
position_ids=position_ids,
@ -1313,7 +1333,9 @@ class DeepseekV3ForCausalLM(SpecDecOneEngineForCausalLM[DeepseekV3Model,
for name, module in tqdm(all_named_modules.items(),
desc="Loading weights"):
if len(module._parameters) > 0:
if len(module._parameters) <= 0 or name.startswith("draft_model"):
continue
else:
names = name.split('.')
parent_module_name = '.'.join(names[:-1])
if "model.layers" in name and int(

View File

@ -194,11 +194,16 @@ class Gemma3VLM(PreTrainedModel):
"text_config", "vision_config"
], f"Expected subconfig name to be either 'text_config' or 'vision_config'. Got {name} instead."
pretrained_config = getattr(model_config.pretrained_config, name)
# ModelOpt currently doesn't quantize the vision part. Without setting quant config to None,
# weight loading fails for vision.
quant_config = model_config.quant_config if name == "text_config" else None
# FlashInfer backend supports custom mask which is needed for bidirectional mask in decoder.
preferred_backend = "FLASHINFER" if name == "text_config" else "TRTLLM"
sub_model_config: ModelConfig[Gemma3Config] = dataclasses.replace(
model_config,
pretrained_config=pretrained_config,
attn_backend=preferred_backend)
attn_backend=preferred_backend,
quant_config=quant_config)
# Make sure some fields that are not explicitly included in the sub config, but present
# in the top-level config, are replicated.
if (hasattr(sub_model_config.pretrained_config, "torch_dtype")

View File

@ -221,7 +221,9 @@ class NemotronHModel(DecoderModel):
)
if self.mamba_metadata is None or self.mamba_metadata.max_batch_size != attn_metadata.max_num_requests:
self.mamba_metadata = Mamba2Metadata(attn_metadata.max_num_requests)
self.mamba_metadata = Mamba2Metadata(
attn_metadata.max_num_requests,
chunk_size=self.model_config.pretrained_config.chunk_size)
self.mamba_metadata.prepare(attn_metadata)
if inputs_embeds is None:

View File

@ -611,23 +611,21 @@ class Phi4MMForCausalLM(transformers.PreTrainedModel):
@staticmethod
def lora_config(model_dir: str):
_lora_config = LoraConfig(
lora_dir=[
f"{model_dir}/vision-lora",
f"{model_dir}/speech-lora",
],
lora_target_modules=[
"attn_qkv",
"attn_dense",
"mlp_h_to_4h",
"mlp_gate_up",
"mlp_4h_to_h",
],
trtllm_modules_to_hf_modules={
"attn_qkv": "qkv_proj",
"attn_dense": "o_proj",
"mlp_h_to_4h": "gate_up_proj",
"mlp_gate_up": "gate_up_proj",
"mlp_4h_to_h": "down_proj",
},
max_lora_rank=320, # Max rank for Phi4MM.
swap_gate_up_proj_lora_b_weight=
False, # Disable swap gate_up_proj.lora_B.weight for Phi4MM.
)
return _lora_config

View File

@ -155,10 +155,12 @@ class Eagle3DraftModel(DecoderModel):
else:
self.hidden_size_in = config.hidden_size
self.fc = Linear(self.hidden_size_in * 3,
config.hidden_size,
bias=getattr(config, "bias", False),
dtype=config.torch_dtype)
if self.spec_config.num_capture_layers > 1:
self.fc = Linear(self.hidden_size_in *
self.spec_config.num_capture_layers,
config.hidden_size,
bias=getattr(config, "bias", False),
dtype=config.torch_dtype)
self.midlayer = Eagle3DecoderLayer(model_config, start_layer_idx)

View File

@ -183,18 +183,28 @@ class BaseMoeRoutingMethod(nn.Module):
class DefaultMoeRoutingMethod(BaseMoeRoutingMethod):
def __init__(self, top_k: int):
def __init__(self, top_k: int, force_enable_pytorch_op: bool = False):
super().__init__()
self.top_k = top_k
self.force_enable_pytorch_op = force_enable_pytorch_op
def apply(self,
router_logits: torch.Tensor) -> (torch.Tensor, torch.Tensor):
def apply_pytorch(
self, router_logits: torch.Tensor) -> (torch.Tensor, torch.Tensor):
topk_values, topk_indices = torch.topk(torch.nn.functional.softmax(
router_logits.float(), dim=-1),
k=self.top_k,
dim=-1)
return topk_indices.to(torch.int32), topk_values
def apply(self,
router_logits: torch.Tensor) -> (torch.Tensor, torch.Tensor):
num_experts = router_logits.shape[-1]
if self.force_enable_pytorch_op or num_experts > 128 or self.top_k > 8:
return self.apply_pytorch(router_logits)
else:
return torch.ops.trtllm.default_moe_routing_op(
router_logits, self.top_k)
@property
def routing_method_type(self):
return RoutingMethodType.Default

View File

@ -13,15 +13,83 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Tuple
import torch
from tensorrt_llm._torch.attention_backend.interface import AttentionMetadata
def cu_seqlens_to_chunk_indices_offsets(
cu_seqlens: torch.Tensor,
chunk_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
cu_seqlens (torch.Tensor): 1D tensor of cumulative sequence lengths, shape (num_seqs + 1,). The first element should be 0. Each entry represents the starting index of a sequence in the flattened token array.
chunk_size (int): The size of each physical mamba chunk (number of tokens per chunk).
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- chunk_indices (torch.Tensor): 1D tensor of indices indicating the physical chunk for each logical chunk.
- chunk_offsets (torch.Tensor): 1D tensor of offsets indicating the starting index of each logical chunk within its physical chunk.
This function computes the chunk indices and offsets for the given cu_seqlens and chunk_size.
Both are tensors of integers with length N, where N is the number of logical (pseudo) chunks.
A logical chunk is a sequence of tokens that are all part of the same sequence and are all in the same physical mamba chunk.
In other words, a logical chunk changes every time we cross a sequence boundary or a physical mamba chunk boundary.
Logical chunks are needed to handle batched requests with initial states (see _state_passing_fwd and _chunk_scan_fwd).
The chunk_indices tensor contains the index of the physical chunk for each logical chunk.
The chunk_offsets tensor contains the offset (AKA starting index) of the logical chunk in the physical chunk.
Example:
cu_seqlens = [0, 5, 10]
chunk_size = 8
-> chunk_indices = [0, 1, 0]
-> chunk_offsets = [0, 5, 0]
In this example, we have 2 sequences, each with 5 tokens. The physical chunk size is 8 tokens.
We have three logical chunks:
- the first logical chunk starts at token 0 in the first physical chunk and contains all 5 tokens from the first sequence
- the second logical chunk starts at token 5 in the first physical chunk and contains first 3 tokens from the second sequence
- the third logical chunk starts at token 0 in the second physical chunk and contains the remaining 2 tokens from the second sequence
"""
total_seqlens = cu_seqlens[-1]
cu_seqlens = cu_seqlens[1:] # remove prepended 0
# outputs will have length expansion of chunks that do not divide
# chunk_size
N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size
> 0).sum()
chunk_indices = torch.arange(N, dtype=torch.int, device=cu_seqlens.device)
chunk_offsets = torch.zeros((N, ),
dtype=torch.int,
device=cu_seqlens.device)
p = 0 # num of insertions
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
# if does not divide chunk_size, then there is one chunk insertion
p += (s % chunk_size > 0)
# get the dimensions
# - the + 1 for _e is to shift the boundary by one chunk
# - this shifting is not needed if chunk_size divides e
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size > 0)
# adjust inidces and offsets
chunk_indices[_s:_e] -= p
chunk_offsets[_s] = s % chunk_size
return chunk_indices, chunk_offsets
class Mamba2Metadata:
def __init__(self, max_batch_size: int):
def __init__(self, max_batch_size: int, chunk_size: int):
self.max_batch_size = max_batch_size
self.chunk_size = chunk_size
# cumulative sequence lengths for prefill requests [batch_size+1]
self.cu_seqlens = torch.zeros(max_batch_size + 1,
@ -31,9 +99,18 @@ class Mamba2Metadata:
# sequence index for prefill requests [num_prefill_tokens] - specifies which request each token belongs to
self.seq_idx: torch.Tensor = None
# helper tensors for chunked prefill
self.has_initial_states = torch.zeros(max_batch_size,
dtype=torch.bool,
device="cuda")
self.use_initial_states = False
self.chunk_indices: torch.Tensor = None
self.chunk_offsets: torch.Tensor = None
def prepare(self, attn_metadata: AttentionMetadata):
num_contexts = attn_metadata.num_contexts
context_lens = attn_metadata.seq_lens_cuda[:num_contexts]
num_ctx_tokens = attn_metadata.num_ctx_tokens
if num_contexts > 0:
torch.cumsum(context_lens,
dim=0,
@ -44,4 +121,17 @@ class Mamba2Metadata:
dtype=torch.int,
device=self.cu_seqlens.device),
repeats=context_lens,
output_size=self.cu_seqlens[num_contexts]).unsqueeze(0)
output_size=num_ctx_tokens).unsqueeze(0)
num_cached_tokens_per_seq = attn_metadata.kv_cache_params.num_cached_tokens_per_seq
self.has_initial_states[:num_contexts] = torch.tensor(
num_cached_tokens_per_seq[:num_contexts]) > 0
# precomputed bool to avoid host<->device syncs during forward pass
self.use_initial_states = torch.any(
self.has_initial_states[:num_contexts]).item()
if self.use_initial_states:
self.chunk_indices, self.chunk_offsets = cu_seqlens_to_chunk_indices_offsets(
self.cu_seqlens[:num_contexts + 1], self.chunk_size)
else:
self.chunk_indices = None
self.chunk_offsets = None

View File

@ -191,12 +191,15 @@ class Mamba2Mixer(nn.Module):
cu_seqlens = mamba_metadata.cu_seqlens[:num_prefills + 1]
seq_idx = mamba_metadata.seq_idx
has_initial_states = mamba_metadata.has_initial_states[:
num_prefills]
xbc_p = causal_conv1d_fn(xbc_p.transpose(0, 1),
self.conv1d.weight,
self.conv1d.bias,
activation="silu",
conv_states=conv_states,
has_initial_state=has_initial_states,
query_start_loc=cu_seqlens,
cache_indices=state_indices_p).transpose(
0, 1)
@ -216,6 +219,12 @@ class Mamba2Mixer(nn.Module):
"b l (h p) -> b l h p",
h=self.tp_nheads)
initial_states = None
if mamba_metadata.use_initial_states:
initial_states = torch.where(
has_initial_states[:, None, None, None],
ssm_states[state_indices_p], 0)
y, current_ssm_states = mamba_chunk_scan_combined(
x_p,
dt_p,
@ -226,7 +235,9 @@ class Mamba2Mixer(nn.Module):
D=self.D,
z=z_p,
dt_bias=self.dt_bias,
initial_states=None,
initial_states=initial_states,
chunk_indices=mamba_metadata.chunk_indices,
chunk_offsets=mamba_metadata.chunk_offsets,
dt_softplus=self.delta_softplus,
cu_seqlens=cu_seqlens,
seq_idx=seq_idx,

View File

@ -314,11 +314,12 @@ def _chunk_scan_fwd_kernel(
# get the cs at the offset boundary
# - c_off == 0 is a passthrough
# - We need dA_cs at the boundary, defined by c_off - no need
# to increase pointer by pid_m (it is a constant offset,
# i.e. the same for all blocks)
dA_cs_m_boundary = tl.load(
dA_cumsum_ptr +
(pid_m * BLOCK_SIZE_M + c_off - 1) * stride_dA_cs_csize,
mask=(((pid_m * BLOCK_SIZE_M + c_off - 1) > -1)
and ((pid_m * BLOCK_SIZE_M + c_off) < chunk_size)),
dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize,
mask=(((c_off - 1) > -1) and (c_off < chunk_size)),
other=0.0).to(tl.float32)
if HAS_SEQ_IDX:

View File

@ -110,21 +110,24 @@ def _mamba_chunk_scan_combined_fwd(
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
# - for handling chunked prefill, this requires i) initial_states
# ii) seq_idx and iii) is_cont_batched to be all specified.
# ii) seq_idx iii) is_cont_batched and (iv) chunk_offsets to be all specified.
# - When a new seq_idx is detected, we will stop passing the prev_state
# and switch accordingly to the init_state corresponding to the new seq_idx.
# - We will also make sure that the dA_cumsum is taken only from the start of the
# sequence (hence we need the full dA_cumsum tensor and not just the values at chunk boundaries)
# - this will ensure that states will be updated with the rightmost flushed seq_idx
# of the previous chunk. This implies that the first chunk of states is either 0
# or equal to init_states of the first example.
states, final_states = _state_passing_fwd(
rearrange(states, "... p n -> ... (p n)"),
dA_cumsum[:, :, :, -1],
dA_cumsum,
initial_states=(rearrange(initial_states, "... p n -> ... (p n)")
if initial_states is not None else None),
seq_idx=seq_idx,
chunk_size=chunk_size,
out_dtype=mamba_ssm_cache_dtype or C.dtype,
is_cont_batched=cu_seqlens is not None)
is_cont_batched=cu_seqlens is not None,
chunk_offsets=chunk_offsets)
states, final_states = [
rearrange(t, "... (p n) -> ... p n", n=dstate)
for t in [states, final_states]

View File

@ -41,6 +41,8 @@ def _state_passing_fwd_kernel(
dA_cs_ptr,
initstates_ptr,
seq_idx_ptr,
chunk_offsets_ptr,
chunk_meta_num,
# Matrix dimensions
dim,
nchunks,
@ -61,6 +63,7 @@ def _state_passing_fwd_kernel(
stride_dA_cs_batch,
stride_dA_cs_chunk,
stride_dA_cs_head,
stride_dA_cs_csize,
stride_initstates_batch,
stride_initstates_head,
stride_initstates_dim,
@ -76,7 +79,8 @@ def _state_passing_fwd_kernel(
pid_h = tl.program_id(axis=2)
pid_m = tl.program_id(axis=0)
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head
dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + (
chunk_size - 1) * stride_dA_cs_csize
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
final_states_ptr += (pid_b * stride_final_states_batch +
pid_h * stride_final_states_head)
@ -105,35 +109,63 @@ def _state_passing_fwd_kernel(
other=0.0).to(tl.float32)
tl.store(out_ptrs, states, mask=offs_m < dim)
out_ptrs += stride_out_chunk
seq_idx = 0
prev_seq_idx_chunk_end = 0
logical_chunk_idx = 0
for c in range(nchunks):
new_states = tl.load(states_ptrs, mask=offs_m < dim,
other=0.0).to(tl.float32)
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
scale = tl.exp(dA_cs)
scale_mask = True
if HAS_SEQ_IDX:
# - the seq to pass forward is the one that is flushed to the right
# boundary.
# - that is given by seq_idx_new below.
seq_idx_new = tl.load(seq_idx_ptr +
(min((c + 1) * chunk_size, seqlen) - 1) *
stride_seq_idx_seqlen)
# - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk.
seq_idx_chunk_end = tl.load(seq_idx_ptr + (min(
(c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen)
if HAS_INITSTATES:
if IS_CONT_BATCHED and seq_idx != seq_idx_new:
if IS_CONT_BATCHED and prev_seq_idx_chunk_end != seq_idx_chunk_end:
# this means in the current chunk the rightmost flushed seq
# has changed.
# - so we do not propagate the state from previous chunk
# - but rather we load that sequence's init state
initstates_ptrs = initstates_ptr + seq_idx_new * stride_initstates_batch
initstates_ptrs = initstates_ptr + seq_idx_chunk_end * stride_initstates_batch
# - update state with seq_idx_new's init state
states = tl.load(initstates_ptrs,
mask=offs_m < dim,
other=0.0).to(tl.float32)
else:
scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)
seq_idx = seq_idx_new
# - we need to consider the cumsum only of the last sequence in the chunk
# - find its starting position (given by c_off of the logical chunk index)
# - and subtract the cumsum just before that position from the total cumsum
# - first, update the logical chunk index (add the number of sequences in the current physical chunk):
# sequence index at the start of the current chunk
seq_idx_chunk_start = tl.load(seq_idx_ptr +
min(c * chunk_size, seqlen) *
stride_seq_idx_seqlen)
logical_chunk_idx += (seq_idx_chunk_end -
seq_idx_chunk_start)
# - load the chunk offset:
c_off = tl.load(chunk_offsets_ptr + logical_chunk_idx,
mask=logical_chunk_idx < chunk_meta_num,
other=0)
# - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything
if c_off > 0:
# - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset
dA_cs_boundary = tl.load(
dA_cs_ptr - (chunk_size - 1) * stride_dA_cs_csize +
(c_off - 1) * stride_dA_cs_csize,
mask=(c_off - 1) > -1 and c_off < chunk_size,
other=0.0)
dA_cs -= dA_cs_boundary
# - increment logical chunk index for every physical chunk
logical_chunk_idx += 1
else:
scale_mask = seq_idx_chunk_end == prev_seq_idx_chunk_end
prev_seq_idx_chunk_end = seq_idx_chunk_end
scale = tl.where(scale_mask, tl.exp(dA_cs), 0.0)
states = scale * states + new_states
if c < nchunks - 1:
tl.store(out_ptrs, states, mask=offs_m < dim)
@ -146,28 +178,36 @@ def _state_passing_fwd_kernel(
def _state_passing_fwd(
states,
dA_chunk_cumsum,
dA_cumsum,
initial_states=None,
seq_idx=None,
chunk_size=None,
out_dtype=None,
is_cont_batched=False,
chunk_offsets=None,
):
batch, nchunks, nheads, dim = states.shape
assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
if chunk_size is None:
chunk_size = dA_cumsum.shape[-1]
else:
assert chunk_size == dA_cumsum.shape[-1]
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
if initial_states is not None:
if is_cont_batched:
# - if cu_seqlens is provided, then the initial states
# are used for continuous batching. In which case we
# require seq_idx to be provided
assert seq_idx is not None, ""
assert seq_idx is not None, "seq_idx must be provided for continuous batching"
# - we also need chunk_offsets to be provided, to account
# for computation of dA_cumsum from the start of the
# sequence
assert chunk_offsets is not None, "chunk_offsets must be provided for continuous batching"
else:
# - this is the regular batching case, where initial
# states are used are for each example of the batch.
assert initial_states.shape == (batch, nheads, dim)
if seq_idx is not None:
assert chunk_size is not None
seqlen = seq_idx.shape[-1]
assert seq_idx.shape == (batch, seqlen)
out_dtype = states.dtype if out_dtype is None else out_dtype
@ -183,13 +223,15 @@ def _state_passing_fwd(
states,
out,
final_states,
dA_chunk_cumsum,
dA_cumsum,
initial_states,
seq_idx,
chunk_offsets,
len(chunk_offsets) if chunk_offsets is not None else 0,
dim,
nchunks,
seqlen if seq_idx is not None else 0,
chunk_size if seq_idx is not None else 0,
chunk_size,
states.stride(0),
states.stride(1),
states.stride(2),
@ -201,9 +243,10 @@ def _state_passing_fwd(
final_states.stride(0),
final_states.stride(1),
final_states.stride(2),
dA_chunk_cumsum.stride(0),
dA_chunk_cumsum.stride(2),
dA_chunk_cumsum.stride(1),
dA_cumsum.stride(0),
dA_cumsum.stride(2),
dA_cumsum.stride(1),
dA_cumsum.stride(3),
*((
initial_states.stride(0),
initial_states.stride(1),

View File

@ -514,7 +514,8 @@ def create_py_executor_instance(
resources[ResourceManagerType.PEFT_CACHE_MANAGER] = peft_cache_manager
model_engine.set_lora_model_config(
lora_config.lora_target_modules,
lora_config.trtllm_modules_to_hf_modules)
lora_config.trtllm_modules_to_hf_modules,
lora_config.swap_gate_up_proj_lora_b_weight)
max_num_sequences = executor_config.max_batch_size * mapping.pp_size

View File

@ -16,7 +16,6 @@ from tensorrt_llm.mapping import CpType
from ..distributed import Distributed
from .llm_request import (ExecutorRequest, LlmRequest,
executor_request_to_llm_request)
from .sampler import Sampler, TorchSampler
SHUTDOWN_REQUEST_ID = -1
@ -707,21 +706,19 @@ class ExecutorRequestQueue:
def set_exclude_last_generation_logits(self,
disable_overlap_scheduler: bool,
sampler: Sampler) -> None:
pp_size: int) -> None:
# When overlap scheduler is enabled then when starting to handle a new prompt,
# sample_async is called twice before the first call to update_requests:
# - 1st time as a context request that handles on the 1st generated token
# - 2nd time as a generation request that handles on the 2nd generated token.
# and only after these two calls the sampler's update_request method is called.
# So in a sampler that works by the expected flow of handling the logits in
# sample_async (TorchSampler is an anomaly that instead does that on
# update_requests), every update_request doesn't handle the newest token, but one
# sample_async, every update_request doesn't handle the newest token, but one
# before it. Since all these calls work on the same request object, then its
# logits storage contains the logits of both the token update_requests should work
# on, and also its next token. Thus, excluding the last generation logits from any
# getter is required, when not using TorchSampler.
self.should_exclude_last_generation_logits = not disable_overlap_scheduler and not isinstance(
sampler, TorchSampler)
# getter is required.
self.should_exclude_last_generation_logits = not disable_overlap_scheduler and pp_size == 1
def _should_exclude_last_generation_logits(self) -> bool:
return self.should_exclude_last_generation_logits

Some files were not shown because too many files have changed in this diff Show More