mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Merge remote-tracking branch 'gitlab/main' into user/xiweny/merge_main_0819
Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com>
This commit is contained in:
commit
808059da34
2
.gitattributes
vendored
2
.gitattributes
vendored
@ -7,3 +7,5 @@
|
||||
triton_backend/tools/gpt/input_data.json filter=lfs diff=lfs merge=lfs -text
|
||||
*cubin.cpp filter=lfs diff=lfs merge=lfs -text
|
||||
docs/source/blogs/media/tech_blog3_mla_absorb.png filter=lfs diff=lfs merge=lfs -text
|
||||
tests/integration/test_input_files/*.png filter=lfs diff=lfs merge=lfs -text
|
||||
tests/integration/test_input_files/*.jpg filter=lfs diff=lfs merge=lfs -text
|
||||
|
||||
@ -9,7 +9,7 @@ TensorRT-LLM
|
||||
[](https://www.python.org/downloads/release/python-31012/)
|
||||
[](https://developer.nvidia.com/cuda-downloads)
|
||||
[](https://developer.nvidia.com/tensorrt)
|
||||
[](./tensorrt_llm/version.py)
|
||||
[](./tensorrt_llm/version.py)
|
||||
[](./LICENSE)
|
||||
|
||||
[Architecture](./docs/source/torch/arch_overview.md) | [Performance](./docs/source/performance/perf-overview.md) | [Examples](https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html) | [Documentation](./docs/source/) | [Roadmap](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue%20state%3Aopen%20label%3Aroadmap)
|
||||
@ -18,10 +18,9 @@ TensorRT-LLM
|
||||
<div align="left">
|
||||
|
||||
## Tech Blogs
|
||||
* [08/06] Running a High Performance GPT-OSS-120B Inference Server with TensorRT-LLM
|
||||
* [08/05] Running a High-Performance GPT-OSS-120B Inference Server with TensorRT-LLM
|
||||
✨ [➡️ link](./docs/source/blogs/tech_blog/blog9_Deploying_GPT_OSS_on_TRTLLM.md)
|
||||
|
||||
|
||||
* [08/01] Scaling Expert Parallelism in TensorRT-LLM (Part 2: Performance Status and Optimization)
|
||||
✨ [➡️ link](./docs/source/blogs/tech_blog/blog8_Scaling_Expert_Parallelism_in_TensorRT-LLM_part2.md)
|
||||
|
||||
@ -44,6 +43,7 @@ TensorRT-LLM
|
||||
✨ [➡️ link](./docs/source/blogs/tech_blog/blog1_Pushing_Latency_Boundaries_Optimizing_DeepSeek-R1_Performance_on_NVIDIA_B200_GPUs.md)
|
||||
|
||||
## Latest News
|
||||
* [08/05] 🌟 TensorRT-LLM delivers Day-0 support for OpenAI's latest open-weights models: GPT-OSS-120B [➡️ link](https://huggingface.co/openai/gpt-oss-120b) and GPT-OSS-20B [➡️ link](https://huggingface.co/openai/gpt-oss-20b)
|
||||
* [07/15] 🌟 TensorRT-LLM delivers Day-0 support for LG AI Research's latest model, EXAONE 4.0 [➡️ link](https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B)
|
||||
* [06/17] Join NVIDIA and DeepInfra for a developer meetup on June 26 ✨ [➡️ link](https://events.nvidia.com/scaletheunscalablenextgenai)
|
||||
* [05/22] Blackwell Breaks the 1,000 TPS/User Barrier With Meta’s Llama 4 Maverick
|
||||
@ -253,5 +253,5 @@ Deprecation is used to inform developers that some APIs and tools are no longer
|
||||
## Useful Links
|
||||
- [Quantized models on Hugging Face](https://huggingface.co/collections/nvidia/model-optimizer-66aa84f7966b3150262481a4): A growing collection of quantized (e.g., FP8, FP4) and optimized LLMs, including [DeepSeek FP4](https://huggingface.co/nvidia/DeepSeek-R1-FP4), ready for fast inference with TensorRT-LLM.
|
||||
- [NVIDIA Dynamo](https://github.com/ai-dynamo/dynamo): A datacenter scale distributed inference serving framework that works seamlessly with TensorRT-LLM.
|
||||
- [AutoDeploy](./examples/auto_deploy/README.md): A prototype backend for TensorRT-LLM to simplify and accelerate the deployment of PyTorch models.
|
||||
- [AutoDeploy](https://nvidia.github.io/TensorRT-LLM/torch/auto_deploy/auto-deploy.html): A prototype backend for TensorRT-LLM to simplify and accelerate the deployment of PyTorch models.
|
||||
- [WeChat Discussion Group](https://github.com/NVIDIA/TensorRT-LLM/issues/5359): A real-time channel for TensorRT-LLM Q&A and news.
|
||||
|
||||
@ -69,7 +69,7 @@ add_compile_definitions("TLLM_GEN_EXPORT_INTERFACE")
|
||||
add_compile_definitions("TLLM_ENABLE_CUDA")
|
||||
|
||||
set(BINDING_TYPE
|
||||
"pybind"
|
||||
"nanobind"
|
||||
CACHE STRING
|
||||
"Binding type of Python bindings for C++ runtime and batch manager")
|
||||
|
||||
|
||||
@ -24,7 +24,6 @@
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include "tensorrt_llm/runtime/modelConfig.h"
|
||||
#include "tensorrt_llm/runtime/request.h"
|
||||
#include "tensorrt_llm/runtime/worldConfig.h"
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
@ -88,37 +87,6 @@ public:
|
||||
SizeType32 maxSequenceLength, OptionalRef<MedusaBuffers const> medusaBuffers) const;
|
||||
|
||||
private:
|
||||
//! @brief Setups decoder internal tensors for new speculative decoding request
|
||||
static void newRequestSpeculativeDecoding(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
|
||||
SamplingConfig const& samplingConfig, runtime::ModelConfig const& modelConfig,
|
||||
DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream,
|
||||
CudaStream const& decoderStream, SpeculativeDecodingMode const& speculativeDecodingMode,
|
||||
SizeType32 maxDecodingEngineTokens);
|
||||
|
||||
//! @brief Setups decoder internal tensors for new request in Draft model Sps mode
|
||||
static void newRequestDraftTokensExternal(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
|
||||
SamplingConfig const& samplingConfig, DecodingInput& jointDecodingInput, CudaStream const& decoderStream);
|
||||
|
||||
//! @brief Setups decoder internal tensors for new Medusa request
|
||||
static void newRequestMedusa(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
|
||||
DecodingInput& jointDecodingInput, CudaStream const& decoderStream, SizeType32 maxDecodingEngineTokens);
|
||||
|
||||
//! @brief Setups decoder internal tensors for new Lookahead request
|
||||
static void newRequestLookahead(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
|
||||
DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream);
|
||||
|
||||
//! @brief Setups decoder internal tensors for new Explicit draft tokens request
|
||||
static void newRequestExplicitDraftTokens(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
|
||||
DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream);
|
||||
|
||||
//! @brief Setups decoder internal tensors for new Eagle request
|
||||
static void newRequestEagle(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
|
||||
runtime::ModelConfig const& modelConfig, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream);
|
||||
|
||||
[[nodiscard]] std::shared_ptr<runtime::ITensor> retrieveDraftLogits(runtime::ModelConfig const& modelConfig,
|
||||
runtime::WorldConfig const& worldConfig, std::shared_ptr<runtime::ITensor> const& tensor,
|
||||
runtime::BufferManager const& bufferManager) const;
|
||||
|
||||
bool mSpeculativeDecodingFastLogits;
|
||||
bool mIsLeaderInOrchMode;
|
||||
bool mIsNormalizeLogProbs;
|
||||
|
||||
@ -1110,7 +1110,7 @@ public:
|
||||
|
||||
[[nodiscard]] SizeType32 getNumDraftTokens() const
|
||||
{
|
||||
return mDraftTokens->size();
|
||||
return hasDraftTokens() ? mDraftTokens->size() : 0;
|
||||
}
|
||||
|
||||
void discardDraftTokens(SizeType32 numTokensToDiscard)
|
||||
|
||||
@ -54,20 +54,21 @@ public:
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
template <typename... Args>
|
||||
void log(Level level, char const* format, Args const&... args);
|
||||
void log(Level const level, char const* format, Args const&... args);
|
||||
|
||||
template <typename... Args>
|
||||
void log(Level level, int rank, char const* format, Args const&... args);
|
||||
void log(Level const level, int const rank, char const* format, Args const&... args);
|
||||
#else
|
||||
template <typename... Args>
|
||||
void log(Level level, char const* format, Args const&... args) __attribute__((format(printf, 3, 0)));
|
||||
void log(Level const level, char const* format, Args const&... args) __attribute__((format(printf, 3, 0)));
|
||||
|
||||
template <typename... Args>
|
||||
void log(Level level, int rank, char const* format, Args const&... args) __attribute__((format(printf, 4, 0)));
|
||||
void log(Level const level, int const rank, char const* format, Args const&... args)
|
||||
__attribute__((format(printf, 4, 0)));
|
||||
#endif
|
||||
|
||||
template <typename... Args>
|
||||
void log(Level level, std::string const& format, Args const&... args)
|
||||
void log(Level const level, std::string const& format, Args const&... args)
|
||||
{
|
||||
return log(level, format.c_str(), args...);
|
||||
}
|
||||
@ -134,7 +135,7 @@ private:
|
||||
};
|
||||
|
||||
template <typename... Args>
|
||||
void Logger::log(Logger::Level level, char const* format, Args const&... args)
|
||||
void Logger::log(Logger::Level const level, char const* format, Args const&... args)
|
||||
{
|
||||
if (isEnabled(level))
|
||||
{
|
||||
|
||||
@ -52,29 +52,30 @@ public:
|
||||
AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2)
|
||||
: mModelConfig(std::move(modelConfig))
|
||||
, mParallelConfig{worldConfig.getTensorParallelism(), worldConfig.getPipelineParallelism(),
|
||||
worldConfig.enableAttentionDP(), worldConfig.getTensorParallelRank(), worldConfig.getTensorParallelism()}
|
||||
worldConfig.getContextParallelism(), worldConfig.enableAttentionDP(), worldConfig.getTensorParallelRank(),
|
||||
worldConfig.getTensorParallelism()}
|
||||
, mDataType{dataType}
|
||||
, mAttentionConfig(attentionType, kvFactor)
|
||||
{
|
||||
}
|
||||
|
||||
CacheState(std::vector<SizeType32> nbKvHeadPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
|
||||
SizeType32 tensorParallelism, SizeType32 pipelineParallelism, nvinfer1::DataType dataType,
|
||||
AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableAttentionDP = false,
|
||||
int DPrank = 0, int DPsize = 0)
|
||||
SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism,
|
||||
nvinfer1::DataType dataType, AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2,
|
||||
bool enableAttentionDP = false, int DPrank = 0, int DPsize = 0)
|
||||
: mModelConfig{std::move(nbKvHeadPerLayer), sizePerHead, tokensPerBlock}
|
||||
, mParallelConfig{tensorParallelism, pipelineParallelism, enableAttentionDP, DPrank, DPsize}
|
||||
, mParallelConfig{tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize}
|
||||
, mDataType{dataType}
|
||||
, mAttentionConfig(attentionType, kvFactor)
|
||||
{
|
||||
}
|
||||
|
||||
CacheState(SizeType32 nbAttentionLayers, SizeType32 nbKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
|
||||
SizeType32 tensorParallelism, SizeType32 pipelineParallelism, nvinfer1::DataType dataType,
|
||||
AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableAttentionDP = false,
|
||||
int DPrank = 0, int DPsize = 0)
|
||||
SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism,
|
||||
nvinfer1::DataType dataType, AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2,
|
||||
bool enableAttentionDP = false, int DPrank = 0, int DPsize = 0)
|
||||
: mModelConfig{std::vector(nbAttentionLayers, nbKvHeads), sizePerHead, tokensPerBlock}
|
||||
, mParallelConfig{tensorParallelism, pipelineParallelism, enableAttentionDP, DPrank, DPsize}
|
||||
, mParallelConfig{tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize}
|
||||
, mDataType{dataType}
|
||||
, mAttentionConfig(attentionType, kvFactor)
|
||||
{
|
||||
@ -83,7 +84,7 @@ public:
|
||||
[[nodiscard]] bool operator==(kv_cache::CacheState const& other) const noexcept
|
||||
{
|
||||
return mModelConfig == other.mModelConfig && mParallelConfig == other.mParallelConfig
|
||||
&& mDataType == other.mDataType;
|
||||
&& mAttentionConfig == other.mAttentionConfig && mDataType == other.mDataType;
|
||||
}
|
||||
|
||||
struct ModelConfig
|
||||
@ -103,6 +104,7 @@ public:
|
||||
{
|
||||
SizeType32 mTensorParallelism;
|
||||
SizeType32 mPipelineParallelism;
|
||||
SizeType32 mContextParallelism;
|
||||
bool mEnableAttentionDP;
|
||||
SizeType32 mDPrank;
|
||||
SizeType32 mDPsize;
|
||||
@ -110,8 +112,8 @@ public:
|
||||
[[nodiscard]] bool operator==(ParallelConfig const& other) const noexcept
|
||||
{
|
||||
return mTensorParallelism == other.mTensorParallelism && mPipelineParallelism == other.mPipelineParallelism
|
||||
&& mEnableAttentionDP == other.mEnableAttentionDP && mDPrank == other.mDPrank
|
||||
&& mDPsize == other.mDPsize;
|
||||
&& mContextParallelism == other.mContextParallelism && mEnableAttentionDP == other.mEnableAttentionDP
|
||||
&& mDPrank == other.mDPrank && mDPsize == other.mDPsize;
|
||||
}
|
||||
};
|
||||
|
||||
@ -125,6 +127,11 @@ public:
|
||||
{
|
||||
}
|
||||
|
||||
[[nodiscard]] bool operator==(AttentionConfig const& other) const noexcept
|
||||
{
|
||||
return mAttentionType == other.mAttentionType && mKvFactor == other.mKvFactor;
|
||||
}
|
||||
|
||||
// attentionType ;
|
||||
AttentionType mAttentionType;
|
||||
int mKvFactor;
|
||||
@ -162,6 +169,7 @@ public:
|
||||
sstring << "mTokensPerBlock:" << mModelConfig.mTokensPerBlock << "\n";
|
||||
sstring << "tp:" << mParallelConfig.mTensorParallelism << "\n";
|
||||
sstring << "pp:" << mParallelConfig.mPipelineParallelism << "\n";
|
||||
sstring << "cp:" << mParallelConfig.mContextParallelism << "\n";
|
||||
sstring << "enableAttentionDP:" << mParallelConfig.mEnableAttentionDP << "\n";
|
||||
sstring << "datatype:" << static_cast<int32_t>(mDataType) << "\n";
|
||||
sstring << "attentionType:" << static_cast<int32_t>(mAttentionConfig.mAttentionType) << "\n";
|
||||
|
||||
@ -102,11 +102,13 @@ public:
|
||||
{
|
||||
public:
|
||||
TensorPtr draftLogits;
|
||||
TensorPtr draftLogitsHost;
|
||||
TensorPtr draftProbs;
|
||||
TensorPtr targetProbs;
|
||||
TensorPtr numDraftTokens;
|
||||
TensorPtr numDraftTokensHost;
|
||||
TensorPtr draftTokenIds;
|
||||
TensorPtr draftTokenIdsHost;
|
||||
TensorPtr useDraftLogits;
|
||||
TensorPtr useDraftLogitsHost;
|
||||
|
||||
|
||||
@ -1,54 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/executor/executor.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
|
||||
#include <optional>
|
||||
|
||||
namespace tensorrt_llm::runtime::decoder_batch
|
||||
{
|
||||
|
||||
class Request
|
||||
{
|
||||
public:
|
||||
using TensorConstPtr = ITensor::SharedConstPtr;
|
||||
using TensorPtr = ITensor::SharedPtr;
|
||||
using BufferPtr = IBuffer::SharedPtr;
|
||||
|
||||
explicit Request(SizeType32 inputLen)
|
||||
: inputLen(inputLen)
|
||||
{
|
||||
}
|
||||
|
||||
//! Mandatory parameters
|
||||
SizeType32 inputLen; // Input length without draft tokens, increasing with generation steps
|
||||
|
||||
// optional parameters
|
||||
SizeType32 generatedTokensPerEngineStep{1}; //
|
||||
|
||||
//! Optional parameters for speculative decoding
|
||||
BufferPtr draftTokens; // [generatedTokensPerEngineStep - 1] on gpu
|
||||
std::optional<TensorPtr> draftLogits; // [generatedTokensPerEngineStep - 1, vocabSize] on gpu
|
||||
TensorPtr medusaPaths; // [maxDecodingTokens, maxPathLen], on gpu
|
||||
TensorPtr medusaTreeIds; // [maxDecodingTokens], on gpu
|
||||
std::optional<executor::LookaheadDecodingConfig> lookaheadRuntimeConfig;
|
||||
std::optional<executor::EagleConfig> eagleConfig;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime::decoder_batch
|
||||
@ -1012,7 +1012,7 @@ CUBIN_EXPORT __global__
|
||||
if (threadIdx.x < smem.gemm1AccColMax.size)
|
||||
{
|
||||
auto const idx = threadIdx.x;
|
||||
smem.gemm1AccColMax[idx] = mha::numeric_limits<float>::lowest();
|
||||
smem.gemm1AccColMax[idx] = safeInitRowMax;
|
||||
smem.gemm1AccColSum[idx] = 0;
|
||||
}
|
||||
smem.gemm1WarpGrpBar.arrive_and_wait();
|
||||
@ -1949,7 +1949,7 @@ __device__ inline void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec,
|
||||
uint32_t const globalRow = tileStartRow + row;
|
||||
if (globalRow >= cacheSeqLen)
|
||||
{
|
||||
acc(m, n)(i, j) = mha::numeric_limits<float>::lowest();
|
||||
acc(m, n)(i, j) = safeInitRowMax;
|
||||
continue;
|
||||
}
|
||||
if (globalRow >= maskStartRow)
|
||||
@ -1957,7 +1957,7 @@ __device__ inline void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec,
|
||||
uint32_t const maskRow = globalRow - maskStartRow;
|
||||
if ((bit_mask >> maskRow) == 0)
|
||||
{
|
||||
acc(m, n)(i, j) = mha::numeric_limits<float>::lowest();
|
||||
acc(m, n)(i, j) = safeInitRowMax;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -2087,7 +2087,7 @@ __device__ inline void warpGrpApplyMask(uint32_t warpRank, Gemm0Acc& acc, uint32
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++)
|
||||
{
|
||||
acc(m, n)(i, j) = mha::numeric_limits<float>::lowest();
|
||||
acc(m, n)(i, j) = safeInitRowMax;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -2380,9 +2380,9 @@ __device__ inline void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec,
|
||||
{
|
||||
uint32_t const col = GmmaAccCoreMat::cols * (4 * n + idxInQuad) + j;
|
||||
assert((col < nbValidCols) == bool(endMask & (1ULL << col)));
|
||||
if (((mask >> col) & 1) == 0)
|
||||
if ((mask & (1ULL << col)) == 0)
|
||||
{
|
||||
acc(m, n)(i, j) = mha::numeric_limits<float>::lowest();
|
||||
acc(m, n)(i, j) = safeInitRowMax;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -2410,7 +2410,7 @@ __device__ inline void warpGrpApplyMask(Gemm0Acc& acc, uint32_t validColBeg, uin
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++)
|
||||
{
|
||||
acc(m, n)(i, j) = mha::numeric_limits<float>::lowest();
|
||||
acc(m, n)(i, j) = safeInitRowMax;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -833,7 +833,7 @@ public:
|
||||
// Runs for 3 iterations or 1 second and picks the best option
|
||||
int pickBestTactic(MOEParallelismConfig parallelism_config, GemmToProfile gemm_to_profile)
|
||||
{
|
||||
auto tactics = mMoERunner.getTactics();
|
||||
auto tactics = mMoERunner.getTactics(static_cast<MoeGemmId>(gemm_to_profile));
|
||||
::nvtx3::scoped_range nvtx(tensorrt_llm::common::nvtx::nextColor(),
|
||||
"Tactic Profiling GEMM " + std::to_string(static_cast<int>(gemm_to_profile)));
|
||||
// We save space by reusing the same workspace buffer for all tactics when doing full layer profiling. So we
|
||||
@ -925,12 +925,14 @@ public:
|
||||
std::pair<int, int> setTactic(
|
||||
int tactic_idx1, int tactic_idx2, MOEParallelismConfig parallelism_config, GemmToProfile gemm_to_profile)
|
||||
{
|
||||
auto tactics = mMoERunner.getTactics();
|
||||
auto tactics1 = mMoERunner.getTactics(MoeGemmId::GEMM_1);
|
||||
auto tactics2 = mMoERunner.getTactics(MoeGemmId::GEMM_2);
|
||||
std::vector<std::pair<std::reference_wrapper<int>, GemmToProfile>> tactics_to_profile{
|
||||
{tactic_idx1, GemmToProfile::GEMM_1}, {tactic_idx2, GemmToProfile::GEMM_2}};
|
||||
for (auto& combo : tactics_to_profile)
|
||||
{
|
||||
auto& t = combo.first.get();
|
||||
auto& tactics = combo.second == GemmToProfile::GEMM_1 ? tactics1 : tactics2;
|
||||
if (combo.second != gemm_to_profile && gemm_to_profile != GemmToProfile::LAYER)
|
||||
{
|
||||
t = 0; // Unneeded tactic, set to 0
|
||||
@ -947,7 +949,7 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
mMoERunner.setTactic(tactics[tactic_idx1], tactics[tactic_idx2]);
|
||||
mMoERunner.setTactic(tactics1[tactic_idx1], tactics2[tactic_idx2]);
|
||||
mBestTacticGemm1 = tactic_idx1;
|
||||
mBestTacticGemm2 = tactic_idx2;
|
||||
return {tactic_idx1, tactic_idx2};
|
||||
@ -965,7 +967,7 @@ public:
|
||||
auto expert_weights_size
|
||||
= gemm_to_profile == GemmToProfile::GEMM_1 ? mExpertWeight1Size : mExpertWeight2Size;
|
||||
|
||||
auto tactics = mMoERunner.getTactics()[tactic_idx];
|
||||
auto tactics = mMoERunner.getTactics(static_cast<MoeGemmId>(gemm_to_profile))[tactic_idx];
|
||||
if (static_cast<int>(gemm_to_profile) != static_cast<int>(mGemmProfilerBackend.mGemmToProfile))
|
||||
{
|
||||
throw std::runtime_error("Configuration mismatch between mGemmProfilerBackend and runMoEPermute");
|
||||
@ -1074,11 +1076,12 @@ void MixtureOfExpertsBenchmark<TypeTuple_>::runBenchmark(benchmark::State& state
|
||||
}
|
||||
if (LOG_LEVEL >= INFO)
|
||||
{
|
||||
auto tactics = mMoERunner.getTactics();
|
||||
std::cout << "Selected tactic #1: " << tactic_idx1 << "/" << tactics.size() << "\n"
|
||||
<< tactics[tactic_idx1].toString() << std::endl;
|
||||
std::cout << "Selected tactic #2: " << tactic_idx2 << "/" << tactics.size() << "\n"
|
||||
<< tactics[tactic_idx2].toString() << std::endl;
|
||||
auto tactics1 = mMoERunner.getTactics(MoeGemmId::GEMM_1);
|
||||
auto tactics2 = mMoERunner.getTactics(MoeGemmId::GEMM_2);
|
||||
std::cout << "Selected tactic #1: " << tactic_idx1 << "/" << tactics1.size() << "\n"
|
||||
<< tactics1[tactic_idx1].toString() << std::endl;
|
||||
std::cout << "Selected tactic #2: " << tactic_idx2 << "/" << tactics2.size() << "\n"
|
||||
<< tactics2[tactic_idx2].toString() << std::endl;
|
||||
}
|
||||
state.counters["tactic_idx1"] = tactic_idx1;
|
||||
state.counters["tactic_idx2"] = tactic_idx2;
|
||||
|
||||
@ -42,148 +42,15 @@ struct WeightParams
|
||||
->Apply(argGen<MixtureOfExpertsBenchmark<WeightParams<atype, wtype, otype>>>)
|
||||
|
||||
template <class BenchClass>
|
||||
auto listAllTactics()
|
||||
auto listAllTactics(MoeGemmId gemm_id)
|
||||
{
|
||||
int const sm = getSMVersion();
|
||||
using RunnerType = decltype(BenchClass::mMoERunner);
|
||||
return RunnerType::getTactics(sm);
|
||||
return RunnerType::getTactics(sm, gemm_id);
|
||||
}
|
||||
|
||||
template <class BenchClass>
|
||||
int parseTacticToId(nlohmann::json tactic_config)
|
||||
{
|
||||
bool is_tma_warp_specialized = tactic_config.at("is_tma_warp_specialized").get<bool>();
|
||||
int tile_shape_id = -1;
|
||||
std::array<int, 3> tile_shape;
|
||||
if (tactic_config.at("tile_shape").is_array())
|
||||
tactic_config.at("tile_shape").get_to(tile_shape);
|
||||
else
|
||||
tile_shape_id = tactic_config.at("tile_shape").get<int>();
|
||||
|
||||
std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> confs = listAllTactics<BenchClass>();
|
||||
|
||||
try
|
||||
{
|
||||
for (int i = 0; i < confs.size(); i++)
|
||||
{
|
||||
auto const& c = confs[i];
|
||||
if (c.is_tma_warp_specialized != is_tma_warp_specialized)
|
||||
continue;
|
||||
|
||||
if (!is_tma_warp_specialized)
|
||||
{
|
||||
int stages = tactic_config.at("stages").get<int>();
|
||||
if (c.stages != stages)
|
||||
continue;
|
||||
}
|
||||
|
||||
if (tile_shape_id != -1)
|
||||
{
|
||||
int comp = c.getTileConfigAsInt();
|
||||
if (tile_shape_id != comp)
|
||||
continue;
|
||||
if (is_tma_warp_specialized && (int) c.cluster_shape != tactic_config.at("cluster_shape").get<int>())
|
||||
continue;
|
||||
|
||||
// Found matching config
|
||||
return i;
|
||||
}
|
||||
|
||||
// Handle if the user provided a shape instead of the enum value
|
||||
if (is_tma_warp_specialized)
|
||||
{
|
||||
// TODO Add cases for blackwell shapes
|
||||
using Kv = uint64_t;
|
||||
constexpr static auto K = [](int m, int n) { return (uint64_t(m) << 32) | uint64_t(n); };
|
||||
static std::unordered_map<Kv, CutlassTileConfigSM90> const tile_map{
|
||||
{K(64, 16), CutlassTileConfigSM90::CtaShape64x16x128B},
|
||||
{K(64, 32), CutlassTileConfigSM90::CtaShape64x32x128B},
|
||||
{K(64, 64), CutlassTileConfigSM90::CtaShape64x64x128B},
|
||||
{K(64, 128), CutlassTileConfigSM90::CtaShape64x128x128B},
|
||||
{K(64, 256), CutlassTileConfigSM90::CtaShape64x256x128B},
|
||||
|
||||
{K(128, 16), CutlassTileConfigSM90::CtaShape128x16x128B},
|
||||
{K(128, 32), CutlassTileConfigSM90::CtaShape128x32x128B},
|
||||
{K(128, 64), CutlassTileConfigSM90::CtaShape128x64x128B},
|
||||
{K(128, 128), CutlassTileConfigSM90::CtaShape128x128x128B},
|
||||
{K(128, 256), CutlassTileConfigSM90::CtaShape128x256x128B},
|
||||
{K(256, 128), CutlassTileConfigSM90::CtaShape256x128x128B},
|
||||
};
|
||||
|
||||
if (c.getTileConfigAsInt() != (int) tile_map.at(K(tile_shape[0], tile_shape[1])))
|
||||
continue;
|
||||
|
||||
static std::unordered_map<Kv, ClusterShape> const cluster_map{
|
||||
// CTA configs for M=64
|
||||
{K(1, 1), ClusterShape::ClusterShape_1x1x1},
|
||||
{K(2, 1), ClusterShape::ClusterShape_2x1x1},
|
||||
{K(1, 2), ClusterShape::ClusterShape_1x2x1},
|
||||
{K(2, 2), ClusterShape::ClusterShape_2x2x1},
|
||||
};
|
||||
|
||||
std::array<int, 3> cluster_shape;
|
||||
tactic_config.at("cluster_shape").get_to(cluster_shape);
|
||||
|
||||
if (c.cluster_shape != cluster_map.at(K(cluster_shape[0], cluster_shape[1])))
|
||||
continue;
|
||||
|
||||
// Found matching config
|
||||
return i;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::array<int, 3> warp_shape;
|
||||
tactic_config.at("warp_shape").get_to(warp_shape);
|
||||
|
||||
using Kv = uint64_t;
|
||||
constexpr static auto K = [](std::array<int, 3> a, std::array<int, 3> b)
|
||||
{
|
||||
uint64_t sum = 0;
|
||||
for (auto v : a)
|
||||
sum = sum * 512 + v;
|
||||
for (auto v : b)
|
||||
sum = sum * 256 + v;
|
||||
return sum;
|
||||
};
|
||||
static std::unordered_map<Kv, CutlassTileConfig> tile_map{
|
||||
{K({128, 128, 8}, {64, 64, 8}), CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8},
|
||||
|
||||
{K({16, 128, 64}, {16, 32, 64}), CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64},
|
||||
{K({32, 128, 64}, {32, 32, 64}), CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64},
|
||||
|
||||
{K({64, 128, 64}, {32, 64, 64}), CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64},
|
||||
{K({64, 64, 128}, {32, 64, 64}), CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64},
|
||||
{K({64, 128, 64}, {64, 32, 64}), CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64},
|
||||
|
||||
{K({128, 64, 64}, {64, 32, 64}), CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64},
|
||||
{K({128, 128, 64}, {64, 32, 64}), CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64},
|
||||
{K({128, 128, 64}, {64, 64, 64}), CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64},
|
||||
{K({128, 128, 64}, {64, 32, 64}), CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64},
|
||||
{K({128, 256, 64}, {64, 64, 64}), CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64},
|
||||
|
||||
{K({256, 128, 64}, {64, 64, 64}), CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64},
|
||||
|
||||
{K({16, 256, 64}, {16, 64, 64}), CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64}
|
||||
|
||||
};
|
||||
if (c.tile_config_sm80 != tile_map.at(K(tile_shape, warp_shape)))
|
||||
continue;
|
||||
|
||||
// Found matching config
|
||||
return i;
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (std::out_of_range const& e)
|
||||
{
|
||||
std::cerr << "Warning: error parsing tactic " << tactic_config.dump(2) << std::endl;
|
||||
}
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
template <class BenchClass>
|
||||
void parseTacticToVectorID(nlohmann::json& tactic, std::vector<int>& tactic_ids)
|
||||
void parseTacticToVectorID(nlohmann::json& tactic, std::vector<int>& tactic_ids, MoeGemmId gemm_id)
|
||||
{
|
||||
if (tactic.is_number_integer())
|
||||
{
|
||||
@ -193,20 +60,16 @@ void parseTacticToVectorID(nlohmann::json& tactic, std::vector<int>& tactic_ids)
|
||||
{
|
||||
for (auto c : tactic)
|
||||
{
|
||||
parseTacticToVectorID<BenchClass>(c, tactic_ids);
|
||||
parseTacticToVectorID<BenchClass>(c, tactic_ids, gemm_id);
|
||||
}
|
||||
}
|
||||
else if (tactic.is_object())
|
||||
{
|
||||
tactic_ids.push_back(parseTacticToId<BenchClass>(tactic));
|
||||
}
|
||||
else if (tactic.is_string())
|
||||
{
|
||||
assert(tactic.is_string());
|
||||
auto tactic_name = tactic.get<std::string>();
|
||||
if (tactic_name == "all")
|
||||
{
|
||||
auto all_tactics = listAllTactics<BenchClass>();
|
||||
auto all_tactics = listAllTactics<BenchClass>(gemm_id);
|
||||
tactic_ids.resize(all_tactics.size());
|
||||
std::iota(tactic_ids.begin(), tactic_ids.end(), 0);
|
||||
}
|
||||
@ -410,39 +273,15 @@ void argGenLoadFile(benchmark::internal::Benchmark* benchmark)
|
||||
}
|
||||
|
||||
// Do this after filtering datatypes as tactics only make sense if we know the data type
|
||||
bool has_tactic_ids2 = false;
|
||||
std::vector<int> tactic_ids1{};
|
||||
std::vector<int> tactic_ids2{};
|
||||
if (run_config.contains("tactic_id1") || run_config.contains("tactic_id2"))
|
||||
if (run_config.contains("tactic_id1"))
|
||||
{
|
||||
if (run_config.contains("tactic_id"))
|
||||
{
|
||||
throw std::invalid_argument("Cannot use tactic_id and tactic_idX");
|
||||
}
|
||||
has_tactic_ids2 = true;
|
||||
parseTacticToVectorID<BenchClass>(run_config["tactic_id1"], tactic_ids1);
|
||||
parseTacticToVectorID<BenchClass>(run_config["tactic_id2"], tactic_ids2);
|
||||
parseTacticToVectorID<BenchClass>(run_config["tactic_id1"], tactic_ids1, MoeGemmId::GEMM_1);
|
||||
}
|
||||
else
|
||||
if (run_config.contains("tactic_id2"))
|
||||
{
|
||||
parseTacticToVectorID<BenchClass>(run_config["tactic_id"], tactic_ids1);
|
||||
has_tactic_ids2 = false;
|
||||
tactic_ids2.resize(1); // Dummy value so we loop exactly once below
|
||||
}
|
||||
if (tactic_ids1.empty() || tactic_ids2.empty())
|
||||
{
|
||||
std::cerr << "Warning: Skipping benchmark, no valid tactic found" << std::endl;
|
||||
static bool printed = false;
|
||||
if (!printed)
|
||||
{
|
||||
printed = true;
|
||||
std::cerr << __PRETTY_FUNCTION__ << ": Valid Tactics are:\n";
|
||||
auto confs = listAllTactics<BenchClass>();
|
||||
for (auto c : confs)
|
||||
std::cerr << c.toString();
|
||||
}
|
||||
|
||||
continue;
|
||||
parseTacticToVectorID<BenchClass>(run_config["tactic_id2"], tactic_ids2, MoeGemmId::GEMM_2);
|
||||
}
|
||||
|
||||
auto get_or = [&](auto name, auto def)
|
||||
@ -478,8 +317,6 @@ void argGenLoadFile(benchmark::internal::Benchmark* benchmark)
|
||||
}
|
||||
else if (gemm_to_profile == (int) GemmToProfile::GEMM_2)
|
||||
{
|
||||
if (!has_tactic_ids2)
|
||||
tactic_ids2 = std::move(tactic_ids1);
|
||||
tactic_ids1 = {-1};
|
||||
}
|
||||
}
|
||||
@ -494,14 +331,31 @@ void argGenLoadFile(benchmark::internal::Benchmark* benchmark)
|
||||
return val;
|
||||
};
|
||||
|
||||
if (tactic_ids1.empty() || tactic_ids2.empty())
|
||||
{
|
||||
std::cerr << "Warning: Skipping benchmark, no valid tactic found" << std::endl;
|
||||
static bool printed = false;
|
||||
if (!printed)
|
||||
{
|
||||
printed = true;
|
||||
std::cerr << __PRETTY_FUNCTION__ << ": Valid Tactics are:\n";
|
||||
for (auto gemm_id : {MoeGemmId::GEMM_1, MoeGemmId::GEMM_2})
|
||||
{
|
||||
std::cerr << "GEMM " << (int) gemm_id << ":\n";
|
||||
auto confs = listAllTactics<BenchClass>(gemm_id);
|
||||
for (auto c : confs)
|
||||
std::cerr << c.toString();
|
||||
std::cerr << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
for (auto t1 : tactic_ids1)
|
||||
{
|
||||
// tactic_ids2 will have one dummy value if has_tactic_ids2 = false
|
||||
for (auto t2 : tactic_ids2)
|
||||
{
|
||||
if (!has_tactic_ids2)
|
||||
t2 = t1;
|
||||
|
||||
benchmark->Args({num_experts, //
|
||||
get_range("k"), //
|
||||
get_range("hidden_size"), //
|
||||
@ -531,7 +385,7 @@ void argGenHardcoded(benchmark::internal::Benchmark* benchmark)
|
||||
// {ActivationType::Relu, ActivationType::Gelu,
|
||||
// ActivationType::Silu, ActivationType::Geglu,
|
||||
// ActivationType::Swiglu};
|
||||
auto cutlass_tactic = {-1}; // {0,..., listAllTactics<BenchClass>().size()};
|
||||
auto cutlass_tactic = {-1}; // {0,..., listAllTactics<BenchClass>(MoeGemmId).size()};
|
||||
auto routing_config = {LOAD_BALANCED_ROUTING_CONFIG}; // {0, 1, 2};
|
||||
|
||||
for (auto num_expert : num_experts)
|
||||
@ -558,14 +412,18 @@ void argGen(benchmark::internal::Benchmark* benchmark)
|
||||
{
|
||||
if (LOG_LEVEL >= VERBOSE)
|
||||
{
|
||||
std::cout << "List of all tactics for dtype " << (int) BenchClass::toDTypeID() << ":\n";
|
||||
int i = 0;
|
||||
for (auto& t : listAllTactics<BenchClass>())
|
||||
std::cout << "== List of all tactics for dtype " << (int) BenchClass::toDTypeID() << " ==\n";
|
||||
for (auto gemm_id : {MoeGemmId::GEMM_1, MoeGemmId::GEMM_2})
|
||||
{
|
||||
std::cout << "Tactic " << i << ":\n";
|
||||
std::cout << t.toString() << std::endl;
|
||||
int i = 0;
|
||||
std::cout << "=== GEMM " << (int) gemm_id << " ===\n";
|
||||
for (auto& t : listAllTactics<BenchClass>(gemm_id))
|
||||
{
|
||||
std::cout << "==== Tactic " << i << " ====\n";
|
||||
std::cout << t.toString() << std::endl;
|
||||
|
||||
i++;
|
||||
i++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -652,7 +510,6 @@ void help()
|
||||
" \"bias\": int, (optional)\n"
|
||||
" \"do_final_scale\": int, (optional)\n"
|
||||
" \"act_fn\": int,\n"
|
||||
" \"tactic_id\": tactic, (see below)\n"
|
||||
" \"tactic_id1\": tactic, (see below)\n"
|
||||
" \"tactic_id2\": tactic, (see below)\n"
|
||||
" \"dtypes\": [string, ...], (optional)\n"
|
||||
@ -676,27 +533,14 @@ void help()
|
||||
"- \"do_final_scale\" - If final scales should be applied, 0 = no scale, 1 = scale\n"
|
||||
"- \"act_fn\" - The activation function to use, 0 = identity, 1 = relu, 2 = gelu, 3 = silu, 4 = geglu, 5 = "
|
||||
"swiglu\n"
|
||||
"- \"tactic_id, tactic_id1, tactic_id2\"\n"
|
||||
"The config for the CUTLASS GEMM. tactic_id sets the same tactic for both to the same tactic (except in "
|
||||
"auto mode)\n"
|
||||
"Use tactic_idX to set the tactic for the corresponding GEMM"
|
||||
"- \"tactic_id1, tactic_id2\"\n"
|
||||
"The config for the CUTLASS GEMM. tactic_idX sets the tactic for the corresponding GEMM"
|
||||
"Valid tactics are:\n"
|
||||
" - An object:\n"
|
||||
" {\n"
|
||||
" \"is_tma_warp_specialized\": bool,\n"
|
||||
" \"tile_shape\": [int, int, int] or int,\n"
|
||||
" \"cluster_shape\": [int, int, int] or int, (required for sm90, type must be an int if tile_shape "
|
||||
"is "
|
||||
"an int)\n"
|
||||
" \"warp_shape\": [int, int, int], (required for non-sm90 if tile_shape is an array)\n"
|
||||
" \"stages\": int, (required for non-sm90)\n"
|
||||
" },\n"
|
||||
" - An integer: corresponds to an index in the tactics array. WARNING this is not stable between test "
|
||||
"configurations\n"
|
||||
" - An array: of integers or objects, forms a list of tactics to sweep\n"
|
||||
" - An integer: corresponds to an index in the tactics array. WARNING this is not stable between data types "
|
||||
"or GPU architectures\n"
|
||||
" - An array: of integers, forms a list of tactics to sweep\n"
|
||||
" - The string \"all\": This will sweep through all possible tactics\n"
|
||||
" - The string \"auto\": This runs a short benchmark to pick the fastest tactic before each benchmark "
|
||||
"case. "
|
||||
" - The string \"auto\": This runs a short benchmark to pick the fastest tactic before each benchmark case. "
|
||||
"Useful for quick perf tests, prefer a full sweep and manually setting the tactic for more accurate "
|
||||
"results"
|
||||
"- dtypes - A list of dtypes to run this config through.\n"
|
||||
|
||||
@ -294,8 +294,7 @@ if(TARGET ${NIXL_WRAPPER_TARGET})
|
||||
endif()
|
||||
|
||||
if(NOT WIN32)
|
||||
set_target_properties(${SHARED_TARGET} PROPERTIES LINK_FLAGS
|
||||
"-Wl,-rpath='$ORIGIN'")
|
||||
set_target_properties(${SHARED_TARGET} PROPERTIES BUILD_RPATH "$ORIGIN")
|
||||
endif()
|
||||
|
||||
if(BUILD_PYT)
|
||||
|
||||
@ -822,6 +822,14 @@ void CacheFormatter::unformat(TransferSession& session)
|
||||
TLLM_LOG_WARNING("CacheFormatter::inquireSupport: only support non-MLA");
|
||||
return false;
|
||||
}
|
||||
if (selfConfig.getParallelConfig().mContextParallelism != 1
|
||||
|| destConfig.getParallelConfig().mContextParallelism != 1)
|
||||
{
|
||||
TLLM_LOG_WARNING(
|
||||
"CacheFormatter::inquireSupport: context parallelism is not currently supported (selfCP=%d, destCP=%d).",
|
||||
selfConfig.getParallelConfig().mContextParallelism, destConfig.getParallelConfig().mContextParallelism);
|
||||
return false;
|
||||
}
|
||||
|
||||
std::unordered_set<int> setVecDest{
|
||||
destConfig.getModelConfig().mNbKvHeadsPerLayer.begin(), destConfig.getModelConfig().mNbKvHeadsPerLayer.end()};
|
||||
|
||||
@ -20,11 +20,14 @@
|
||||
#include "tensorrt_llm/batch_manager/llmRequest.h"
|
||||
#include "tensorrt_llm/batch_manager/medusaBuffers.h"
|
||||
#include "tensorrt_llm/batch_manager/utils/logitsThread.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/common/nvtxUtils.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/decoderState.h"
|
||||
#include "tensorrt_llm/runtime/decodingInput.h"
|
||||
#include "tensorrt_llm/runtime/decodingOutput.h"
|
||||
#include "tensorrt_llm/runtime/iBuffer.h"
|
||||
#include "tensorrt_llm/runtime/runtimeKernels.h"
|
||||
#include "tensorrt_llm/runtime/speculativeDecodingMode.h"
|
||||
#include "tensorrt_llm/runtime/utils/mpiUtils.h"
|
||||
@ -45,6 +48,8 @@ namespace tensorrt_llm::batch_manager
|
||||
using SizeType32 = CreateNewDecoderRequests::SizeType32;
|
||||
using TensorPtr = CreateNewDecoderRequests::TensorPtr;
|
||||
using SharedConstPtr = CreateNewDecoderRequests::SharedConstPtr;
|
||||
template <typename T>
|
||||
using OptionalRef = tensorrt_llm::common::OptionalRef<T>;
|
||||
|
||||
namespace
|
||||
{
|
||||
@ -320,149 +325,165 @@ void initializeOutputs(DecodingOutput& dJointOutput, SizeType32 batchSlot, SizeT
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void CreateNewDecoderRequests::newRequestSpeculativeDecoding(SizeType32 batchIdx,
|
||||
runtime::decoder_batch::Request const& request, SamplingConfig const& samplingConfig,
|
||||
runtime::ModelConfig const& modelConfig, DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput,
|
||||
CudaStream const& runtimeStream, CudaStream const& decoderStream,
|
||||
SpeculativeDecodingMode const& speculativeDecodingMode, SizeType32 maxDecodingEngineTokens)
|
||||
void retrieveDraftLogits(TensorPtr& draftLogitsHost, std::shared_ptr<runtime::ITensor> const& reqDraftLogits,
|
||||
ModelConfig const& modelConfig, WorldConfig const& worldConfig, bool speculativeDecodingFastLogits,
|
||||
bool isLeaderInOrchMode, BufferManager const& bufferManager)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
if (speculativeDecodingMode.predictsDraftTokens())
|
||||
if (!speculativeDecodingFastLogits)
|
||||
{
|
||||
auto const& stream = decoderStream;
|
||||
BufferManager manager{std::make_shared<CudaStream>(stream.get())};
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
bufferManager.copy(*reqDraftLogits, *draftLogitsHost);
|
||||
return;
|
||||
}
|
||||
|
||||
auto& dJointOutput = jointDecodingOutput;
|
||||
if (isLeaderInOrchMode)
|
||||
{
|
||||
// reqDraftLogits contains metadata for fast-logits path; validate size.
|
||||
auto constexpr fastLogitsInfoSize = sizeof(te::SpeculativeDecodingFastLogitsInfo);
|
||||
TLLM_CHECK_WITH_INFO(reqDraftLogits->getSizeInBytes() >= fastLogitsInfoSize,
|
||||
"Draft logits metadata buffer is too small to hold SpeculativeDecodingFastLogitsInfo.");
|
||||
te::SpeculativeDecodingFastLogitsInfo fastLogitsInfo{};
|
||||
std::memcpy(&fastLogitsInfo, reqDraftLogits->data(), fastLogitsInfoSize);
|
||||
utils::targetModelReceiveLogits(draftLogitsHost, fastLogitsInfo, modelConfig.getLogitsDtype());
|
||||
|
||||
TensorPtr nextDraftTokens
|
||||
= ITensor::slice(dJointOutput.speculativeDecodingOutputs->nextDraftTokens, batchIdx, 1);
|
||||
// FIXME: can we skip this?
|
||||
manager.setZero(*nextDraftTokens);
|
||||
if (speculativeDecodingMode.variableDraftLength())
|
||||
// Broadcast to other ranks if needed
|
||||
if (worldConfig.isTensorParallel())
|
||||
{
|
||||
TensorPtr nextDraftTokensLen
|
||||
= ITensor::slice(dJointOutput.speculativeDecodingOutputs->nextDraftTokensLen, batchIdx, 1);
|
||||
manager.setZero(*nextDraftTokensLen);
|
||||
auto const& commSession = COMM_SESSION;
|
||||
auto shape = draftLogitsHost->getShape();
|
||||
commSession.bcastValue(shape.d[0], 0);
|
||||
commSession.bcastValue(shape.d[1], 0);
|
||||
commSession.bcast(draftLogitsHost->data(), draftLogitsHost->getSizeInBytes(), mpi::MpiType::kUINT8, 0);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(worldConfig.isTensorParallel(),
|
||||
"Fast logits path requires tensor-parallel broadcast for non-leader ranks.");
|
||||
|
||||
if (speculativeDecodingMode.isDraftTokensExternal())
|
||||
{
|
||||
newRequestDraftTokensExternal(batchIdx, request, samplingConfig, jointDecodingInput, decoderStream);
|
||||
}
|
||||
else if (speculativeDecodingMode.isMedusa())
|
||||
{
|
||||
newRequestMedusa(batchIdx, request, jointDecodingInput, decoderStream, maxDecodingEngineTokens);
|
||||
}
|
||||
else if (speculativeDecodingMode.isLookaheadDecoding())
|
||||
{
|
||||
newRequestLookahead(batchIdx, request, jointDecodingInput, jointDecodingOutput, runtimeStream);
|
||||
}
|
||||
else if (speculativeDecodingMode.isExplicitDraftTokens())
|
||||
{
|
||||
newRequestExplicitDraftTokens(batchIdx, request, jointDecodingOutput, runtimeStream);
|
||||
}
|
||||
else if (speculativeDecodingMode.isEagle())
|
||||
{
|
||||
newRequestEagle(batchIdx, request, modelConfig, jointDecodingOutput, runtimeStream);
|
||||
// Get logits from leader rank
|
||||
auto const& commSession = COMM_SESSION;
|
||||
int64_t dims[2];
|
||||
commSession.bcastValue(dims[0], 0);
|
||||
commSession.bcastValue(dims[1], 0);
|
||||
draftLogitsHost->reshape(ITensor::makeShape({dims[0], dims[1]}));
|
||||
commSession.bcast(draftLogitsHost->data(), draftLogitsHost->getSizeInBytes(), mpi::MpiType::kUINT8, 0);
|
||||
}
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
};
|
||||
|
||||
void CreateNewDecoderRequests::newRequestDraftTokensExternal(SizeType32 batchIdx,
|
||||
runtime::decoder_batch::Request const& request, SamplingConfig const& samplingConfig,
|
||||
DecodingInput& jointDecodingInput, CudaStream const& decoderStream)
|
||||
//! @brief Setups decoder internal tensors for new request in Draft model Sps mode
|
||||
void newRequestDraftTokensExternal(DecodingInput& jointDecodingInput, SizeType32 batchIdx, LlmRequest const& llmReq,
|
||||
SizeType32 numDecodingEngineTokens, runtime::ModelConfig const& modelConfig, WorldConfig const& worldConfig,
|
||||
bool speculativeDecodingFastLogits, bool isLeaderInOrchMode, CudaStream const& decoderStream)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
BufferManager manager{std::make_shared<CudaStream>(decoderStream.get())};
|
||||
BufferManager decoderBufferManager{std::make_shared<CudaStream>(decoderStream.get())};
|
||||
|
||||
auto& dJointInput = jointDecodingInput;
|
||||
TLLM_CHECK(jointDecodingInput.externalDraftTokensInputs);
|
||||
auto& externalDraftTokensInputs = jointDecodingInput.externalDraftTokensInputs;
|
||||
|
||||
auto const numDraftTokens = request.generatedTokensPerEngineStep - 1;
|
||||
auto const& draftTokens = llmReq.getDraftTokens();
|
||||
auto const numDraftTokens = numDecodingEngineTokens - 1;
|
||||
|
||||
auto const useDraftLogits = request.draftLogits.has_value();
|
||||
if (useDraftLogits)
|
||||
{
|
||||
TensorPtr draftLogitsView = ITensor::view(request.draftLogits.value());
|
||||
|
||||
TensorPtr draftLogitsReqBatchSlice
|
||||
= ITensor::slice(dJointInput.externalDraftTokensInputs->draftLogits, batchIdx, 1);
|
||||
draftLogitsReqBatchSlice->squeeze(0);
|
||||
TensorPtr draftLogitsReqTokensSlice = ITensor::slice(draftLogitsReqBatchSlice, 0, numDraftTokens);
|
||||
manager.copy(*draftLogitsView, *draftLogitsReqTokensSlice);
|
||||
}
|
||||
auto* useDraftLogitsHostPtr = runtime::bufferCast<bool>(*dJointInput.externalDraftTokensInputs->useDraftLogitsHost);
|
||||
useDraftLogitsHostPtr[batchIdx] = useDraftLogits;
|
||||
auto useDraftLogitsView = ITensor::slice(dJointInput.externalDraftTokensInputs->useDraftLogits, batchIdx, 1);
|
||||
runtime::kernels::invokeFill(*useDraftLogitsView, useDraftLogits, decoderStream);
|
||||
auto numDraftTokensHostRange = runtime::BufferRange<SizeType32>(*externalDraftTokensInputs->numDraftTokensHost);
|
||||
numDraftTokensHostRange[batchIdx] = numDraftTokens;
|
||||
auto numDraftTokensView = ITensor::slice(externalDraftTokensInputs->numDraftTokens, batchIdx, 1);
|
||||
runtime::kernels::invokeFill(*numDraftTokensView, numDraftTokens, decoderStream);
|
||||
|
||||
if (numDraftTokens > 0)
|
||||
{
|
||||
TensorPtr draftTokensReqBatchSlice
|
||||
= ITensor::slice(dJointInput.externalDraftTokensInputs->draftTokenIds, batchIdx, 1);
|
||||
draftTokensReqBatchSlice->squeeze(0);
|
||||
TensorPtr draftTokensReqTokensSlice = ITensor::slice(draftTokensReqBatchSlice, 0, numDraftTokens);
|
||||
TensorPtr draftTokensView = ITensor::view(request.draftTokens, ITensor::makeShape({numDraftTokens}));
|
||||
manager.copy(*draftTokensView, *draftTokensReqTokensSlice);
|
||||
TensorPtr draftTokenIdsHostSlice
|
||||
= ITensor::slice(externalDraftTokensInputs->draftTokenIdsHost, {batchIdx, 0}, numDraftTokens);
|
||||
// Copy to pinned host memory (don't care about stream of bufferManager)
|
||||
decoderBufferManager.copy(draftTokens->data(), *draftTokenIdsHostSlice);
|
||||
|
||||
TensorPtr draftTokenIdsSlice
|
||||
= ITensor::slice(externalDraftTokensInputs->draftTokenIds, {batchIdx, 0}, numDraftTokens);
|
||||
decoderBufferManager.copy(*draftTokenIdsHostSlice, *draftTokenIdsSlice);
|
||||
}
|
||||
|
||||
auto* numDraftTokensHostPtr
|
||||
= runtime::bufferCast<SizeType32>(*dJointInput.externalDraftTokensInputs->numDraftTokensHost);
|
||||
numDraftTokensHostPtr[batchIdx] = numDraftTokens;
|
||||
auto numDraftTokensView = ITensor::slice(dJointInput.externalDraftTokensInputs->numDraftTokens, batchIdx, 1);
|
||||
runtime::kernels::invokeFill(*numDraftTokensView, numDraftTokens, decoderStream);
|
||||
auto const& draftLogits = llmReq.getDraftLogits();
|
||||
auto const useDraftLogits = draftLogits.has_value();
|
||||
|
||||
auto useDraftLogitsHostRange = runtime::BufferRange<bool>(*externalDraftTokensInputs->useDraftLogitsHost);
|
||||
useDraftLogitsHostRange[batchIdx] = useDraftLogits;
|
||||
auto useDraftLogitsView = ITensor::slice(externalDraftTokensInputs->useDraftLogits, batchIdx, 1);
|
||||
runtime::kernels::invokeFill(*useDraftLogitsView, useDraftLogits, decoderStream);
|
||||
|
||||
if (useDraftLogits)
|
||||
{
|
||||
TensorPtr draftLogitsHostSlice
|
||||
= ITensor::slice(externalDraftTokensInputs->draftLogitsHost, {batchIdx, 0}, numDraftTokens);
|
||||
retrieveDraftLogits(draftLogitsHostSlice, draftLogits.value(), modelConfig, worldConfig,
|
||||
speculativeDecodingFastLogits, isLeaderInOrchMode, decoderBufferManager);
|
||||
|
||||
TensorPtr draftLogitsSlice
|
||||
= ITensor::slice(externalDraftTokensInputs->draftLogits, {batchIdx, 0}, numDraftTokens);
|
||||
decoderBufferManager.copy(*draftLogitsHostSlice, *draftLogitsSlice);
|
||||
}
|
||||
|
||||
auto const& samplingConfig = llmReq.mSamplingConfig;
|
||||
bool const useRandomAcceptanceThreshold = !samplingConfig.draftAcceptanceThreshold.has_value();
|
||||
float const constantThreshold
|
||||
= useRandomAcceptanceThreshold ? 0 : samplingConfig.draftAcceptanceThreshold.value()[0];
|
||||
|
||||
dJointInput.externalDraftTokensInputs->useRandomAcceptanceThreshold = useRandomAcceptanceThreshold;
|
||||
dJointInput.externalDraftTokensInputs->constantThreshold = constantThreshold;
|
||||
externalDraftTokensInputs->useRandomAcceptanceThreshold = useRandomAcceptanceThreshold;
|
||||
externalDraftTokensInputs->constantThreshold = constantThreshold;
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void CreateNewDecoderRequests::newRequestMedusa(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
|
||||
DecodingInput& jointDecodingInput, CudaStream const& decoderStream, SizeType32 maxDecodingEngineTokens)
|
||||
//! @brief Setups decoder internal tensors for new Medusa request
|
||||
void newRequestMedusa(DecodingInput& jointDecodingInput, SizeType32 batchIdx, LlmRequest& llmReq,
|
||||
SizeType32 numDecodingEngineTokens, SizeType32 maxDecodingEngineTokens, MedusaBuffers const& medusaBuffers,
|
||||
CudaStream const& decoderStream)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
llmReq.mSamplingConfig.topKMedusaHeads = {medusaBuffers.mTopKs};
|
||||
// FIXME: we must set medusa paths and tree ids not from seq slot, but from llmRequest?
|
||||
// When multiple microbatches buffers are used, runtime buffers can not be addressed with seqSlot.
|
||||
auto medusaPaths = ITensor::slice(medusaBuffers.medusaPathsDevice, 0, 1);
|
||||
auto medusaTreeIds = ITensor::slice(medusaBuffers.medusaTreeIdsDevice, 0, 1);
|
||||
|
||||
BufferManager manager{std::make_shared<CudaStream>(decoderStream.get())};
|
||||
|
||||
auto& dJointInput = jointDecodingInput;
|
||||
auto& medusaInputs = jointDecodingInput.medusaInputs;
|
||||
|
||||
TensorPtr curTokensPerStepSlice
|
||||
= ITensor::slice(constPointerCast(dJointInput.medusaInputs->medusaCurTokensPerStep), batchIdx, 1);
|
||||
= ITensor::slice(constPointerCast(medusaInputs->medusaCurTokensPerStep), batchIdx, 1);
|
||||
// Context phase Medusa processes 1 token only, new value from targetTokensPerStep will be filled at the end
|
||||
// of first decoder
|
||||
runtime::kernels::invokeFill(*curTokensPerStepSlice, 1, decoderStream);
|
||||
TensorPtr targetTokensPerStepSlice
|
||||
= ITensor::slice(constPointerCast(dJointInput.medusaInputs->medusaTargetTokensPerStep), batchIdx, 1);
|
||||
auto const generatedTokensPerEngineStep = request.generatedTokensPerEngineStep;
|
||||
TLLM_CHECK_WITH_INFO(generatedTokensPerEngineStep <= maxDecodingEngineTokens,
|
||||
"Tokens per step for (%d) is larger than maximum tokens per step (%d)", generatedTokensPerEngineStep,
|
||||
= ITensor::slice(constPointerCast(medusaInputs->medusaTargetTokensPerStep), batchIdx, 1);
|
||||
TLLM_CHECK_WITH_INFO(numDecodingEngineTokens <= maxDecodingEngineTokens,
|
||||
"Tokens per step for (%d) is larger than maximum tokens per step (%d)", numDecodingEngineTokens,
|
||||
maxDecodingEngineTokens);
|
||||
runtime::kernels::invokeFill(*targetTokensPerStepSlice, generatedTokensPerEngineStep, decoderStream);
|
||||
runtime::kernels::invokeFill(*targetTokensPerStepSlice, numDecodingEngineTokens, decoderStream);
|
||||
|
||||
TensorPtr pathsSlice = ITensor::slice(constPointerCast(dJointInput.medusaInputs->medusaPaths), batchIdx, 1);
|
||||
manager.copy(*request.medusaPaths, *pathsSlice);
|
||||
TensorPtr pathsSlice = ITensor::slice(constPointerCast(medusaInputs->medusaPaths), batchIdx, 1);
|
||||
manager.copy(*medusaPaths, *pathsSlice);
|
||||
|
||||
TensorPtr treeIdsSlice = ITensor::slice(constPointerCast(dJointInput.medusaInputs->medusaTreeIds), batchIdx, 1);
|
||||
manager.copy(*request.medusaTreeIds, *treeIdsSlice);
|
||||
TensorPtr treeIdsSlice = ITensor::slice(constPointerCast(medusaInputs->medusaTreeIds), batchIdx, 1);
|
||||
manager.copy(*medusaTreeIds, *treeIdsSlice);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void CreateNewDecoderRequests::newRequestLookahead(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
|
||||
DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream)
|
||||
//! @brief Setups decoder internal tensors for new Lookahead request
|
||||
void newRequestLookahead(DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, SizeType32 batchIdx,
|
||||
CudaStream const& runtimeStream)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
TLLM_CHECK(jointDecodingOutput.lookaheadOutputs);
|
||||
TLLM_CHECK(jointDecodingInput.lookaheadInputs);
|
||||
|
||||
// The first generation step only generate 1 token.
|
||||
TensorPtr curTokensPerStepSlice
|
||||
@ -472,65 +493,72 @@ void CreateNewDecoderRequests::newRequestLookahead(SizeType32 batchIdx, runtime:
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void CreateNewDecoderRequests::newRequestExplicitDraftTokens(SizeType32 batchIdx,
|
||||
runtime::decoder_batch::Request const& request, DecodingOutput& jointDecodingOutput,
|
||||
CudaStream const& runtimeStream)
|
||||
//! @brief Setups decoder internal tensors for new Explicit draft tokens request
|
||||
void newRequestExplicitDraftTokens(
|
||||
DecodingOutput& jointDecodingOutput, SizeType32 batchIdx, LlmRequest const& llmReq, CudaStream const& runtimeStream)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
TLLM_CHECK(jointDecodingOutput.explicitDraftTokensBuffers);
|
||||
|
||||
auto const inputLen = llmReq.getPromptLen();
|
||||
|
||||
TensorPtr positionIdsBaseSlice
|
||||
= ITensor::slice(jointDecodingOutput.explicitDraftTokensBuffers->positionIdsBase, batchIdx, 1);
|
||||
runtime::kernels::invokeFill(*positionIdsBaseSlice, request.inputLen, runtimeStream);
|
||||
runtime::kernels::invokeFill(*positionIdsBaseSlice, inputLen, runtimeStream);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void CreateNewDecoderRequests::newRequestEagle(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
|
||||
runtime::ModelConfig const& modelConfig, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream)
|
||||
//! @brief Setups decoder internal tensors for new Eagle request
|
||||
void newRequestEagle(DecodingOutput& jointDecodingOutput, SizeType32 batchIdx, LlmRequest const& llmReq,
|
||||
runtime::ModelConfig const& modelConfig, executor::DecodingConfig const& decodingConfig,
|
||||
CudaStream const& runtimeStream)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
TLLM_CHECK(jointDecodingOutput.eagleBuffers);
|
||||
auto& eagleBuffers = *jointDecodingOutput.eagleBuffers;
|
||||
|
||||
auto const inputLen = llmReq.getPromptLen();
|
||||
|
||||
BufferManager manager{std::make_shared<CudaStream>(runtimeStream.get())};
|
||||
|
||||
TensorPtr eagleNetCtxRequestTypesHostSlice
|
||||
= ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetCtxRequestTypesHost, batchIdx, 1);
|
||||
TensorPtr eagleNetCtxRequestTypesHostSlice = ITensor::slice(eagleBuffers.eagleNetCtxRequestTypesHost, batchIdx, 1);
|
||||
TensorPtr eagleNetCtxContextLengthsHostSlice
|
||||
= ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetCtxContextLengthsHost, batchIdx, 1);
|
||||
= ITensor::slice(eagleBuffers.eagleNetCtxContextLengthsHost, batchIdx, 1);
|
||||
TensorPtr eagleNetCtxPastKeyValueLengthsHostSlice
|
||||
= ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetCtxPastKeyValueLengthsHost, batchIdx, 1);
|
||||
= ITensor::slice(eagleBuffers.eagleNetCtxPastKeyValueLengthsHost, batchIdx, 1);
|
||||
|
||||
runtime::bufferCast<SizeType32>(*eagleNetCtxRequestTypesHostSlice)[0] = 0;
|
||||
runtime::bufferCast<SizeType32>(*eagleNetCtxContextLengthsHostSlice)[0] = request.inputLen;
|
||||
runtime::bufferCast<SizeType32>(*eagleNetCtxPastKeyValueLengthsHostSlice)[0] = request.inputLen;
|
||||
runtime::bufferCast<SizeType32>(*eagleNetCtxContextLengthsHostSlice)[0] = inputLen;
|
||||
runtime::bufferCast<SizeType32>(*eagleNetCtxPastKeyValueLengthsHostSlice)[0] = inputLen;
|
||||
|
||||
TensorPtr eagleNetGenRequestTypesHostSlice
|
||||
= ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetGenRequestTypesHost, batchIdx, 1);
|
||||
TensorPtr eagleNetGenRequestTypesHostSlice = ITensor::slice(eagleBuffers.eagleNetGenRequestTypesHost, batchIdx, 1);
|
||||
TensorPtr eagleNetGenContextLengthsHostSlice
|
||||
= ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetGenContextLengthsHost, batchIdx, 1);
|
||||
= ITensor::slice(eagleBuffers.eagleNetGenContextLengthsHost, batchIdx, 1);
|
||||
TensorPtr eagleNetGenPastKeyValueLengthsHostSlice
|
||||
= ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetGenPastKeyValueLengthsHost, batchIdx, 1);
|
||||
= ITensor::slice(eagleBuffers.eagleNetGenPastKeyValueLengthsHost, batchIdx, 1);
|
||||
|
||||
runtime::bufferCast<SizeType32>(*eagleNetGenRequestTypesHostSlice)[0] = 1;
|
||||
runtime::bufferCast<SizeType32>(*eagleNetGenContextLengthsHostSlice)[0] = request.inputLen;
|
||||
runtime::bufferCast<SizeType32>(*eagleNetGenPastKeyValueLengthsHostSlice)[0] = request.inputLen;
|
||||
runtime::bufferCast<SizeType32>(*eagleNetGenContextLengthsHostSlice)[0] = inputLen;
|
||||
runtime::bufferCast<SizeType32>(*eagleNetGenPastKeyValueLengthsHostSlice)[0] = inputLen;
|
||||
|
||||
auto const eagleModule = std::dynamic_pointer_cast<tensorrt_llm::runtime::EagleModule const>(
|
||||
modelConfig.getSpeculativeDecodingModulePtr());
|
||||
std::optional<executor::EagleChoices> eagleChoicesOpt;
|
||||
|
||||
if (request.eagleConfig)
|
||||
auto const& eagleConfig = llmReq.getEagleConfig() ? llmReq.getEagleConfig() : decodingConfig.getEagleConfig();
|
||||
|
||||
if (eagleConfig)
|
||||
{
|
||||
eagleChoicesOpt = request.eagleConfig->getEagleChoices();
|
||||
eagleChoicesOpt = eagleConfig->getEagleChoices();
|
||||
}
|
||||
|
||||
if (!request.eagleConfig || !request.eagleConfig->useDynamicTree())
|
||||
if (!eagleConfig || !eagleConfig->useDynamicTree())
|
||||
{
|
||||
TensorPtr draftPathsHostSlice = ITensor::slice(jointDecodingOutput.eagleBuffers->draftPathsHost, batchIdx, 1);
|
||||
TensorPtr draftPathsSlice = ITensor::slice(jointDecodingOutput.eagleBuffers->draftPaths, batchIdx, 1);
|
||||
TensorPtr draftPathsHostSlice = ITensor::slice(eagleBuffers.draftPathsHost, batchIdx, 1);
|
||||
TensorPtr draftPathsSlice = ITensor::slice(eagleBuffers.draftPaths, batchIdx, 1);
|
||||
|
||||
// eagleConfig is nullptr or Eagle-1
|
||||
std::vector<SizeType32> topKs;
|
||||
@ -546,6 +574,61 @@ void CreateNewDecoderRequests::newRequestEagle(SizeType32 batchIdx, runtime::dec
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
//! @brief Setups decoder internal tensors for new speculative decoding request
|
||||
void newRequestSpeculativeDecoding(DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput,
|
||||
SizeType32 batchIdx, LlmRequest& llmReq, SpeculativeDecodingMode const& speculativeDecodingMode,
|
||||
SizeType32 numDecodingEngineTokens, SizeType32 maxDecodingEngineTokens,
|
||||
OptionalRef<MedusaBuffers const> medusaBuffers, runtime::ModelConfig const& modelConfig,
|
||||
WorldConfig const& worldConfig, executor::DecodingConfig const& decodingConfig, bool speculativeDecodingFastLogits,
|
||||
bool isLeaderInOrchMode, CudaStream const& runtimeStream, CudaStream const& decoderStream)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
if (speculativeDecodingMode.predictsDraftTokens())
|
||||
{
|
||||
BufferManager manager{std::make_shared<CudaStream>(decoderStream.get())};
|
||||
|
||||
TLLM_CHECK(jointDecodingOutput.speculativeDecodingOutputs);
|
||||
auto& speculativeDecodingOutputs = *jointDecodingOutput.speculativeDecodingOutputs;
|
||||
|
||||
TensorPtr nextDraftTokens = ITensor::slice(speculativeDecodingOutputs.nextDraftTokens, batchIdx, 1);
|
||||
// FIXME: can we skip this?
|
||||
manager.setZero(*nextDraftTokens);
|
||||
if (speculativeDecodingMode.variableDraftLength())
|
||||
{
|
||||
TensorPtr nextDraftTokensLen = ITensor::slice(speculativeDecodingOutputs.nextDraftTokensLen, batchIdx, 1);
|
||||
manager.setZero(*nextDraftTokensLen);
|
||||
}
|
||||
}
|
||||
|
||||
if (speculativeDecodingMode.isDraftTokensExternal())
|
||||
{
|
||||
newRequestDraftTokensExternal(jointDecodingInput, batchIdx, llmReq, numDecodingEngineTokens, modelConfig,
|
||||
worldConfig, speculativeDecodingFastLogits, isLeaderInOrchMode, decoderStream);
|
||||
}
|
||||
else if (speculativeDecodingMode.isMedusa())
|
||||
{
|
||||
TLLM_CHECK(medusaBuffers);
|
||||
newRequestMedusa(jointDecodingInput, batchIdx, llmReq, numDecodingEngineTokens, maxDecodingEngineTokens,
|
||||
medusaBuffers.value(), decoderStream);
|
||||
}
|
||||
else if (speculativeDecodingMode.isLookaheadDecoding())
|
||||
{
|
||||
newRequestLookahead(jointDecodingInput, jointDecodingOutput, batchIdx, runtimeStream);
|
||||
}
|
||||
else if (speculativeDecodingMode.isExplicitDraftTokens())
|
||||
{
|
||||
newRequestExplicitDraftTokens(jointDecodingOutput, batchIdx, llmReq, runtimeStream);
|
||||
}
|
||||
else if (speculativeDecodingMode.isEagle())
|
||||
{
|
||||
newRequestEagle(jointDecodingOutput, batchIdx, llmReq, modelConfig, decodingConfig, runtimeStream);
|
||||
}
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::tuple<std::vector<runtime::ITensor::SharedConstPtr>, std::vector<executor::LookaheadDecodingConfig>>
|
||||
CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedContextRequests, TensorPtr const& inputIds,
|
||||
executor::DecodingConfig const& decodingConfig, runtime::decoder::DecoderState& decoderState,
|
||||
@ -563,9 +646,6 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon
|
||||
}
|
||||
inputIds->resize(decoderInputSize);
|
||||
|
||||
std::vector<decoder_batch::Request> decoderRequests;
|
||||
decoderRequests.reserve(finishedContextRequests.size());
|
||||
|
||||
std::vector<runtime::ITensor::SharedConstPtr> lookaheadPrompt;
|
||||
std::vector<executor::LookaheadDecodingConfig> lookaheadAlgoConfigs;
|
||||
if (modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding())
|
||||
@ -597,36 +677,18 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon
|
||||
|
||||
auto const promptLen = llmReq->getPromptLen();
|
||||
|
||||
auto decoderRequest = decoder_batch::Request{promptLen};
|
||||
|
||||
SizeType32 numDecodingEngineTokens{1};
|
||||
if (modelConfig.getSpeculativeDecodingMode().isDraftTokensExternal())
|
||||
{
|
||||
if (llmReq->hasDraftTokens())
|
||||
{
|
||||
auto const& draftTokens = llmReq->getDraftTokens();
|
||||
// Copy to pinned host memory (don't care about stream of bufferManager)
|
||||
decoderRequest.draftTokens = decoderBufferManager.copyFrom(*draftTokens, MemoryType::kPINNEDPOOL);
|
||||
auto const& draftLogits = llmReq->getDraftLogits();
|
||||
if (draftLogits.has_value())
|
||||
{
|
||||
decoderRequest.draftLogits
|
||||
= retrieveDraftLogits(modelConfig, worldConfig, draftLogits.value(), decoderBufferManager);
|
||||
}
|
||||
decoderRequest.generatedTokensPerEngineStep = draftTokens->size() + 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
decoderRequest.generatedTokensPerEngineStep = 1;
|
||||
}
|
||||
numDecodingEngineTokens = llmReq->getNumDraftTokens() + 1;
|
||||
}
|
||||
else if (!modelConfig.getSpeculativeDecodingMode().isNone())
|
||||
{
|
||||
decoderRequest.generatedTokensPerEngineStep = modelConfig.getMaxDecodingTokens();
|
||||
numDecodingEngineTokens = modelConfig.getMaxDecodingTokens();
|
||||
}
|
||||
|
||||
auto& dJointInput = decoderState.getJointDecodingInput();
|
||||
|
||||
auto const numDecodingEngineTokens = decoderRequest.generatedTokensPerEngineStep;
|
||||
initializeInputLengths(dJointInput, batchSlot, promptLen, llmReq->mMaxNewTokens, numDecodingEngineTokens,
|
||||
maxSequenceLength, decoderBufferManager);
|
||||
decoderState.setNumDecodingEngineTokens(batchSlot, numDecodingEngineTokens);
|
||||
@ -667,16 +729,7 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon
|
||||
{
|
||||
TLLM_CHECK(beamWidth == 1);
|
||||
|
||||
if (modelConfig.getSpeculativeDecodingMode().isMedusa())
|
||||
{
|
||||
TLLM_CHECK(medusaBuffers);
|
||||
llmReq->mSamplingConfig.topKMedusaHeads = {medusaBuffers->mTopKs};
|
||||
// FIXME: we must set medusa paths and tree ids not from seq slot, but from llmRequest?
|
||||
// When multiple microbatches buffers are used, runtime buffers can not be addressed with seqSlot.
|
||||
decoderRequest.medusaPaths = ITensor::slice(medusaBuffers->medusaPathsDevice, 0, 1);
|
||||
decoderRequest.medusaTreeIds = ITensor::slice(medusaBuffers->medusaTreeIdsDevice, 0, 1);
|
||||
}
|
||||
else if (modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding())
|
||||
if (modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding())
|
||||
{
|
||||
lookaheadPrompt.emplace_back(requestIds);
|
||||
|
||||
@ -684,67 +737,17 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon
|
||||
= llmReq->getLookaheadConfig().value_or(decodingConfig.getLookaheadDecodingConfig().value());
|
||||
lookaheadAlgoConfigs.emplace_back(lookaheadRuntimeConfig);
|
||||
}
|
||||
else if (modelConfig.getSpeculativeDecodingMode().isEagle())
|
||||
{
|
||||
decoderRequest.eagleConfig
|
||||
= llmReq->getEagleConfig() ? llmReq->getEagleConfig() : decodingConfig.getEagleConfig();
|
||||
}
|
||||
|
||||
newRequestSpeculativeDecoding(batchSlot, decoderRequest, samplingConfig, modelConfig,
|
||||
decoderState.getJointDecodingInput(), decoderState.getJointDecodingOutput(), runtimeStream,
|
||||
decoderStream, decoderState.getSpeculativeDecodingMode(), decoderState.getMaxDecodingEngineTokens());
|
||||
newRequestSpeculativeDecoding(decoderState.getJointDecodingInput(), decoderState.getJointDecodingOutput(),
|
||||
batchSlot, *llmReq, decoderState.getSpeculativeDecodingMode(), numDecodingEngineTokens,
|
||||
decoderState.getMaxDecodingEngineTokens(), medusaBuffers, modelConfig, worldConfig, decodingConfig,
|
||||
mSpeculativeDecodingFastLogits, mIsLeaderInOrchMode, runtimeStream, decoderStream);
|
||||
}
|
||||
|
||||
decoderRequests.push_back(decoderRequest);
|
||||
|
||||
inputOffset += promptLen;
|
||||
}
|
||||
|
||||
return {std::move(lookaheadPrompt), std::move(lookaheadAlgoConfigs)};
|
||||
}
|
||||
|
||||
std::shared_ptr<runtime::ITensor> CreateNewDecoderRequests::retrieveDraftLogits(ModelConfig const& modelConfig,
|
||||
WorldConfig const& worldConfig, std::shared_ptr<runtime::ITensor> const& tensor,
|
||||
BufferManager const& bufferManager) const
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
if (!mSpeculativeDecodingFastLogits)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
return bufferManager.copyFrom(*tensor, MemoryType::kPINNEDPOOL);
|
||||
}
|
||||
|
||||
if (mIsLeaderInOrchMode)
|
||||
{
|
||||
te::SpeculativeDecodingFastLogitsInfo fastLogitsInfo;
|
||||
std::memcpy(&fastLogitsInfo, tensor->data(), sizeof(fastLogitsInfo));
|
||||
auto logits = utils::targetModelReceiveLogits(fastLogitsInfo, modelConfig).value();
|
||||
|
||||
// Broadcast to other ranks if needed
|
||||
if (worldConfig.isTensorParallel())
|
||||
{
|
||||
auto const& commSession = COMM_SESSION;
|
||||
auto shape = logits->getShape();
|
||||
commSession.bcastValue(shape.d[0], 0);
|
||||
commSession.bcastValue(shape.d[1], 0);
|
||||
commSession.bcast(logits->data(), logits->getSizeInBytes(), mpi::MpiType::kUINT8, 0);
|
||||
}
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
return logits;
|
||||
}
|
||||
|
||||
// Get logits from leader rank
|
||||
auto const& commSession = COMM_SESSION;
|
||||
int64_t dims[2];
|
||||
commSession.bcastValue(dims[0], 0);
|
||||
commSession.bcastValue(dims[1], 0);
|
||||
auto const logitsDtype = modelConfig.getLogitsDtype();
|
||||
auto logits = tensorrt_llm::runtime::BufferManager::pinnedPool(ITensor::makeShape({dims[0], dims[1]}), logitsDtype);
|
||||
commSession.bcast(logits->data(), logits->getSizeInBytes(), mpi::MpiType::kUINT8, 0);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
return logits;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
|
||||
@ -558,18 +558,20 @@ void MLACacheFormatter::unformat(TransferSession& session)
|
||||
TLLM_LOG_WARNING("MLACacheFormatter::inquireSupport: only support MLA");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (selfConfig.getAttentionConfig().mKvFactor != destConfig.getAttentionConfig().mKvFactor)
|
||||
{
|
||||
TLLM_LOG_WARNING("MLACacheFormatter::inquireSupport: only support same kv factor");
|
||||
return false;
|
||||
}
|
||||
if (selfConfig.getParallelConfig().mEnableAttentionDP
|
||||
&& (selfConfig.getParallelConfig().mTensorParallelism % selfConfig.getParallelConfig().mDPsize != 0))
|
||||
{
|
||||
TLLM_LOG_WARNING("MLACacheFormatter::inquireSupport: TP size must be divisible by DP size");
|
||||
return false;
|
||||
}
|
||||
if (selfConfig.getParallelConfig().mContextParallelism != 1
|
||||
|| destConfig.getParallelConfig().mContextParallelism != 1)
|
||||
{
|
||||
TLLM_LOG_WARNING(
|
||||
"MLACacheFormatter::inquireSupport: context parallelism is not currently supported (selfCP=%d, destCP=%d).",
|
||||
selfConfig.getParallelConfig().mContextParallelism, destConfig.getParallelConfig().mContextParallelism);
|
||||
return false;
|
||||
}
|
||||
if (destConfig.getParallelConfig().mEnableAttentionDP
|
||||
&& (destConfig.getParallelConfig().mTensorParallelism % destConfig.getParallelConfig().mDPsize != 0))
|
||||
{
|
||||
|
||||
@ -121,8 +121,8 @@ void draftModelSendLogitsThread(int device, std::atomic<bool>* draftModelThreadS
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
}
|
||||
|
||||
std::optional<runtime::ITensor::SharedPtr> targetModelReceiveLogits(
|
||||
executor::SpeculativeDecodingFastLogitsInfo const& fastLogitsInfo, runtime::ModelConfig const& modelConfig)
|
||||
void targetModelReceiveLogits(runtime::ITensor::SharedPtr& draftLogitsHost,
|
||||
executor::SpeculativeDecodingFastLogitsInfo const& fastLogitsInfo, nvinfer1::DataType logitsDtype)
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
auto const& worldComm = tensorrt_llm::mpi::MpiComm::world();
|
||||
@ -151,10 +151,7 @@ std::optional<runtime::ITensor::SharedPtr> targetModelReceiveLogits(
|
||||
int64_t dims[2];
|
||||
MPICHECK(MPI_Mrecv(&dims, count, MPI_INT64_T, &msg, &status));
|
||||
|
||||
auto const logitsDtype = modelConfig.getLogitsDtype();
|
||||
|
||||
auto tensor = tensorrt_llm::runtime::BufferManager::pinnedPool(
|
||||
runtime::ITensor::makeShape({dims[0], dims[1]}), logitsDtype);
|
||||
draftLogitsHost->reshape(runtime::ITensor::makeShape({dims[0], dims[1]}));
|
||||
|
||||
worldComm.mprobe(fastLogitsInfo.draftParticipantId, mpi::MpiTag::kSpecDecLogitsData, &msg, &status);
|
||||
|
||||
@ -163,11 +160,7 @@ std::optional<runtime::ITensor::SharedPtr> targetModelReceiveLogits(
|
||||
uint64_t const expectedSize = static_cast<uint64_t>(dims[0]) * dims[1] * tc::getDTypeSize(logitsDtype);
|
||||
TLLM_CHECK((uint64_t) count == expectedSize);
|
||||
|
||||
MPICHECK(MPI_Mrecv(tensor->data(), count, MPI_UINT8_T, &msg, &status));
|
||||
|
||||
return tensor;
|
||||
#else
|
||||
return std::nullopt;
|
||||
MPICHECK(MPI_Mrecv(draftLogitsHost->data(), count, MPI_UINT8_T, &msg, &status));
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
}
|
||||
|
||||
|
||||
@ -21,10 +21,8 @@
|
||||
#include "tensorrt_llm/executor/executor.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include "tensorrt_llm/runtime/modelConfig.h"
|
||||
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
|
||||
namespace tensorrt_llm::batch_manager
|
||||
{
|
||||
@ -52,7 +50,7 @@ void draftModelSendLogitsThread(int device, std::atomic<bool>* draftModelThreadS
|
||||
std::shared_ptr<kv_cache_manager::BaseKVCacheManager> const& crossKvCacheManager,
|
||||
std::shared_ptr<BasePeftCacheManager> const& peftCacheManager);
|
||||
|
||||
std::optional<runtime::ITensor::SharedPtr> targetModelReceiveLogits(
|
||||
executor::SpeculativeDecodingFastLogitsInfo const& fastLogitsInfo, runtime::ModelConfig const& modelConfig);
|
||||
void targetModelReceiveLogits(runtime::ITensor::SharedPtr& draftLogitsHost,
|
||||
executor::SpeculativeDecodingFastLogitsInfo const& fastLogitsInfo, nvinfer1::DataType logitsDtype);
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager::utils
|
||||
|
||||
@ -435,6 +435,14 @@ struct CutlassGemmConfig
|
||||
int sm_version = 80; // Use 80 as a catch all for <90
|
||||
bool is_tma_warp_specialized = false;
|
||||
|
||||
enum class EpilogueFusionType : int
|
||||
{
|
||||
NONE,
|
||||
FINALIZE
|
||||
};
|
||||
|
||||
EpilogueFusionType epilogue_fusion_type = EpilogueFusionType::NONE;
|
||||
|
||||
CutlassGemmConfig() = default;
|
||||
|
||||
CutlassGemmConfig(CutlassTileConfig tile_config, SplitKStyle split_k_style, int split_k_factor, int stages)
|
||||
@ -505,7 +513,8 @@ struct CutlassGemmConfig
|
||||
<< "\n\tsm: " << sm_version << "\n\ttile shape ID: " << getTileConfigAsInt()
|
||||
<< "\n\tcluster shape ID: " << (int) cluster_shape
|
||||
<< "\n\tmainloop sched: " << (int) mainloop_schedule << "\n\tepi sched: " << (int) epilogue_schedule
|
||||
<< "\n\tenable cuda kernel: " << (enableCudaKernel ? "true" : "false");
|
||||
<< "\n\tenable cuda kernel: " << (enableCudaKernel ? "true" : "false")
|
||||
<< "\n\tepilogue fusion type: " << (int) epilogue_fusion_type;
|
||||
}
|
||||
else if (tile_config_sm80 != tensorrt_llm::cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic)
|
||||
{
|
||||
@ -537,7 +546,8 @@ inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& conf
|
||||
<< ", mainloop_schedule_enum: " << int(config.mainloop_schedule)
|
||||
<< ", epilogue_schedule_enum: " << int(config.epilogue_schedule)
|
||||
<< ", cluster_shape_enum: " << int(config.cluster_shape)
|
||||
<< ", enable_cuda_kernel: " << (config.enableCudaKernel ? "true" : "false");
|
||||
<< ", enable_cuda_kernel: " << (config.enableCudaKernel ? "true" : "false")
|
||||
<< ", epilogue_fusion_type: " << int(config.epilogue_fusion_type);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@ -531,14 +531,15 @@ kv_cache::CacheState Serialization::deserializeCacheState(std::istream& is)
|
||||
auto tokensPerBlock = su::deserialize<decltype(CacheState::ModelConfig::mTokensPerBlock)>(is);
|
||||
auto tensorParallelism = su::deserialize<decltype(CacheState::ParallelConfig::mTensorParallelism)>(is);
|
||||
auto pipelineParallelism = su::deserialize<decltype(CacheState::ParallelConfig::mPipelineParallelism)>(is);
|
||||
auto contextParallelism = su::deserialize<decltype(CacheState::ParallelConfig::mContextParallelism)>(is);
|
||||
auto enableAttentionDP = su::deserialize<decltype(CacheState::ParallelConfig::mEnableAttentionDP)>(is);
|
||||
auto DPrank = su::deserialize<decltype(CacheState::ParallelConfig::mDPrank)>(is);
|
||||
auto DPsize = su::deserialize<decltype(CacheState::ParallelConfig::mDPsize)>(is);
|
||||
auto dataType = su::deserialize<decltype(CacheState::mDataType)>(is);
|
||||
auto attentionType = su::deserialize<decltype(CacheState::AttentionConfig::mAttentionType)>(is);
|
||||
auto kvFactor = su::deserialize<decltype(CacheState::AttentionConfig::mKvFactor)>(is);
|
||||
return CacheState{nbKvHeadsPerLayer, sizePerHead, tokensPerBlock, tensorParallelism, pipelineParallelism, dataType,
|
||||
attentionType, kvFactor, enableAttentionDP, DPrank, DPsize};
|
||||
return CacheState{nbKvHeadsPerLayer, sizePerHead, tokensPerBlock, tensorParallelism, pipelineParallelism,
|
||||
contextParallelism, dataType, attentionType, kvFactor, enableAttentionDP, DPrank, DPsize};
|
||||
}
|
||||
|
||||
void Serialization::serialize(kv_cache::CacheState const& state, std::ostream& os)
|
||||
@ -548,6 +549,7 @@ void Serialization::serialize(kv_cache::CacheState const& state, std::ostream& o
|
||||
su::serialize(state.mModelConfig.mTokensPerBlock, os);
|
||||
su::serialize(state.mParallelConfig.mTensorParallelism, os);
|
||||
su::serialize(state.mParallelConfig.mPipelineParallelism, os);
|
||||
su::serialize(state.mParallelConfig.mContextParallelism, os);
|
||||
su::serialize(state.mParallelConfig.mEnableAttentionDP, os);
|
||||
su::serialize(state.mParallelConfig.mDPrank, os);
|
||||
su::serialize(state.mParallelConfig.mDPsize, os);
|
||||
@ -564,6 +566,7 @@ size_t Serialization::serializedSize(kv_cache::CacheState const& state)
|
||||
totalSize += su::serializedSize(state.mModelConfig.mTokensPerBlock);
|
||||
totalSize += su::serializedSize(state.mParallelConfig.mTensorParallelism);
|
||||
totalSize += su::serializedSize(state.mParallelConfig.mPipelineParallelism);
|
||||
totalSize += su::serializedSize(state.mParallelConfig.mContextParallelism);
|
||||
totalSize += su::serializedSize(state.mParallelConfig.mEnableAttentionDP);
|
||||
totalSize += su::serializedSize(state.mParallelConfig.mDPrank);
|
||||
totalSize += su::serializedSize(state.mParallelConfig.mDPsize);
|
||||
|
||||
268
cpp/tensorrt_llm/kernels/customMoeRoutingKernels.cu
Normal file
268
cpp/tensorrt_llm/kernels/customMoeRoutingKernels.cu
Normal file
@ -0,0 +1,268 @@
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "moeTopKFuncs.cuh"
|
||||
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
|
||||
#include "tensorrt_llm/common/envUtils.h"
|
||||
#include "tensorrt_llm/kernels/archCondition.h"
|
||||
#include "tensorrt_llm/kernels/customMoeRoutingKernels.h"
|
||||
#include <climits> // For INT_MAX
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <cub/cub.cuh>
|
||||
#include <cuda/std/limits> // For numeric_limits
|
||||
#include <math.h>
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
using namespace tensorrt_llm::common;
|
||||
|
||||
namespace tensorrt_llm::kernels
|
||||
{
|
||||
|
||||
static constexpr int BLOCK_SIZE = 1024;
|
||||
static constexpr int WARP_SIZE = 32;
|
||||
static constexpr int WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
__device__ T calcSoftmax(cg::thread_block_tile<WARP_SIZE> const& warp, T score, int32_t laneIdx, int32_t NumTopExperts)
|
||||
{
|
||||
T maxScore = T{-INFINITY};
|
||||
if (laneIdx < NumTopExperts)
|
||||
{
|
||||
maxScore = score >= maxScore ? score : maxScore;
|
||||
}
|
||||
maxScore = cg::reduce(warp, maxScore, cg::greater<T>());
|
||||
|
||||
float sumScore{0.f};
|
||||
float newScore;
|
||||
// Get the summation of scores for each token
|
||||
if (laneIdx < NumTopExperts)
|
||||
{
|
||||
newScore = static_cast<float>(score) - static_cast<float>(maxScore);
|
||||
newScore = static_cast<float>(exp(newScore));
|
||||
sumScore += newScore;
|
||||
}
|
||||
sumScore = cg::reduce(warp, sumScore, cg::plus<float>());
|
||||
|
||||
if (laneIdx < NumTopExperts)
|
||||
{
|
||||
score = static_cast<T>(newScore / sumScore);
|
||||
}
|
||||
|
||||
return score;
|
||||
}
|
||||
|
||||
template <typename DataType, int VecSize>
|
||||
__device__ void calcSoftmax(cg::thread_block_tile<WARP_SIZE> const& warp, DataType (&scores)[VecSize])
|
||||
{
|
||||
DataType maxScore = DataType{-INFINITY};
|
||||
DataType sumScore = DataType{0.f};
|
||||
|
||||
// Get the max score for each token
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; ++i)
|
||||
{
|
||||
maxScore = scores[i] >= maxScore ? scores[i] : maxScore;
|
||||
}
|
||||
maxScore = cg::reduce(warp, maxScore, cg::greater<DataType>());
|
||||
|
||||
// Get the summation of scores for each token
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; ++i)
|
||||
{
|
||||
scores[i] = static_cast<DataType>(exp(scores[i] - maxScore));
|
||||
sumScore += scores[i];
|
||||
}
|
||||
sumScore = cg::reduce(warp, sumScore, cg::plus<DataType>());
|
||||
|
||||
// Normalize the scores
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; ++i)
|
||||
{
|
||||
scores[i] = static_cast<DataType>(scores[i] / sumScore);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename InputT, typename OutputT, typename IdxT, int MaxNumExperts, int MaxNumTopExperts,
|
||||
bool DoSoftmaxBeforeTopK>
|
||||
__global__ void customMoeRoutingKernel(InputT* routerLogits, OutputT* topkValues, IdxT* topkIndices,
|
||||
int32_t const numTokens, int32_t const numExperts, int32_t const topK)
|
||||
{
|
||||
using BaseType = std::conditional_t<DoSoftmaxBeforeTopK, float, InputT>;
|
||||
uint32_t const blockRank = blockIdx.x;
|
||||
uint32_t const tIdx = BLOCK_SIZE * blockRank + threadIdx.x;
|
||||
uint32_t const warpIdx = tIdx / WARP_SIZE;
|
||||
uint32_t const laneIdx = tIdx % WARP_SIZE;
|
||||
uint32_t const warpNum = gridDim.x * WARPS_PER_BLOCK;
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
|
||||
BaseType minScore = BaseType{-INFINITY};
|
||||
for (uint32_t tokenId = warpIdx; tokenId < numTokens; tokenId += warpNum)
|
||||
{
|
||||
auto scoreOffset = tokenId * numExperts;
|
||||
auto outputOffset = tokenId * topK;
|
||||
|
||||
BaseType inputScore[MaxNumExperts / WARP_SIZE];
|
||||
IdxT inputIndex[MaxNumExperts / WARP_SIZE];
|
||||
|
||||
BaseType warpTopKScore[MaxNumTopExperts];
|
||||
IdxT warpTopKExpertIdx[MaxNumTopExperts];
|
||||
|
||||
// Load scores and indices for this warp
|
||||
for (uint32_t i = 0; i < MaxNumExperts / WARP_SIZE; ++i)
|
||||
{
|
||||
auto expertIdx = i * WARP_SIZE + laneIdx;
|
||||
inputScore[i]
|
||||
= expertIdx < numExperts ? static_cast<BaseType>(routerLogits[scoreOffset + expertIdx]) : minScore;
|
||||
inputIndex[i] = expertIdx;
|
||||
}
|
||||
|
||||
if constexpr (DoSoftmaxBeforeTopK)
|
||||
{
|
||||
calcSoftmax(warp, inputScore);
|
||||
}
|
||||
// Reduce topK scores and indices for this warp
|
||||
reduce_topk::reduceTopK(warp, warpTopKScore, warpTopKExpertIdx, inputScore, inputIndex, minScore);
|
||||
|
||||
// Normalize the scores
|
||||
if constexpr (DoSoftmaxBeforeTopK)
|
||||
{
|
||||
if (laneIdx < topK)
|
||||
{
|
||||
topkValues[outputOffset + laneIdx] = static_cast<OutputT>(warpTopKScore[laneIdx]);
|
||||
topkIndices[outputOffset + laneIdx] = warpTopKExpertIdx[laneIdx];
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
auto softmaxScore = calcSoftmax(warp,
|
||||
laneIdx < topK ? static_cast<float>(warpTopKScore[laneIdx]) : static_cast<float>(minScore), laneIdx,
|
||||
topK);
|
||||
if (laneIdx < topK)
|
||||
{
|
||||
topkValues[outputOffset + laneIdx] = static_cast<OutputT>(softmaxScore);
|
||||
topkIndices[outputOffset + laneIdx] = warpTopKExpertIdx[laneIdx];
|
||||
}
|
||||
}
|
||||
} // end for tokenId
|
||||
}
|
||||
|
||||
int nextPowerOfTwo(int num)
|
||||
{
|
||||
if (num <= 0)
|
||||
{
|
||||
return 1; // Handle invalid input
|
||||
}
|
||||
int power = 1;
|
||||
while (power < num)
|
||||
{
|
||||
// Check for overflow before shifting
|
||||
if (power > INT_MAX / 2)
|
||||
{
|
||||
return power;
|
||||
}
|
||||
power <<= 1;
|
||||
}
|
||||
return power;
|
||||
}
|
||||
|
||||
#define CASE(MAX_NUM_EXPERTS) \
|
||||
case MAX_NUM_EXPERTS: \
|
||||
switch (maxNumTopExperts) \
|
||||
{ \
|
||||
case 1: \
|
||||
kernelInstance = &customMoeRoutingKernel<InputT, OutputT, IdxT, MAX_NUM_EXPERTS, 1, DoSoftmaxBeforeTopK>; \
|
||||
break; \
|
||||
case 2: \
|
||||
kernelInstance = &customMoeRoutingKernel<InputT, OutputT, IdxT, MAX_NUM_EXPERTS, 2, DoSoftmaxBeforeTopK>; \
|
||||
break; \
|
||||
case 4: \
|
||||
kernelInstance = &customMoeRoutingKernel<InputT, OutputT, IdxT, MAX_NUM_EXPERTS, 4, DoSoftmaxBeforeTopK>; \
|
||||
break; \
|
||||
case 8: \
|
||||
kernelInstance = &customMoeRoutingKernel<InputT, OutputT, IdxT, MAX_NUM_EXPERTS, 8, DoSoftmaxBeforeTopK>; \
|
||||
break; \
|
||||
default: kernelInstance = nullptr; break; \
|
||||
} \
|
||||
break;
|
||||
|
||||
template <typename InputT, typename OutputT, typename IdxT, bool DoSoftmaxBeforeTopK>
|
||||
void invokeRenormMoeRouting(InputT* routerLogits, OutputT* topkValues, IdxT* topkIndices, int64_t const numTokens,
|
||||
int64_t const numExperts, int64_t const topK, cudaStream_t const stream)
|
||||
{
|
||||
|
||||
const uint32_t maxNumBlocks = 1024;
|
||||
const uint32_t numBlocks = std::min(static_cast<uint32_t>((numTokens - 1) / WARPS_PER_BLOCK + 1), maxNumBlocks);
|
||||
|
||||
uint32_t maxNumExperts = nextPowerOfTwo(numExperts) < 32 ? 32 : nextPowerOfTwo(numExperts);
|
||||
uint32_t maxNumTopExperts = nextPowerOfTwo(topK);
|
||||
|
||||
auto* kernelInstance = &customMoeRoutingKernel<InputT, OutputT, IdxT, 128, 8, DoSoftmaxBeforeTopK>;
|
||||
|
||||
switch (maxNumExperts)
|
||||
{
|
||||
CASE(32)
|
||||
CASE(64)
|
||||
CASE(96)
|
||||
CASE(128)
|
||||
default: kernelInstance = nullptr; break;
|
||||
}
|
||||
|
||||
if (kernelInstance == nullptr)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(kernelInstance != nullptr, "Can not find corresponding kernel instance.");
|
||||
}
|
||||
|
||||
dim3 renormMoeRoutingGridDim(numBlocks);
|
||||
dim3 renormMoeRoutingBlockDim(BLOCK_SIZE);
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = renormMoeRoutingGridDim;
|
||||
config.blockDim = renormMoeRoutingBlockDim;
|
||||
config.dynamicSmemBytes = 0;
|
||||
config.stream = stream;
|
||||
cudaLaunchAttribute attrs[1];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
cudaLaunchKernelEx(&config, kernelInstance, routerLogits, topkValues, topkIndices, static_cast<int32_t>(numTokens),
|
||||
static_cast<int32_t>(numExperts), static_cast<int32_t>(topK));
|
||||
sync_check_cuda_error(stream);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_RENORM_MOE_ROUTING(InputT, OutputT, IdxT, DoSoftmaxBeforeTopK) \
|
||||
template void invokeRenormMoeRouting<InputT, OutputT, IdxT, DoSoftmaxBeforeTopK>(InputT * routerLogits, \
|
||||
OutputT * topkValues, IdxT * topkIndices, int64_t const numTokens, int64_t const numExperts, \
|
||||
int64_t const topK, cudaStream_t const stream);
|
||||
|
||||
INSTANTIATE_RENORM_MOE_ROUTING(float, float, int32_t, false);
|
||||
INSTANTIATE_RENORM_MOE_ROUTING(half, float, int32_t, false);
|
||||
#ifdef ENABLE_BF16
|
||||
INSTANTIATE_RENORM_MOE_ROUTING(__nv_bfloat16, float, int32_t, false);
|
||||
#endif
|
||||
|
||||
INSTANTIATE_RENORM_MOE_ROUTING(float, float, int32_t, true);
|
||||
INSTANTIATE_RENORM_MOE_ROUTING(half, float, int32_t, true);
|
||||
#ifdef ENABLE_BF16
|
||||
INSTANTIATE_RENORM_MOE_ROUTING(__nv_bfloat16, float, int32_t, true);
|
||||
#endif
|
||||
|
||||
} // namespace tensorrt_llm::kernels
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@ -23,7 +23,7 @@
|
||||
|
||||
namespace tensorrt_llm::kernels
|
||||
{
|
||||
template <typename InputT, typename OutputT, typename IdxT>
|
||||
template <typename InputT, typename OutputT, typename IdxT, bool DoSoftmaxBeforeTopK>
|
||||
void invokeRenormMoeRouting(InputT* routerLogits, OutputT* topkValues, IdxT* topkIndices, int64_t const numTokens,
|
||||
int64_t const numExperts, int64_t const topK, cudaStream_t const stream);
|
||||
} // namespace tensorrt_llm::kernels
|
||||
@ -288,15 +288,20 @@ public:
|
||||
void moeGemm(GroupedGemmInput<T, WeightType, ScaleBiasType, OutputType> inputs,
|
||||
TmaWarpSpecializedGroupedGemmInput hopper_inputs);
|
||||
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig> getConfigs() const;
|
||||
static std::vector<cutlass_extensions::CutlassGemmConfig> getConfigs(int sm);
|
||||
static std::vector<cutlass_extensions::CutlassGemmConfig> getTmaWarpSpecializedConfigs(int sm);
|
||||
static std::vector<cutlass_extensions::CutlassGemmConfig> getBlackwellConfigs(int sm);
|
||||
static std::vector<cutlass_extensions::CutlassGemmConfig> getHopperConfigs(int sm);
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig> getConfigs(bool supports_finalize_fusion) const;
|
||||
static std::vector<cutlass_extensions::CutlassGemmConfig> getConfigs(int sm, bool supports_finalize_fusion);
|
||||
static std::vector<cutlass_extensions::CutlassGemmConfig> getTmaWarpSpecializedConfigs(
|
||||
int sm, bool supports_finalize_fusion);
|
||||
static std::vector<cutlass_extensions::CutlassGemmConfig> getAmpereConfigs(int sm);
|
||||
|
||||
[[nodiscard]] bool isTmaWarpSpecialized(cutlass_extensions::CutlassGemmConfig gemm_config) const;
|
||||
[[nodiscard]] bool supportsTmaWarpSpecialized() const;
|
||||
|
||||
[[nodiscard]] bool supportsTmaWarpSpecialized() const
|
||||
{
|
||||
return supportsTmaWarpSpecialized(sm_);
|
||||
}
|
||||
|
||||
[[nodiscard]] static bool supportsTmaWarpSpecialized(int sm);
|
||||
[[nodiscard]] bool isFusedGatedActivation(cutlass_extensions::CutlassGemmConfig gemm_config,
|
||||
ActivationType activation_type, int gemm_n, int gemm_k) const;
|
||||
[[nodiscard]] bool supportsFusedGatedActivation(ActivationType activation_type, int gemm_n, int gemm_k) const;
|
||||
|
||||
@ -228,6 +228,13 @@ struct MOEParallelismConfig
|
||||
}
|
||||
};
|
||||
|
||||
enum class MoeGemmId : int
|
||||
{
|
||||
Undefined = 0,
|
||||
GEMM_1,
|
||||
GEMM_2
|
||||
};
|
||||
|
||||
struct QuantParams
|
||||
{
|
||||
// Int weight only quantization params
|
||||
@ -446,7 +453,7 @@ public:
|
||||
virtual void setTactic(std::optional<cutlass_extensions::CutlassGemmConfig> gemm1_config,
|
||||
std::optional<cutlass_extensions::CutlassGemmConfig> gemm2_config)
|
||||
= 0;
|
||||
virtual std::vector<cutlass_extensions::CutlassGemmConfig> getTactics() = 0;
|
||||
virtual std::vector<cutlass_extensions::CutlassGemmConfig> getTactics(MoeGemmId gemm_id) = 0;
|
||||
|
||||
virtual void runMoe(void const* input_activations, void const* input_sf, bool const swizzled_input_sf,
|
||||
int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights,
|
||||
@ -593,15 +600,15 @@ public:
|
||||
gemm2_config_ = std::move(gemm2_config);
|
||||
}
|
||||
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig> getTactics() override
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig> getTactics(MoeGemmId gemm_id) override
|
||||
{
|
||||
return moe_gemm_runner_.getConfigs();
|
||||
return moe_gemm_runner_.getConfigs(gemm_id == MoeGemmId::GEMM_2 && mayHaveFinalizeFused());
|
||||
}
|
||||
|
||||
static std::vector<cutlass_extensions::CutlassGemmConfig> getTactics(int sm)
|
||||
static std::vector<cutlass_extensions::CutlassGemmConfig> getTactics(int sm, MoeGemmId gemm_id)
|
||||
{
|
||||
using RunnerType = decltype(moe_gemm_runner_);
|
||||
return RunnerType::getConfigs(sm);
|
||||
return RunnerType::getConfigs(sm, gemm_id == MoeGemmId::GEMM_2 && Self::mayHaveFinalizeFused(sm));
|
||||
}
|
||||
|
||||
void runMoe(void const* input_activations, void const* input_sf, bool const swizzled_input_sf,
|
||||
@ -798,6 +805,12 @@ private:
|
||||
&& !use_w4_groupwise;
|
||||
}
|
||||
|
||||
static bool mayHaveFinalizeFused(int sm)
|
||||
{
|
||||
using RunnerType = decltype(moe_gemm_runner_);
|
||||
return RunnerType::supportsTmaWarpSpecialized(sm) && sm >= 90 && !use_w4_groupwise;
|
||||
}
|
||||
|
||||
// TODO: This should eventually take the quant params to give more flexibility
|
||||
static auto getScalingType()
|
||||
{
|
||||
@ -895,12 +908,7 @@ struct GemmProfilerBackend
|
||||
{
|
||||
public:
|
||||
using Config = cutlass_extensions::CutlassGemmConfig;
|
||||
enum class GemmToProfile
|
||||
{
|
||||
Undefined = 0,
|
||||
GEMM_1,
|
||||
GEMM_2
|
||||
};
|
||||
using GemmToProfile = MoeGemmId;
|
||||
|
||||
void init(CutlassMoeFCRunnerInterface& runner, GemmToProfile gemm_to_profile, nvinfer1::DataType dtype,
|
||||
nvinfer1::DataType wtype, nvinfer1::DataType otype, int num_experts, int k, int64_t hidden_size,
|
||||
@ -951,7 +959,6 @@ public:
|
||||
CutlassMoeFCRunnerInterface* mInterface;
|
||||
|
||||
GemmToProfile mGemmToProfile = GemmToProfile::Undefined;
|
||||
std::vector<Config> mAllTacticsSaved;
|
||||
int mSM{};
|
||||
int64_t mNumExperts{};
|
||||
int64_t mNumExpertsPerNode{};
|
||||
@ -972,7 +979,7 @@ public:
|
||||
// This will be a unique value for every iteration of warmup and actual bench
|
||||
constexpr static int64_t NUM_ROUTING_SAMPLES = 16;
|
||||
|
||||
std::array<TmaWarpSpecializedGroupedGemmInput, NUM_ROUTING_SAMPLES> mTmaInputCache;
|
||||
std::array<std::array<TmaWarpSpecializedGroupedGemmInput, 2>, NUM_ROUTING_SAMPLES> mTmaInputCache;
|
||||
QuantParams mQuantParams;
|
||||
|
||||
bool mBias{};
|
||||
@ -985,7 +992,8 @@ public:
|
||||
private:
|
||||
void prepareRouting(int num_tokens, char* workspace, cudaStream_t stream);
|
||||
void prepareQuantParams(int num_tokens, char* workspace, cudaStream_t stream);
|
||||
void prepareTmaWsInputs(int num_tokens, char* workspace, void const* expert_weights, cudaStream_t stream);
|
||||
void prepareTmaWsInputs(int num_tokens, char* workspace, void const* expert_weights,
|
||||
TmaWarpSpecializedGroupedGemmInput::EpilogueFusion fusion, cudaStream_t stream);
|
||||
};
|
||||
|
||||
// Populates a buffer with random values for use with MOE benchmarking
|
||||
|
||||
@ -57,7 +57,6 @@ namespace kernels
|
||||
{
|
||||
namespace cutlass_kernels
|
||||
{
|
||||
|
||||
template <typename T, typename arch, typename ThreadblockShape, typename WarpShape, int Stages>
|
||||
void genericInt8GemmKernelLauncher(int8_t const* A, int8_t const* B, tk::QuantMode quantOption, float const* alphaCol,
|
||||
float const* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace,
|
||||
|
||||
@ -475,17 +475,18 @@ void dispatchMoeGemmToCutlass(GroupedGemmInput<T, WeightType, GemmOutputType, Ge
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig>
|
||||
MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getConfigs() const
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig> MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getConfigs(
|
||||
bool supports_finalize_fusion) const
|
||||
{
|
||||
return getConfigs(sm_);
|
||||
return getConfigs(sm_, supports_finalize_fusion);
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig> MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getConfigs(
|
||||
int sm)
|
||||
int sm, bool supports_finalize_fusion)
|
||||
{
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig> candidate_configs = getTmaWarpSpecializedConfigs(sm);
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig> candidate_configs
|
||||
= getTmaWarpSpecializedConfigs(sm, supports_finalize_fusion);
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig> ampere_configs = getAmpereConfigs(sm);
|
||||
std::copy(ampere_configs.begin(), ampere_configs.end(), std::back_inserter(candidate_configs));
|
||||
return candidate_configs;
|
||||
@ -521,7 +522,8 @@ MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getAmpereConfigs(int sm
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
|
||||
std::vector<cutlass_extensions::CutlassGemmConfig>
|
||||
MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getTmaWarpSpecializedConfigs(int sm)
|
||||
MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getTmaWarpSpecializedConfigs(
|
||||
int sm, bool supports_finalize_fusion)
|
||||
{
|
||||
using tensorrt_llm::cutlass_extensions::CutlassGemmConfig;
|
||||
static constexpr auto weight_only_flag
|
||||
@ -568,6 +570,17 @@ MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::getTmaWarpSpecializedCo
|
||||
= tensorrt_llm::kernels::cutlass_kernels::get_candidate_configs(100, max_split_k, config_type_param);
|
||||
std::copy(sm100_configs.begin(), sm100_configs.end(), std::back_inserter(tma_ws_configs));
|
||||
}
|
||||
if (supports_finalize_fusion)
|
||||
{
|
||||
// Duplicate the configs and set the epilogue fusion type to FINALIZE
|
||||
auto finalize_configs = tma_ws_configs;
|
||||
std::transform(finalize_configs.begin(), finalize_configs.end(), std::back_inserter(tma_ws_configs),
|
||||
[](auto& config)
|
||||
{
|
||||
config.epilogue_fusion_type = cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE;
|
||||
return config;
|
||||
});
|
||||
}
|
||||
return tma_ws_configs;
|
||||
}
|
||||
|
||||
@ -580,13 +593,11 @@ bool MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::isTmaWarpSpecializ
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
|
||||
bool MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::supportsTmaWarpSpecialized() const
|
||||
bool MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::supportsTmaWarpSpecialized(int sm)
|
||||
{
|
||||
return (sm_ == 90 && tensorrt_llm::kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType>())
|
||||
|| (sm_ >= 100 && sm_ < 120
|
||||
&& tensorrt_llm::kernels::cutlass_kernels::isValidBlackwellMOESpecialisation<T, WeightType>())
|
||||
|| ((sm_ == 120 || sm_ == 121)
|
||||
&& tensorrt_llm::kernels::cutlass_kernels::isValidSM120MOESpecialisation<T, WeightType>());
|
||||
return (sm == 90 && kernels::cutlass_kernels::isValidHopperMOESpecialisation<T, WeightType>())
|
||||
|| (sm >= 100 && sm < 120 && kernels::cutlass_kernels::isValidBlackwellMOESpecialisation<T, WeightType>())
|
||||
|| ((sm == 120 || sm == 121) && kernels::cutlass_kernels::isValidSM120MOESpecialisation<T, WeightType>());
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename OutputType, typename ScaleBiasType>
|
||||
@ -833,7 +844,9 @@ size_t MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>::calcMaxWorkspace
|
||||
if constexpr (kernels::cutlass_kernels::isValidTmaWarpSpecializedMOESpecialisation<T, WeightType>() && !use_w4afp8
|
||||
&& !use_wfp4a16)
|
||||
{
|
||||
auto configs = getTmaWarpSpecializedConfigs(sm_);
|
||||
// Finalize fusion may not actually be supported by the kernel,
|
||||
// if they are not we will catch the error and skip them
|
||||
auto configs = getTmaWarpSpecializedConfigs(sm_, true);
|
||||
auto fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE;
|
||||
if constexpr (use_wfp4afp8)
|
||||
{
|
||||
|
||||
@ -2847,9 +2847,10 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
|
||||
expert_first_token_offset_ = getWsPtr(int64_t{}, "expert_first_token_offset");
|
||||
|
||||
// We check if the provided config uses fused finalize and disable it if it does not
|
||||
bool const gemm2_using_tma_ws = moe_gemm_runner_.isTmaWarpSpecialized(*gemm2_config_);
|
||||
bool gemm2_using_finalize_fusion
|
||||
= gemm2_config_->epilogue_fusion_type == cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE;
|
||||
permuted_token_final_scales_
|
||||
= (gemm2_using_tma_ws && mayHaveFinalizeFused()) ? getWsPtr(float{}, "permuted_token_final_scales") : nullptr;
|
||||
= gemm2_using_finalize_fusion ? getWsPtr(float{}, "permuted_token_final_scales") : nullptr;
|
||||
|
||||
bool const is_gated_activation = isGatedActivation(activation_type);
|
||||
bool const gemm1_using_fused_moe
|
||||
@ -4006,8 +4007,12 @@ CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>::
|
||||
|
||||
bool apply_bias = parallelism_config.tp_rank == 0;
|
||||
auto* fc2_bias = apply_bias ? fc2_expert_biases : nullptr;
|
||||
bool gemm2_using_finalize_fusion = gemm2_config_->epilogue_fusion_type
|
||||
== cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE;
|
||||
bool using_fused_finalize
|
||||
= use_fused_finalize_ && gemm2_config_->sm_version >= 90 && !use_w4_groupwise && !use_lora;
|
||||
= use_fused_finalize_ && gemm2_using_finalize_fusion && !use_w4_groupwise && !use_lora;
|
||||
TLLM_CHECK_WITH_INFO(using_fused_finalize == gemm2_using_finalize_fusion,
|
||||
"GEMM2 tactic requests finalize fusion, but the runner is not configured to use it");
|
||||
if (using_fused_finalize)
|
||||
{
|
||||
assert(min_latency_mode == false);
|
||||
@ -4550,14 +4555,26 @@ void GemmProfilerBackend::prepareQuantParams(int num_tokens, char* workspace_ptr
|
||||
}
|
||||
}
|
||||
|
||||
void GemmProfilerBackend::prepareTmaWsInputs(
|
||||
int num_tokens, char* workspace_ptr_char, void const* expert_weights, cudaStream_t stream)
|
||||
void GemmProfilerBackend::prepareTmaWsInputs(int num_tokens, char* workspace_ptr_char, void const* expert_weights,
|
||||
TmaWarpSpecializedGroupedGemmInput::EpilogueFusion fusion, cudaStream_t stream)
|
||||
{
|
||||
if (mSM < 90)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
bool use_w4afp8 = (mDType == nvinfer1::DataType::kFP8 && mWType == nvinfer1::DataType::kINT4);
|
||||
bool use_wfp4a16 = ((mDType == nvinfer1::DataType::kHALF || mDType == nvinfer1::DataType::kBF16)
|
||||
&& mWType == nvinfer1::DataType::kUINT8);
|
||||
bool use_w4_groupwise = use_w4afp8 || use_wfp4a16;
|
||||
bool const use_finalize_fusion = fusion == TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE;
|
||||
bool const finalize_fusion_not_supported = !mInterface->use_fused_finalize_ || mMinLatencyMode || use_w4_groupwise
|
||||
|| mGemmToProfile != GemmToProfile::GEMM_2;
|
||||
if (use_finalize_fusion && finalize_fusion_not_supported)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
auto workspaces = getProfilerWorkspaces(num_tokens, mSM >= 90);
|
||||
|
||||
#define GET_WS_PTR(type, name) \
|
||||
@ -4596,28 +4613,24 @@ void GemmProfilerBackend::prepareTmaWsInputs(
|
||||
size_t num_expanded_tokens = num_tokens * mK;
|
||||
for (int64_t i = 0; i < NUM_ROUTING_SAMPLES; i++)
|
||||
{
|
||||
mTmaInputCache[i].configureWorkspace(tma_ws_input_workspace, mNumExpertsPerNode, gemm_workspace,
|
||||
// Note: Even though we have separate TMA WS inputs for finalize fusion on/off we reuse the same pointers to
|
||||
// save space.
|
||||
auto& cache_element = mTmaInputCache[i][use_finalize_fusion];
|
||||
cache_element.configureWorkspace(tma_ws_input_workspace, mNumExpertsPerNode, gemm_workspace,
|
||||
workspaces.at("gemm_workspace").first, mScalingType);
|
||||
tma_ws_input_workspace += tma_ws_size;
|
||||
|
||||
int64_t* expert_first_token_offset = expert_first_token_offset_base + i * (mNumExpertsPerNode + 1);
|
||||
int* permuted_row_to_unpermuted_row = permuted_row_to_unpermuted_row_base + i * num_expanded_tokens;
|
||||
|
||||
auto& gemm1_tma_ws_input = mGemmToProfile == GemmToProfile::GEMM_1 ? mTmaInputCache[i] : dummy_tma_ws_input;
|
||||
auto& gemm2_tma_ws_input = mGemmToProfile == GemmToProfile::GEMM_2 ? mTmaInputCache[i] : dummy_tma_ws_input;
|
||||
auto& gemm1_tma_ws_input = mGemmToProfile == GemmToProfile::GEMM_1 ? cache_element : dummy_tma_ws_input;
|
||||
auto& gemm2_tma_ws_input = mGemmToProfile == GemmToProfile::GEMM_2 ? cache_element : dummy_tma_ws_input;
|
||||
if (mSM >= 90)
|
||||
{
|
||||
/* GEMM1 */
|
||||
gemm1_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE;
|
||||
gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE;
|
||||
|
||||
bool use_w4afp8 = (mDType == nvinfer1::DataType::kFP8 && mWType == nvinfer1::DataType::kINT4);
|
||||
bool use_wfp4a16 = ((mDType == nvinfer1::DataType::kHALF || mDType == nvinfer1::DataType::kBF16)
|
||||
&& mWType == nvinfer1::DataType::kUINT8);
|
||||
bool use_w4_groupwise = use_w4afp8 || use_wfp4a16;
|
||||
bool using_fused_finalize
|
||||
= mInterface->use_fused_finalize_ && mSM >= 90 && !mMinLatencyMode && !use_w4_groupwise;
|
||||
if (using_fused_finalize)
|
||||
if (use_finalize_fusion)
|
||||
{
|
||||
assert(!mMinLatencyMode);
|
||||
gemm2_tma_ws_input.fusion = TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE;
|
||||
@ -4652,7 +4665,6 @@ void GemmProfilerBackend::prepareTmaWsInputs(
|
||||
void GemmProfilerBackend::prepare(
|
||||
int num_tokens, char* workspace_ptr_char, void const* expert_weights, cudaStream_t stream)
|
||||
{
|
||||
mAllTacticsSaved = mInterface->getTactics();
|
||||
mSampleIndex = 0;
|
||||
|
||||
auto workspace_size = getWorkspaceSize(num_tokens);
|
||||
@ -4660,7 +4672,10 @@ void GemmProfilerBackend::prepare(
|
||||
|
||||
prepareRouting(num_tokens, workspace_ptr_char, stream);
|
||||
prepareQuantParams(num_tokens, workspace_ptr_char, stream);
|
||||
prepareTmaWsInputs(num_tokens, workspace_ptr_char, expert_weights, stream);
|
||||
prepareTmaWsInputs(num_tokens, workspace_ptr_char, expert_weights,
|
||||
TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::NONE, stream);
|
||||
prepareTmaWsInputs(num_tokens, workspace_ptr_char, expert_weights,
|
||||
TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE, stream);
|
||||
}
|
||||
|
||||
size_t GemmProfilerBackend::getWorkspaceSize(int maxM)
|
||||
@ -4724,7 +4739,9 @@ void GemmProfilerBackend::runProfiler(int original_num_tokens, Config const& tac
|
||||
TmaWarpSpecializedGroupedGemmInput tma_ws_input_template;
|
||||
if (tactic.is_tma_warp_specialized)
|
||||
{
|
||||
tma_ws_input_template = mTmaInputCache[mSampleIndex];
|
||||
tma_ws_input_template = mTmaInputCache[mSampleIndex][tactic.epilogue_fusion_type
|
||||
== cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE];
|
||||
TLLM_CHECK_WITH_INFO(tma_ws_input_template.isValid(), "TMA WS input template is not initialized");
|
||||
}
|
||||
|
||||
mInterface->is_profiler = true;
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:86586b9f6845e91e8ba0accad53a5a3418c50d8fd30ad49fa8837470c72b5dcf
|
||||
size 67051604
|
||||
oid sha256:d6a3f6adef11003f794a6cec1235d0c622ead71b4e801a89866e91dfd91bb30c
|
||||
size 67053244
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
568cb6ca2413c93b0f5839dd05577c0c57bc4b5f2359366c79d0ace665de4bd6 libtensorrt_llm_internal_cutlass_kernels_static.a
|
||||
commit 9c0a42825905952beaf9b35d5a35d58de1a123fa
|
||||
317a25037093a6f3d156ffa58a68bce53071ef68dacdcb04cc0aaeea80b64e76 libtensorrt_llm_internal_cutlass_kernels_static.a
|
||||
commit 444ef1b3b06cdc7ee66b4e612ce26ad25967440b
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6489751f16a4dadf42664738ded03fbbd60195619f2d5f80af8190554318257d
|
||||
size 66872936
|
||||
oid sha256:489fb557b78062efedd1514f2995fafb9216bb0e0068a550e86763efb9d5eee9
|
||||
size 66874608
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
813c237a565664b2acf2313f0e436f66f24deeb16a84d273dc007af55795e55f libtensorrt_llm_internal_cutlass_kernels_static.a
|
||||
commit 9c0a42825905952beaf9b35d5a35d58de1a123fa
|
||||
5a31acd0fb1415196bff71fa4a8d1dded147e15ea10821cc46c85684c66986ee libtensorrt_llm_internal_cutlass_kernels_static.a
|
||||
commit 444ef1b3b06cdc7ee66b4e612ce26ad25967440b
|
||||
|
||||
205
cpp/tensorrt_llm/kernels/moeTopKFuncs.cuh
Normal file
205
cpp/tensorrt_llm/kernels/moeTopKFuncs.cuh
Normal file
@ -0,0 +1,205 @@
|
||||
|
||||
/*
|
||||
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
#ifndef TRTLLM_MOETOPKFUNCS_CUH_H
|
||||
#define TRTLLM_MOETOPKFUNCS_CUH_H
|
||||
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <cub/cub.cuh>
|
||||
|
||||
#include "tensorrt_llm/kernels/archCondition.h"
|
||||
|
||||
namespace tensorrt_llm::kernels
|
||||
{
|
||||
|
||||
namespace reduce_topk
|
||||
{
|
||||
namespace cg = cooperative_groups;
|
||||
static constexpr int kWARP_SIZE = 32;
|
||||
static constexpr bool kTLLM_GEN_HAS_FAST_REDUX = tensorrt_llm::kernels::arch::is_major_v<10>;
|
||||
|
||||
template <typename T_>
|
||||
struct TopKRedType
|
||||
{
|
||||
using T = T_;
|
||||
static_assert(std::is_same_v<T, float> || std::is_same_v<T, half> || std::is_same_v<T, __nv_bfloat16>
|
||||
|| std::is_same_v<T, int>,
|
||||
"Top K reduction only implemented for int, float, float16 and bfloat16");
|
||||
|
||||
using TypeCmp = std::conditional_t<sizeof(T) == 4, uint64_t, uint32_t>;
|
||||
using IdxT = std::conditional_t<sizeof(T) == 4, int32_t, int16_t>;
|
||||
|
||||
static constexpr int kMoveBits = (sizeof(T) == 4) ? 32 : 16;
|
||||
static constexpr int kMaxIdx = 65535;
|
||||
TypeCmp compValIdx;
|
||||
|
||||
static __host__ __device__ inline TypeCmp makeCmpVal(T val, int32_t idx = 0)
|
||||
{
|
||||
auto valueBits = cub::Traits<T>::TwiddleIn(reinterpret_cast<typename cub::Traits<T>::UnsignedBits&>(val));
|
||||
TypeCmp compactTmp = reinterpret_cast<TypeCmp&>(valueBits);
|
||||
compactTmp = (compactTmp << kMoveBits) | (0xFFFF & (kMaxIdx - idx));
|
||||
// Use 65535 minus idx to give higher priority to elements with smaller indices.
|
||||
return compactTmp;
|
||||
}
|
||||
|
||||
static __host__ __device__ void unpack(T& value, int32_t& index, TypeCmp cmp)
|
||||
{
|
||||
// Since “65535-idx” is always smaller than 65536 and positive, we can directly use it as the lower 16 bits
|
||||
index = kMaxIdx - static_cast<int32_t>((cmp & 0xFFFF));
|
||||
|
||||
auto compactTmp = cmp >> kMoveBits;
|
||||
auto valueBits
|
||||
= cub::Traits<T>::TwiddleOut(reinterpret_cast<typename cub::Traits<T>::UnsignedBits&>(compactTmp));
|
||||
value = reinterpret_cast<T&>(valueBits);
|
||||
}
|
||||
|
||||
__host__ __device__ TopKRedType() = default;
|
||||
|
||||
__host__ __device__ TopKRedType(T val, int32_t idx)
|
||||
: compValIdx(makeCmpVal(val, idx))
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ operator TypeCmp() const noexcept
|
||||
{
|
||||
return compValIdx;
|
||||
}
|
||||
|
||||
__device__ inline TypeCmp reduce(cg::thread_block_tile<kWARP_SIZE> const& warp)
|
||||
{
|
||||
if constexpr (!kTLLM_GEN_HAS_FAST_REDUX || sizeof(TypeCmp) == 8)
|
||||
{
|
||||
return cg::reduce(warp, compValIdx, cg::greater<TypeCmp>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
TypeCmp result;
|
||||
asm("redux.sync.max.u32 %0, %1, 0xffffffff;\n" : "=r"(result) : "r"(compValIdx));
|
||||
return result;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int K_, bool Enable_>
|
||||
struct TopKIdx
|
||||
{
|
||||
// by default, empty
|
||||
};
|
||||
|
||||
template <int K_>
|
||||
struct TopKIdx<K_, true>
|
||||
{
|
||||
static constexpr int K = K_;
|
||||
int32_t val[K];
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define TOPK_SWAP(I, J) \
|
||||
{ \
|
||||
auto pairMin = min(topK[I].compValIdx, topK[J].compValIdx); \
|
||||
auto pairMax = max(topK[I].compValIdx, topK[J].compValIdx); \
|
||||
topK[I].compValIdx = pairMax; \
|
||||
topK[J].compValIdx = pairMin; \
|
||||
}
|
||||
|
||||
template <int N, typename RedType>
|
||||
struct Sort;
|
||||
|
||||
template <typename RedType>
|
||||
struct Sort<1, RedType>
|
||||
{
|
||||
static __device__ void run(RedType* topK) {}
|
||||
};
|
||||
|
||||
template <typename RedType>
|
||||
struct Sort<2, RedType>
|
||||
{
|
||||
static __device__ void run(RedType* topK)
|
||||
{
|
||||
TOPK_SWAP(0, 1);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename RedType>
|
||||
struct Sort<3, RedType>
|
||||
{
|
||||
static __device__ void run(RedType* topK)
|
||||
{
|
||||
TOPK_SWAP(0, 1);
|
||||
TOPK_SWAP(1, 2);
|
||||
TOPK_SWAP(0, 1);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename RedType>
|
||||
struct Sort<4, RedType>
|
||||
{
|
||||
static __device__ void run(RedType* topK)
|
||||
{
|
||||
TOPK_SWAP(0, 2);
|
||||
TOPK_SWAP(1, 3);
|
||||
TOPK_SWAP(0, 1);
|
||||
TOPK_SWAP(2, 3);
|
||||
TOPK_SWAP(1, 2);
|
||||
}
|
||||
};
|
||||
|
||||
template <int K, typename Type, int N, bool IsSorted = false>
|
||||
__device__ void reduceTopK(cg::thread_block_tile<kWARP_SIZE> const& warp, Type (&out)[K], int32_t (&outIdx)[K],
|
||||
Type (&value)[N], int32_t (&idx)[N], Type minValue)
|
||||
{
|
||||
static_assert(K > 0, "Top K must have K > 0");
|
||||
static_assert(K < kWARP_SIZE, "Top K must have K < kWARP_SIZE");
|
||||
static_assert(N > 0, "Top K must have N > 0");
|
||||
static_assert(N < 5, "Only support candidates number less than or equal to 128");
|
||||
using RedType = TopKRedType<Type>;
|
||||
RedType topK[N];
|
||||
#pragma unroll
|
||||
for (int nn = 0; nn < N; ++nn)
|
||||
{
|
||||
topK[nn] = RedType{value[nn], idx[nn]};
|
||||
}
|
||||
|
||||
if constexpr (!IsSorted)
|
||||
{
|
||||
Sort<N, RedType>::run(topK);
|
||||
}
|
||||
typename RedType::TypeCmp packedMax{};
|
||||
#pragma unroll
|
||||
for (int kk = 0; kk < K; ++kk)
|
||||
{
|
||||
bool update = kk > 0 && packedMax == topK[0].compValIdx;
|
||||
#pragma unroll
|
||||
for (int nn = 0; nn < N; ++nn)
|
||||
{
|
||||
topK[nn] = update && nn == N - 1 ? RedType{minValue, idx[nn]} : update ? topK[nn + 1] : topK[nn];
|
||||
}
|
||||
// get the next largest value
|
||||
packedMax = topK[0].reduce(warp);
|
||||
RedType::unpack(out[kk], outIdx[kk], packedMax);
|
||||
}
|
||||
};
|
||||
|
||||
#undef TOPK_SWAP
|
||||
|
||||
} // namespace reduce_topk
|
||||
} // namespace tensorrt_llm::kernels
|
||||
#endif // TRTLLM_MOETOPKFUNCS_CUH_H
|
||||
@ -1,376 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
|
||||
#include "tensorrt_llm/common/envUtils.h"
|
||||
#include "tensorrt_llm/kernels/archCondition.h"
|
||||
#include "tensorrt_llm/kernels/renormMoeRoutingKernels.h"
|
||||
#include <climits> // For INT_MAX
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <cub/cub.cuh>
|
||||
#include <cuda/std/limits> // For numeric_limits
|
||||
#include <math.h>
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
using namespace tensorrt_llm::common;
|
||||
|
||||
namespace tensorrt_llm::kernels
|
||||
{
|
||||
|
||||
static constexpr int BLOCK_SIZE = 1024;
|
||||
static constexpr int WARP_SIZE = 32;
|
||||
static constexpr int WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE;
|
||||
|
||||
namespace reduce_topk
|
||||
{
|
||||
|
||||
static constexpr bool TLLM_GEN_HAS_FAST_REDUX = tensorrt_llm::kernels::arch::is_major_v<10>;
|
||||
|
||||
template <typename T_>
|
||||
struct TopKRedType
|
||||
{
|
||||
using T = T_;
|
||||
static_assert(std::is_same_v<T, float> || std::is_same_v<T, half> || std::is_same_v<T, __nv_bfloat16>,
|
||||
"Top K reduction only implemented for float, float16 and bfloat16");
|
||||
|
||||
using TypeCmp = std::conditional_t<sizeof(T) == 4, uint64_t, uint32_t>;
|
||||
using IdxT = std::conditional_t<sizeof(T) == 4, int32_t, int16_t>;
|
||||
static constexpr int moveBits = (sizeof(T) == 4) ? 32 : 16;
|
||||
static constexpr int maxIdx = 65535;
|
||||
TypeCmp compValIdx;
|
||||
|
||||
static __host__ __device__ inline TypeCmp makeCmpVal(T val, int32_t idx = 0)
|
||||
{
|
||||
auto valueBits = cub::Traits<T>::TwiddleIn(reinterpret_cast<typename cub::Traits<T>::UnsignedBits&>(val));
|
||||
TypeCmp compactTmp = reinterpret_cast<TypeCmp&>(valueBits);
|
||||
compactTmp = (compactTmp << moveBits) | (0xFFFF & (maxIdx - idx));
|
||||
// Use 65535 minus idx to give higher priority to elements with smaller indices.
|
||||
return compactTmp;
|
||||
}
|
||||
|
||||
static __host__ __device__ void unpack(T& value, int32_t& index, TypeCmp cmp)
|
||||
{
|
||||
// Since “65535-idx” is always smaller than 65536 and positive, we can directly use it as the lower 16 bits
|
||||
index = maxIdx - static_cast<int32_t>((cmp & 0xFFFF));
|
||||
|
||||
auto compactTmp = cmp >> moveBits;
|
||||
auto valueBits
|
||||
= cub::Traits<T>::TwiddleOut(reinterpret_cast<typename cub::Traits<T>::UnsignedBits&>(compactTmp));
|
||||
value = reinterpret_cast<T&>(valueBits);
|
||||
}
|
||||
|
||||
__host__ __device__ TopKRedType() = default;
|
||||
|
||||
__host__ __device__ TopKRedType(T val, int32_t idx)
|
||||
: compValIdx(makeCmpVal(val, idx))
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ operator TypeCmp() const noexcept
|
||||
{
|
||||
return compValIdx;
|
||||
}
|
||||
|
||||
__device__ inline TypeCmp reduce(cg::thread_block_tile<WARP_SIZE> const& warp)
|
||||
{
|
||||
if constexpr (!TLLM_GEN_HAS_FAST_REDUX || sizeof(TypeCmp) == 8)
|
||||
{
|
||||
return cg::reduce(warp, compValIdx, cg::greater<TypeCmp>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
TypeCmp result;
|
||||
asm("redux.sync.max.u32 %0, %1, 0xffffffff;\n" : "=r"(result) : "r"(compValIdx));
|
||||
return result;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int K_, bool Enable_>
|
||||
struct TopKIdx
|
||||
{
|
||||
// by default, empty
|
||||
};
|
||||
|
||||
template <int K_>
|
||||
struct TopKIdx<K_, true>
|
||||
{
|
||||
static constexpr int K = K_;
|
||||
int32_t val[K];
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define TOPK_SWAP(I, J) \
|
||||
{ \
|
||||
auto pairMin = min(topK[I].compValIdx, topK[J].compValIdx); \
|
||||
auto pairMax = max(topK[I].compValIdx, topK[J].compValIdx); \
|
||||
topK[I].compValIdx = pairMax; \
|
||||
topK[J].compValIdx = pairMin; \
|
||||
}
|
||||
|
||||
template <int N, typename RedType>
|
||||
struct Sort;
|
||||
|
||||
template <typename RedType>
|
||||
struct Sort<1, RedType>
|
||||
{
|
||||
static __device__ void run(RedType* topK) {}
|
||||
};
|
||||
|
||||
template <typename RedType>
|
||||
struct Sort<2, RedType>
|
||||
{
|
||||
static __device__ void run(RedType* topK)
|
||||
{
|
||||
TOPK_SWAP(0, 1);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename RedType>
|
||||
struct Sort<3, RedType>
|
||||
{
|
||||
static __device__ void run(RedType* topK)
|
||||
{
|
||||
TOPK_SWAP(0, 1);
|
||||
TOPK_SWAP(1, 2);
|
||||
TOPK_SWAP(0, 1);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename RedType>
|
||||
struct Sort<4, RedType>
|
||||
{
|
||||
static __device__ void run(RedType* topK)
|
||||
{
|
||||
TOPK_SWAP(0, 2);
|
||||
TOPK_SWAP(1, 3);
|
||||
TOPK_SWAP(0, 1);
|
||||
TOPK_SWAP(2, 3);
|
||||
TOPK_SWAP(1, 2);
|
||||
}
|
||||
};
|
||||
|
||||
template <int K, typename Type, int N, bool IsSorted = false>
|
||||
__device__ void reduceTopK(cg::thread_block_tile<WARP_SIZE> const& warp, Type (&out)[K], int32_t (&outIdx)[K],
|
||||
Type (&value)[N], int32_t (&idx)[N], Type minValue)
|
||||
{
|
||||
static_assert(K > 0, "Top K must have K > 0");
|
||||
static_assert(K < WARP_SIZE, "Top K must have K < WARP_SIZE");
|
||||
static_assert(N > 0, "Top K must have N > 0");
|
||||
static_assert(N < 5, "Only support candidates number less than or equal to 128");
|
||||
using RedType = TopKRedType<Type>;
|
||||
RedType topK[N];
|
||||
#pragma unroll
|
||||
for (int nn = 0; nn < N; ++nn)
|
||||
{
|
||||
topK[nn] = RedType{value[nn], idx[nn]};
|
||||
}
|
||||
|
||||
if constexpr (!IsSorted)
|
||||
{
|
||||
Sort<N, RedType>::run(topK);
|
||||
}
|
||||
typename RedType::TypeCmp packedMax{};
|
||||
#pragma unroll
|
||||
for (int kk = 0; kk < K; ++kk)
|
||||
{
|
||||
bool update = kk > 0 && packedMax == topK[0].compValIdx;
|
||||
#pragma unroll
|
||||
for (int nn = 0; nn < N; ++nn)
|
||||
{
|
||||
topK[nn] = update && nn == N - 1 ? RedType{minValue, idx[nn]} : update ? topK[nn + 1] : topK[nn];
|
||||
}
|
||||
// get the next largest value
|
||||
packedMax = topK[0].reduce(warp);
|
||||
RedType::unpack(out[kk], outIdx[kk], packedMax);
|
||||
}
|
||||
};
|
||||
|
||||
#undef TOPK_SWAP
|
||||
|
||||
} // end of namespace reduce_topk
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T>
|
||||
__device__ T calcSoftmax(cg::thread_block_tile<WARP_SIZE> const& warp, T score, int32_t laneIdx, int32_t NumTopExperts)
|
||||
{
|
||||
T maxScore = T{-INFINITY};
|
||||
if (laneIdx < NumTopExperts)
|
||||
{
|
||||
maxScore = score >= maxScore ? score : maxScore;
|
||||
}
|
||||
maxScore = cg::reduce(warp, maxScore, cg::greater<T>());
|
||||
|
||||
float sumScore = float{0.f};
|
||||
float newScore;
|
||||
// Get the summation of scores for each token
|
||||
if (laneIdx < NumTopExperts)
|
||||
{
|
||||
newScore = static_cast<float>(score) - static_cast<float>(maxScore);
|
||||
newScore = static_cast<float>(exp(newScore));
|
||||
sumScore += newScore;
|
||||
}
|
||||
sumScore = cg::reduce(warp, sumScore, cg::plus<float>());
|
||||
|
||||
if (laneIdx < NumTopExperts)
|
||||
{
|
||||
score = static_cast<T>(newScore / sumScore);
|
||||
}
|
||||
|
||||
return score;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename InputT, typename OutputT, typename IdxT, int MaxNumExperts, int MaxNumTopExperts>
|
||||
__global__ void renormMoeRoutingKernel(InputT* routerLogits, OutputT* topkValues, IdxT* topkIndices,
|
||||
int32_t const numTokens, int32_t const numExperts, int32_t const topK)
|
||||
{
|
||||
|
||||
uint32_t const blockRank = blockIdx.x;
|
||||
uint32_t const tIdx = BLOCK_SIZE * blockRank + threadIdx.x;
|
||||
uint32_t const warpIdx = tIdx / WARP_SIZE;
|
||||
uint32_t const laneIdx = tIdx % WARP_SIZE;
|
||||
uint32_t const warpNum = gridDim.x * WARPS_PER_BLOCK;
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<WARP_SIZE>(block);
|
||||
|
||||
InputT minScore = InputT{-INFINITY};
|
||||
for (uint32_t tokenId = warpIdx; tokenId < numTokens; tokenId += warpNum)
|
||||
{
|
||||
auto scoreOffset = tokenId * numExperts;
|
||||
auto outputOffset = tokenId * topK;
|
||||
InputT inputScore[MaxNumExperts / WARP_SIZE];
|
||||
IdxT inputIndex[MaxNumExperts / WARP_SIZE];
|
||||
|
||||
InputT warpTopKScore[MaxNumTopExperts];
|
||||
IdxT warpTopKExpertIdx[MaxNumTopExperts];
|
||||
|
||||
// Load scores and indices for this warp
|
||||
for (uint32_t i = 0; i < MaxNumExperts / WARP_SIZE; ++i)
|
||||
{
|
||||
auto expertIdx = i * WARP_SIZE + laneIdx;
|
||||
inputScore[i]
|
||||
= expertIdx < numExperts ? static_cast<InputT>(routerLogits[scoreOffset + expertIdx]) : minScore;
|
||||
inputIndex[i] = expertIdx;
|
||||
}
|
||||
|
||||
// Reduce topK scores and indices for this warp
|
||||
reduce_topk::reduceTopK(warp, warpTopKScore, warpTopKExpertIdx, inputScore, inputIndex, minScore);
|
||||
|
||||
// Perform softmax on topK scores
|
||||
auto score = calcSoftmax(warp,
|
||||
laneIdx < topK ? static_cast<float>(warpTopKScore[laneIdx]) : static_cast<float>(minScore), laneIdx, topK);
|
||||
if (laneIdx < topK)
|
||||
{
|
||||
topkValues[outputOffset + laneIdx] = static_cast<OutputT>(score);
|
||||
topkIndices[outputOffset + laneIdx] = warpTopKExpertIdx[laneIdx];
|
||||
}
|
||||
} // end for tokenId
|
||||
}
|
||||
|
||||
int nextPowerOfTwo(int num)
|
||||
{
|
||||
if (num <= 0)
|
||||
{
|
||||
return 1; // Handle invalid input
|
||||
}
|
||||
int power = 1;
|
||||
while (power < num)
|
||||
{
|
||||
// Check for overflow before shifting
|
||||
if (power > INT_MAX / 2)
|
||||
{
|
||||
return power;
|
||||
}
|
||||
power <<= 1;
|
||||
}
|
||||
return power;
|
||||
}
|
||||
|
||||
#define CASE(MAX_NUM_EXPERTS) \
|
||||
case MAX_NUM_EXPERTS: \
|
||||
switch (maxNumTopExperts) \
|
||||
{ \
|
||||
case 1: kernelInstance = &renormMoeRoutingKernel<InputT, OutputT, IdxT, MAX_NUM_EXPERTS, 1>; break; \
|
||||
case 2: kernelInstance = &renormMoeRoutingKernel<InputT, OutputT, IdxT, MAX_NUM_EXPERTS, 2>; break; \
|
||||
case 4: kernelInstance = &renormMoeRoutingKernel<InputT, OutputT, IdxT, MAX_NUM_EXPERTS, 4>; break; \
|
||||
case 8: kernelInstance = &renormMoeRoutingKernel<InputT, OutputT, IdxT, MAX_NUM_EXPERTS, 8>; break; \
|
||||
default: kernelInstance = nullptr; break; \
|
||||
} \
|
||||
break;
|
||||
|
||||
template <typename InputT, typename OutputT, typename IdxT>
|
||||
void invokeRenormMoeRouting(InputT* routerLogits, OutputT* topkValues, IdxT* topkIndices, int64_t const numTokens,
|
||||
int64_t const numExperts, int64_t const topK, cudaStream_t const stream)
|
||||
{
|
||||
|
||||
const uint32_t maxNumBlocks = 1024;
|
||||
const uint32_t numBlocks = std::min(static_cast<uint32_t>((numTokens - 1) / WARPS_PER_BLOCK + 1), maxNumBlocks);
|
||||
|
||||
uint32_t maxNumExperts = nextPowerOfTwo(numExperts) < 32 ? 32 : nextPowerOfTwo(numExperts);
|
||||
uint32_t maxNumTopExperts = nextPowerOfTwo(topK);
|
||||
|
||||
auto* kernelInstance = &renormMoeRoutingKernel<InputT, OutputT, IdxT, 128, 8>;
|
||||
|
||||
switch (maxNumExperts)
|
||||
{
|
||||
CASE(32)
|
||||
CASE(64)
|
||||
CASE(96)
|
||||
CASE(128)
|
||||
default: kernelInstance = nullptr; break;
|
||||
}
|
||||
|
||||
if (kernelInstance == nullptr)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(kernelInstance != nullptr, "Can not find corresponding kernel instance.");
|
||||
}
|
||||
|
||||
dim3 renormMoeRoutingGridDim(numBlocks);
|
||||
dim3 renormMoeRoutingBlockDim(BLOCK_SIZE);
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = renormMoeRoutingGridDim;
|
||||
config.blockDim = renormMoeRoutingBlockDim;
|
||||
config.dynamicSmemBytes = 0;
|
||||
config.stream = stream;
|
||||
cudaLaunchAttribute attrs[1];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
cudaLaunchKernelEx(&config, kernelInstance, routerLogits, topkValues, topkIndices, static_cast<int32_t>(numTokens),
|
||||
static_cast<int32_t>(numExperts), static_cast<int32_t>(topK));
|
||||
sync_check_cuda_error(stream);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_RENORM_MOE_ROUTING(InputT, OutputT, IdxT) \
|
||||
template void invokeRenormMoeRouting<InputT, OutputT, IdxT>(InputT * routerLogits, OutputT * topkValues, \
|
||||
IdxT * topkIndices, int64_t const numTokens, int64_t const numExperts, int64_t const topK, \
|
||||
cudaStream_t const stream);
|
||||
|
||||
INSTANTIATE_RENORM_MOE_ROUTING(float, float, int32_t);
|
||||
INSTANTIATE_RENORM_MOE_ROUTING(half, float, int32_t);
|
||||
#ifdef ENABLE_BF16
|
||||
INSTANTIATE_RENORM_MOE_ROUTING(__nv_bfloat16, float, int32_t);
|
||||
#endif
|
||||
|
||||
} // namespace tensorrt_llm::kernels
|
||||
@ -22,11 +22,17 @@
|
||||
*/
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
#include "moeTopKFuncs.cuh"
|
||||
#include "topkLastDim.h"
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <cub/cub.cuh>
|
||||
#include <cuda/atomic>
|
||||
#include <cuda/std/limits>
|
||||
#include <limits>
|
||||
#include <thrust/iterator/counting_iterator.h>
|
||||
#include <thrust/iterator/transform_iterator.h>
|
||||
#include <type_traits>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
@ -203,12 +209,12 @@ __host__ __device__ IdxT calc_buf_len(IdxT len)
|
||||
* @param len the number of elements to read
|
||||
* @param f the lambda taking two arguments (T x, IdxT idx)
|
||||
*/
|
||||
template <typename T, typename idxT, typename Func>
|
||||
__device__ void vectorized_process(size_t thread_rank, size_t num_threads, T const* in, idxT len, Func f)
|
||||
template <typename T, typename IdxT, typename Func>
|
||||
__device__ void vectorized_process(size_t thread_rank, size_t num_threads, T const* in, IdxT len, Func f)
|
||||
{
|
||||
if constexpr (sizeof(T) >= sizeof(WideT))
|
||||
{
|
||||
for (idxT i = thread_rank; i < len; i += num_threads)
|
||||
for (IdxT i = thread_rank; i < len; i += num_threads)
|
||||
{
|
||||
f(in[i], i);
|
||||
}
|
||||
@ -233,12 +239,12 @@ __device__ void vectorized_process(size_t thread_rank, size_t num_threads, T con
|
||||
skip_cnt = len;
|
||||
}
|
||||
WideT const* in_cast = reinterpret_cast<decltype(in_cast)>(in + skip_cnt);
|
||||
const idxT len_cast = (len - skip_cnt) / items_per_scalar;
|
||||
const IdxT len_cast = (len - skip_cnt) / items_per_scalar;
|
||||
|
||||
for (idxT i = thread_rank; i < len_cast; i += num_threads)
|
||||
for (IdxT i = thread_rank; i < len_cast; i += num_threads)
|
||||
{
|
||||
wide.scalar = in_cast[i];
|
||||
const idxT real_i = skip_cnt + i * items_per_scalar;
|
||||
const IdxT real_i = skip_cnt + i * items_per_scalar;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < items_per_scalar; ++j)
|
||||
{
|
||||
@ -258,7 +264,7 @@ __device__ void vectorized_process(size_t thread_rank, size_t num_threads, T con
|
||||
// and so
|
||||
// len - (skip_cnt + len_cast * items_per_scalar) < items_per_scalar <= WARP_SIZE
|
||||
// no need to use loop
|
||||
const idxT remain_i = skip_cnt + len_cast * items_per_scalar + thread_rank;
|
||||
const IdxT remain_i = skip_cnt + len_cast * items_per_scalar + thread_rank;
|
||||
if (remain_i < len)
|
||||
{
|
||||
f(in[remain_i], remain_i);
|
||||
@ -267,14 +273,14 @@ __device__ void vectorized_process(size_t thread_rank, size_t num_threads, T con
|
||||
}
|
||||
|
||||
// sync_width should >= WARP_SIZE
|
||||
template <typename T, typename idxT, typename Func>
|
||||
__device__ void vectorized_process(T const* in, idxT len, Func f, int sync_width)
|
||||
template <typename T, typename IdxT, typename Func>
|
||||
__device__ void vectorized_process(T const* in, IdxT len, Func f, int sync_width)
|
||||
{
|
||||
const idxT stride = blockDim.x * gridDim.x;
|
||||
const idxT tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const IdxT stride = blockDim.x * gridDim.x;
|
||||
const IdxT tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if constexpr (sizeof(T) >= sizeof(WideT))
|
||||
{
|
||||
for (idxT i = tid; i < len; i += stride)
|
||||
for (IdxT i = tid; i < len; i += stride)
|
||||
{
|
||||
f(in[i], i, true);
|
||||
}
|
||||
@ -298,17 +304,17 @@ __device__ void vectorized_process(T const* in, idxT len, Func f, int sync_width
|
||||
skip_cnt = len;
|
||||
}
|
||||
WideT const* in_cast = reinterpret_cast<decltype(in_cast)>(in + skip_cnt);
|
||||
const idxT len_cast = (len - skip_cnt) / items_per_scalar;
|
||||
const IdxT len_cast = (len - skip_cnt) / items_per_scalar;
|
||||
|
||||
const idxT len_cast_for_sync = ((len_cast - 1) / sync_width + 1) * sync_width;
|
||||
for (idxT i = tid; i < len_cast_for_sync; i += stride)
|
||||
const IdxT len_cast_for_sync = ((len_cast - 1) / sync_width + 1) * sync_width;
|
||||
for (IdxT i = tid; i < len_cast_for_sync; i += stride)
|
||||
{
|
||||
bool valid = i < len_cast;
|
||||
if (valid)
|
||||
{
|
||||
wide.scalar = in_cast[i];
|
||||
}
|
||||
const idxT real_i = skip_cnt + i * items_per_scalar;
|
||||
const IdxT real_i = skip_cnt + i * items_per_scalar;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < items_per_scalar; ++j)
|
||||
{
|
||||
@ -325,7 +331,7 @@ __device__ void vectorized_process(T const* in, idxT len, Func f, int sync_width
|
||||
T value = valid ? in[tid] : T();
|
||||
f(value, tid, valid);
|
||||
|
||||
const idxT remain_i = skip_cnt + len_cast * items_per_scalar + tid;
|
||||
const IdxT remain_i = skip_cnt + len_cast * items_per_scalar + tid;
|
||||
valid = remain_i < len;
|
||||
value = valid ? in[remain_i] : T();
|
||||
f(value, remain_i, valid);
|
||||
@ -1166,6 +1172,77 @@ __global__ void radix_topk_one_block_kernel(T const* in, IdxT const* in_idx, con
|
||||
} // namespace air_topk_stable
|
||||
|
||||
//}
|
||||
namespace moe_topk
|
||||
{
|
||||
namespace cg = cooperative_groups;
|
||||
static constexpr int kBLOCK_SIZE = 1024;
|
||||
static constexpr int kWARP_SIZE = 32;
|
||||
static constexpr int kWARPS_PER_BLOCK = kBLOCK_SIZE / kWARP_SIZE;
|
||||
|
||||
template <typename T>
|
||||
__device__ T negativeInfinity()
|
||||
{
|
||||
return -INFINITY;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ half negativeInfinity<half>()
|
||||
{
|
||||
return -CUDART_INF_FP16;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __nv_bfloat16 negativeInfinity<__nv_bfloat16>()
|
||||
{
|
||||
return -CUDART_INF_BF16;
|
||||
}
|
||||
|
||||
/****************TopK kernel for candidate number<= 128 and K <= 8 **************** */
|
||||
template <typename InputT, typename OutputT, typename IdxT, int MaxLen, int MaxTopK>
|
||||
__global__ void moe_topk_kernel(
|
||||
InputT const* in, OutputT* out, IdxT* outIdx, int32_t const batchSize, int32_t const len, int32_t const topK)
|
||||
{
|
||||
|
||||
uint32_t const blockRank = blockIdx.x;
|
||||
uint32_t const tIdx = kBLOCK_SIZE * blockRank + threadIdx.x;
|
||||
uint32_t const warpIdx = tIdx / kWARP_SIZE;
|
||||
uint32_t const laneIdx = tIdx % kWARP_SIZE;
|
||||
uint32_t const warpNum = gridDim.x * kWARPS_PER_BLOCK;
|
||||
auto block = cg::this_thread_block();
|
||||
auto warp = cg::tiled_partition<kWARP_SIZE>(block);
|
||||
|
||||
InputT minScore = negativeInfinity<InputT>();
|
||||
|
||||
for (uint32_t tokenId = warpIdx; tokenId < batchSize; tokenId += warpNum)
|
||||
{
|
||||
auto scoreOffset = tokenId * len;
|
||||
auto outputOffset = tokenId * topK;
|
||||
InputT inputScore[MaxLen / kWARP_SIZE];
|
||||
IdxT inputIndex[MaxLen / kWARP_SIZE];
|
||||
|
||||
InputT warpTopKScore[MaxTopK];
|
||||
IdxT warpTopKExpertIdx[MaxTopK];
|
||||
|
||||
// Load scores and indices for this warp
|
||||
for (uint32_t i = 0; i < MaxLen / kWARP_SIZE; ++i)
|
||||
{
|
||||
auto expertIdx = i * kWARP_SIZE + laneIdx;
|
||||
inputScore[i] = expertIdx < len ? static_cast<InputT>(in[scoreOffset + expertIdx]) : minScore;
|
||||
inputIndex[i] = expertIdx;
|
||||
}
|
||||
|
||||
// Reduce topK scores and indices for this warp
|
||||
tensorrt_llm::kernels::reduce_topk::reduceTopK(
|
||||
warp, warpTopKScore, warpTopKExpertIdx, inputScore, inputIndex, minScore);
|
||||
|
||||
if (laneIdx < topK)
|
||||
{
|
||||
out[outputOffset + laneIdx] = static_cast<OutputT>(warpTopKScore[laneIdx]);
|
||||
outIdx[outputOffset + laneIdx] = warpTopKExpertIdx[laneIdx];
|
||||
}
|
||||
} // end for tokenId
|
||||
}
|
||||
} // namespace moe_topk
|
||||
|
||||
/***************Runtime API****************/
|
||||
|
||||
@ -1223,9 +1300,11 @@ void standalone_stable_radix_topk_(void* buf, size_t& buf_size, T const* in, Idx
|
||||
IdxT* sort_in_idx = nullptr;
|
||||
|
||||
air_topk_stable::ComputeOffset<IdxT> computeoffset(k);
|
||||
|
||||
thrust::counting_iterator<IdxT> counting_iter(0);
|
||||
thrust::transform_iterator<air_topk_stable::ComputeOffset<IdxT>, thrust::counting_iterator<IdxT>> transform_iter(
|
||||
counting_iter, computeoffset);
|
||||
|
||||
cub::DeviceSegmentedSort::SortPairs(NULL, temp_storage_bytes, out_idx, out_idx, out, out, k * batch_size,
|
||||
batch_size, transform_iter, transform_iter + 1, stream);
|
||||
if (sorted)
|
||||
@ -1277,8 +1356,8 @@ void standalone_stable_radix_topk_(void* buf, size_t& buf_size, T const* in, Idx
|
||||
sort_in = static_cast<decltype(sort_in)>(aligned_pointers[9]);
|
||||
sort_in_idx = static_cast<decltype(sort_in_idx)>(aligned_pointers[10]);
|
||||
}
|
||||
cudaMemsetAsync(
|
||||
buf, 0, static_cast<char*>(aligned_pointers[2]) - static_cast<char*>(aligned_pointers[0]), stream);
|
||||
cudaMemsetAsync(aligned_pointers[0], 0,
|
||||
static_cast<char*>(aligned_pointers[2]) - static_cast<char*>(aligned_pointers[0]), stream);
|
||||
}
|
||||
|
||||
T const* in_buf = nullptr;
|
||||
@ -1423,36 +1502,120 @@ void standalone_stable_radix_topk_one_block_(void* buf, size_t& buf_size, T cons
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename idxT, bool sorted = false>
|
||||
void standalone_stable_radix_11bits(void* buf, size_t& buf_size, T const* in, int batch_size, idxT len, idxT k, T* out,
|
||||
idxT* out_idx, bool greater, cudaStream_t stream = 0)
|
||||
template <typename T, typename IdxT, bool sorted = false>
|
||||
void standalone_stable_radix_11bits(void* buf, size_t& buf_size, T const* in, int batch_size, IdxT len, IdxT k, T* out,
|
||||
IdxT* out_idx, bool greater, cudaStream_t stream = 0)
|
||||
{
|
||||
constexpr int items_per_thread = 32;
|
||||
constexpr int block_dim = 512;
|
||||
constexpr bool fused_last_filter = false;
|
||||
if (len <= block_dim * items_per_thread)
|
||||
{
|
||||
standalone_stable_radix_topk_one_block_<T, idxT, 11, block_dim>(
|
||||
buf, buf_size, in, static_cast<idxT*>(nullptr), batch_size, len, k, out, out_idx, !greater, stream, sorted);
|
||||
standalone_stable_radix_topk_one_block_<T, IdxT, 11, block_dim>(
|
||||
buf, buf_size, in, static_cast<IdxT*>(nullptr), batch_size, len, k, out, out_idx, !greater, stream, sorted);
|
||||
}
|
||||
else
|
||||
{
|
||||
int sm_cnt = tensorrt_llm::common::getMultiProcessorCount();
|
||||
unsigned grid_dim = air_topk_stable::calc_grid_dim<T, idxT, 11, block_dim>(batch_size, len, sm_cnt);
|
||||
unsigned grid_dim = air_topk_stable::calc_grid_dim<T, IdxT, 11, block_dim>(batch_size, len, sm_cnt);
|
||||
|
||||
if (grid_dim == 1)
|
||||
{
|
||||
standalone_stable_radix_topk_one_block_<T, idxT, 11, block_dim>(buf, buf_size, in,
|
||||
static_cast<idxT*>(nullptr), batch_size, len, k, out, out_idx, !greater, stream, sorted);
|
||||
standalone_stable_radix_topk_one_block_<T, IdxT, 11, block_dim>(buf, buf_size, in,
|
||||
static_cast<IdxT*>(nullptr), batch_size, len, k, out, out_idx, !greater, stream, sorted);
|
||||
}
|
||||
else
|
||||
{
|
||||
standalone_stable_radix_topk_<T, idxT, 11, block_dim>(buf, buf_size, in, static_cast<idxT*>(nullptr),
|
||||
standalone_stable_radix_topk_<T, IdxT, 11, block_dim>(buf, buf_size, in, static_cast<IdxT*>(nullptr),
|
||||
batch_size, len, k, out, out_idx, !greater, fused_last_filter, grid_dim, stream, sorted);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int nextPowerOfTwo(int num)
|
||||
{
|
||||
if (num <= 0)
|
||||
{
|
||||
return 1; // Handle invalid input
|
||||
}
|
||||
int power = 1;
|
||||
while (power < num)
|
||||
{
|
||||
// Check for overflow before shifting
|
||||
if (power > INT_MAX / 2)
|
||||
{
|
||||
return power;
|
||||
}
|
||||
power <<= 1;
|
||||
}
|
||||
return power;
|
||||
}
|
||||
|
||||
template <typename T, typename IdxT>
|
||||
void moe_reduce_topk(
|
||||
T const* in, int batch_size, IdxT len, IdxT k, T* out, IdxT* out_idx, bool greater, cudaStream_t stream = 0)
|
||||
{
|
||||
using InputT = T;
|
||||
using OutputT = T;
|
||||
const uint32_t max_num_blocks = 1024;
|
||||
const uint32_t num_blocks
|
||||
= std::min(static_cast<uint32_t>((batch_size - 1) / moe_topk::kWARPS_PER_BLOCK + 1), max_num_blocks);
|
||||
|
||||
uint32_t max_len = nextPowerOfTwo(len) < 32 ? 32 : nextPowerOfTwo(len);
|
||||
uint32_t moe_topk = nextPowerOfTwo(k);
|
||||
|
||||
auto* kernel_instance = &moe_topk::moe_topk_kernel<InputT, OutputT, IdxT, 128, 8>;
|
||||
|
||||
switch (max_len)
|
||||
{
|
||||
case 32:
|
||||
switch (moe_topk)
|
||||
{
|
||||
case 1: kernel_instance = &moe_topk::moe_topk_kernel<InputT, OutputT, IdxT, 32, 1>; break;
|
||||
case 2: kernel_instance = &moe_topk::moe_topk_kernel<InputT, OutputT, IdxT, 32, 2>; break;
|
||||
case 4: kernel_instance = &moe_topk::moe_topk_kernel<InputT, OutputT, IdxT, 32, 4>; break;
|
||||
case 8: kernel_instance = &moe_topk::moe_topk_kernel<InputT, OutputT, IdxT, 32, 8>; break;
|
||||
default: kernel_instance = nullptr; break;
|
||||
}
|
||||
break;
|
||||
case 64:
|
||||
switch (moe_topk)
|
||||
{
|
||||
case 1: kernel_instance = &moe_topk::moe_topk_kernel<InputT, OutputT, IdxT, 64, 1>; break;
|
||||
case 2: kernel_instance = &moe_topk::moe_topk_kernel<InputT, OutputT, IdxT, 64, 2>; break;
|
||||
case 4: kernel_instance = &moe_topk::moe_topk_kernel<InputT, OutputT, IdxT, 64, 4>; break;
|
||||
case 8: kernel_instance = &moe_topk::moe_topk_kernel<InputT, OutputT, IdxT, 64, 8>; break;
|
||||
default: kernel_instance = nullptr; break;
|
||||
}
|
||||
break;
|
||||
case 96:
|
||||
switch (moe_topk)
|
||||
{
|
||||
case 1: kernel_instance = &moe_topk::moe_topk_kernel<InputT, OutputT, IdxT, 96, 1>; break;
|
||||
case 2: kernel_instance = &moe_topk::moe_topk_kernel<InputT, OutputT, IdxT, 96, 2>; break;
|
||||
case 4: kernel_instance = &moe_topk::moe_topk_kernel<InputT, OutputT, IdxT, 96, 4>; break;
|
||||
case 8: kernel_instance = &moe_topk::moe_topk_kernel<InputT, OutputT, IdxT, 96, 8>; break;
|
||||
default: kernel_instance = nullptr; break;
|
||||
}
|
||||
break;
|
||||
case 128:
|
||||
switch (moe_topk)
|
||||
{
|
||||
case 1: kernel_instance = &moe_topk::moe_topk_kernel<InputT, OutputT, IdxT, 128, 1>; break;
|
||||
case 2: kernel_instance = &moe_topk::moe_topk_kernel<InputT, OutputT, IdxT, 128, 2>; break;
|
||||
case 4: kernel_instance = &moe_topk::moe_topk_kernel<InputT, OutputT, IdxT, 128, 4>; break;
|
||||
case 8: kernel_instance = &moe_topk::moe_topk_kernel<InputT, OutputT, IdxT, 128, 8>; break;
|
||||
default: kernel_instance = nullptr; break;
|
||||
}
|
||||
break;
|
||||
default: kernel_instance = nullptr; break;
|
||||
}
|
||||
|
||||
dim3 moe_topk_grid_dim(num_blocks);
|
||||
dim3 moe_topk_block_dim(moe_topk::kBLOCK_SIZE);
|
||||
|
||||
kernel_instance<<<moe_topk_grid_dim, moe_topk_block_dim, 0, stream>>>(in, out, out_idx, batch_size, len, k);
|
||||
}
|
||||
#endif
|
||||
|
||||
///////////////
|
||||
@ -1461,22 +1624,22 @@ template <typename T>
|
||||
size_t invokeComputeTopkLastDimWorkspaceSize(
|
||||
SizeType32 batchSize, SizeType32 inputLength, SizeType32 k, bool is_largest)
|
||||
{
|
||||
using idxT = SizeType32;
|
||||
using IdxT = SizeType32;
|
||||
|
||||
size_t buf_size = 0;
|
||||
void* workspace = nullptr;
|
||||
T const* in = nullptr;
|
||||
T* out_val = nullptr;
|
||||
idxT* out_idx = nullptr;
|
||||
IdxT* out_idx = nullptr;
|
||||
|
||||
constexpr int block_dim = 512;
|
||||
constexpr bool fused_last_filter = false;
|
||||
constexpr bool sorted = true;
|
||||
|
||||
int sm_cnt = tensorrt_llm::common::getMultiProcessorCount();
|
||||
unsigned grid_dim = air_topk_stable::calc_grid_dim<T, idxT, 11, block_dim>(batchSize, inputLength, sm_cnt);
|
||||
unsigned grid_dim = air_topk_stable::calc_grid_dim<T, IdxT, 11, block_dim>(batchSize, inputLength, sm_cnt);
|
||||
|
||||
standalone_stable_radix_topk_<T, idxT, 11, block_dim>(workspace, buf_size, in, static_cast<idxT*>(nullptr),
|
||||
standalone_stable_radix_topk_<T, IdxT, 11, block_dim>(workspace, buf_size, in, static_cast<IdxT*>(nullptr),
|
||||
batchSize, inputLength, k, out_val, out_idx, !is_largest, fused_last_filter, grid_dim, 0, sorted);
|
||||
return buf_size;
|
||||
}
|
||||
@ -1506,8 +1669,17 @@ void invokeTopkLastDim(SizeType32 batchSize, SizeType32 inputLength, SizeType32
|
||||
T const* in = reinterpret_cast<T const*>(input);
|
||||
T* out_val_ = reinterpret_cast<T*>(out_val);
|
||||
SizeType32* out_idx_ = reinterpret_cast<SizeType32*>(out_idx);
|
||||
standalone_stable_radix_11bits<T, SizeType32, true>(
|
||||
workspace, buf_size, in, batchSize, inputLength, k, out_val_, out_idx_, is_largest, stream);
|
||||
if (inputLength <= 128 && k <= 8 && is_largest == true)
|
||||
{
|
||||
// This method does not require a buffer, but since the implementation may vary in different cases,
|
||||
// we still allocate the buffer in case AIR TopK is used instead.
|
||||
moe_reduce_topk(in, batchSize, inputLength, k, out_val_, out_idx_, !is_largest, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
standalone_stable_radix_11bits<T, SizeType32, true>(
|
||||
workspace, buf_size, in, batchSize, inputLength, k, out_val_, out_idx_, is_largest, stream);
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_TOPK_LastDim_DATA_TYPE(T) \
|
||||
|
||||
@ -378,7 +378,7 @@ __device__ void routingPermutation(KernelParams params, PackedScoreIdx<BaseType>
|
||||
// We can't do it earlier because FC1 depends on the mPtrCtaIdxXyToBatchIdx,
|
||||
// mPtrCtaIdxXyToMnLimit, mPtrNumNonExitingCtas and mPtrTotalNumPaddedTokens
|
||||
// TODO: this is not sufficient to ensure visibility in the next kernel!
|
||||
#if !defined(PDL_PROFILE) || PDL_PROFILE == 0
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
if constexpr (KernelParams::UsePdl)
|
||||
{
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
@ -757,15 +757,13 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesOffsetsKernel(Ke
|
||||
}
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
// Trigger secondary kernel.
|
||||
// Note: this does not guarantee the visibility of prior writes unless the consumer executes a
|
||||
// dependency sync.
|
||||
#if !defined(PDL_PROFILE) || PDL_PROFILE == 0
|
||||
// Trigger secondary kernel.
|
||||
// Note: this does not guarantee the visibility of prior writes unless the consumer executes a
|
||||
// dependency sync.
|
||||
if constexpr (KernelParams::UsePdl)
|
||||
{
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
}
|
||||
#endif
|
||||
#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
}
|
||||
|
||||
|
||||
@ -227,13 +227,11 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam
|
||||
}
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
#if !defined(PDL_PROFILE) || PDL_PROFILE == 0
|
||||
// we can trigger the next kernel at this point
|
||||
if constexpr (KernelParams::UsePdl)
|
||||
{
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// at this point, all values for offsets are ready, except the final offsets
|
||||
|
||||
@ -199,13 +199,11 @@ __global__ void __launch_bounds__(NumThreadsSingleBlock) routingIndicesBlockKern
|
||||
}
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
#if !defined(PDL_PROFILE) || PDL_PROFILE == 0
|
||||
// we can trigger the next kernel at this point
|
||||
if constexpr (KernelParams::UsePdl)
|
||||
{
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
for (int tokenIdx = 0; tokenIdx < params.mNumTokens; tokenIdx++)
|
||||
|
||||
@ -43,7 +43,7 @@ target_link_libraries(
|
||||
${Python3_LIBRARIES}
|
||||
${TORCH_LIBRARIES}
|
||||
torch_python
|
||||
CUDA::cuda_driver
|
||||
${CUDA_DRV_LIB}
|
||||
${CUDA_NVML_LIB}
|
||||
th_common)
|
||||
target_compile_definitions(
|
||||
|
||||
@ -285,5 +285,35 @@ struct type_caster<std::vector<std::reference_wrapper<T const>>>
|
||||
return make_caster<std::vector<T>>::from_cpp(result, policy, cleanup);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_caster<torch::ScalarType>
|
||||
{
|
||||
NB_TYPE_CASTER(torch::ScalarType, const_name("torch.dtype"));
|
||||
|
||||
bool from_python(handle src, uint8_t flags, cleanup_list* cleanup) noexcept
|
||||
{
|
||||
std::string dtype_name = nb::cast<std::string>(nb::str(src));
|
||||
if (dtype_name.substr(0, 6) == "torch.")
|
||||
{
|
||||
dtype_name = dtype_name.substr(6);
|
||||
}
|
||||
|
||||
auto const& dtype_map = c10::getStringToDtypeMap();
|
||||
auto it = dtype_map.find(dtype_name);
|
||||
if (it != dtype_map.end())
|
||||
{
|
||||
value = it->second;
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
static handle from_cpp(torch::ScalarType src, rv_policy policy, cleanup_list* cleanup)
|
||||
{
|
||||
throw std::runtime_error("from_cpp for torch::ScalarType is not implemented");
|
||||
}
|
||||
};
|
||||
} // namespace detail
|
||||
} // namespace NB_NAMESPACE
|
||||
|
||||
@ -240,7 +240,8 @@ void initBindings(nb::module_& m)
|
||||
nb::class_<tle::KVCacheEvent>(executor_kv_cache, "KVCacheEvent")
|
||||
.def_ro("event_id", &tle::KVCacheEvent::eventId)
|
||||
.def_ro("data", &tle::KVCacheEvent::data)
|
||||
.def_ro("window_size", &tle::KVCacheEvent::windowSize);
|
||||
.def_ro("window_size", &tle::KVCacheEvent::windowSize)
|
||||
.def_ro("attention_dp_rank", &tle::KVCacheEvent::attentionDpRank);
|
||||
|
||||
nb::class_<tle::KVCacheEventManager>(executor_kv_cache, "KVCacheEventManager")
|
||||
.def(
|
||||
|
||||
@ -27,6 +27,7 @@
|
||||
|
||||
#include <nanobind/nanobind.h>
|
||||
#include <nanobind/stl/chrono.h>
|
||||
#include <nanobind/stl/function.h>
|
||||
#include <nanobind/stl/list.h>
|
||||
#include <nanobind/stl/optional.h>
|
||||
#include <nanobind/stl/shared_ptr.h>
|
||||
|
||||
@ -279,7 +279,7 @@ void initBindings(nb::module_& m)
|
||||
.def(nb::init<tr::GptDecoderBatched::CudaStreamPtr>(), nb::arg("stream"))
|
||||
.def("setup", &tr::GptDecoderBatched::setup, nb::arg("mode"), nb::arg("max_num_sequences"),
|
||||
nb::arg("max_beam_width"), nb::arg("dtype"), nb::arg("model_config"), nb::arg("world_config"))
|
||||
.def("forward_async", &tr::GptDecoderBatched::forwardAsync, nb::arg("output"), nb::arg("input"))
|
||||
.def("forward_async", &tr::GptDecoderBatched::forwardAsync, nb::arg("decoder_state"), nb::arg("input"))
|
||||
.def("underlying_decoder", &tr::GptDecoderBatched::getUnderlyingDecoder, nb::rv_policy::reference)
|
||||
.def("finalize", &tr::GptDecoderBatched::finalize, nb::arg("decoder_state"), nb::arg("batch_idx"),
|
||||
nb::arg("sampling_config"), nb::arg("streaming"))
|
||||
|
||||
@ -946,8 +946,8 @@ int MixtureOfExpertsPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
|
||||
std::optional<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> gemm2;
|
||||
if (common::getEnvForceDeterministicMOE())
|
||||
{
|
||||
gemm1 = mMOERunner->getTactics()[0];
|
||||
gemm2 = mMOERunner->getTactics()[0];
|
||||
gemm1 = mMOERunner->getTactics(MoeGemmId::GEMM_1)[0];
|
||||
gemm2 = mMOERunner->getTactics(MoeGemmId::GEMM_2)[0];
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -1278,7 +1278,7 @@ void MixtureOfExpertsGemmProfiler::runTactic(int m, int n, int k, MixtureOfExper
|
||||
auto MixtureOfExpertsGemmProfiler::getTactics(int m, int n, int k) const -> std::vector<Config>
|
||||
{
|
||||
assert(mRunner);
|
||||
return mRunner->mMOERunner->getTactics();
|
||||
return mRunner->mMOERunner->getTactics(backend.mGemmToProfile);
|
||||
}
|
||||
|
||||
void MixtureOfExpertsGemmProfiler::initTmpData(
|
||||
|
||||
@ -43,6 +43,7 @@ namespace kernels = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE;
|
||||
using MoeMinLatencyParams = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::MoeMinLatencyParams;
|
||||
using MOEParallelismConfig = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::MOEParallelismConfig;
|
||||
using QuantParams = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::QuantParams;
|
||||
using MoeGemmId = CUTLASS_MOE_GEMM_NAMESPACE::MoeGemmId;
|
||||
using ActivationType = CUTLASS_MOE_GEMM_NAMESPACE::ActivationType;
|
||||
using ActivationParams = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::ActivationParams;
|
||||
using TmaWarpSpecializedGroupedGemmInput = CUTLASS_MOE_GEMM_NAMESPACE::TmaWarpSpecializedGroupedGemmInput;
|
||||
|
||||
@ -44,7 +44,7 @@ target_link_libraries(
|
||||
${Python3_LIBRARIES}
|
||||
${TORCH_LIBRARIES}
|
||||
torch_python
|
||||
CUDA::cuda_driver
|
||||
${CUDA_DRV_LIB}
|
||||
${CUDA_NVML_LIB}
|
||||
th_common)
|
||||
target_compile_definitions(
|
||||
|
||||
@ -131,6 +131,7 @@ void DecoderState::setupSpeculativeDecodingBuffers(
|
||||
|
||||
mSpeculativeDecodingMode = speculativeDecodingMode;
|
||||
|
||||
auto constexpr nvTokenIdType = TRTDataType<TokenIdType>::value;
|
||||
auto constexpr nvSizeType = TRTDataType<SizeType32>::value;
|
||||
|
||||
auto& dInput = mJointDecodingInput;
|
||||
@ -179,6 +180,7 @@ void DecoderState::setupSpeculativeDecodingBuffers(
|
||||
DecodingInput::ExternalDraftTokensInputs externalDraftTokensInputs;
|
||||
|
||||
externalDraftTokensInputs.draftLogits = bufferManager.emptyTensor(MemoryType::kGPU, dtype);
|
||||
externalDraftTokensInputs.draftLogitsHost = bufferManager.emptyTensor(MemoryType::kPINNEDPOOL, dtype);
|
||||
externalDraftTokensInputs.draftProbs = bufferManager.emptyTensor(MemoryType::kGPU, dtype);
|
||||
externalDraftTokensInputs.targetProbs = bufferManager.emptyTensor(MemoryType::kGPU, dtype);
|
||||
externalDraftTokensInputs.numDraftTokens = bufferManager.emptyTensor(MemoryType::kGPU, nvSizeType);
|
||||
@ -187,8 +189,8 @@ void DecoderState::setupSpeculativeDecodingBuffers(
|
||||
= bufferManager.emptyTensor(MemoryType::kGPU, TRTDataType<bool>::value);
|
||||
externalDraftTokensInputs.useDraftLogitsHost
|
||||
= bufferManager.emptyTensor(MemoryType::kPINNEDPOOL, TRTDataType<bool>::value);
|
||||
externalDraftTokensInputs.draftTokenIds
|
||||
= bufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32);
|
||||
externalDraftTokensInputs.draftTokenIds = bufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType);
|
||||
externalDraftTokensInputs.draftTokenIdsHost = bufferManager.emptyTensor(MemoryType::kPINNEDPOOL, nvTokenIdType);
|
||||
|
||||
dInput->externalDraftTokensInputs = externalDraftTokensInputs;
|
||||
}
|
||||
@ -366,10 +368,16 @@ void DecoderState::reshapeSpeculativeDecodingBuffers(SpeculativeDecodingMode con
|
||||
{mMaxNumSequences, mMaxDecodingEngineTokens, mMaxBeamWidth, static_cast<SizeType32>(vocabSizePadded)});
|
||||
dInput.externalDraftTokensInputs->draftProbs->reshape(probsShape);
|
||||
dInput.externalDraftTokensInputs->targetProbs->reshape(probsShape);
|
||||
dInput.externalDraftTokensInputs->draftLogits->reshape(
|
||||
ITensor::makeShape({mMaxNumSequences, mMaxDecodingEngineTokens, static_cast<SizeType32>(vocabSizePadded)}));
|
||||
dInput.externalDraftTokensInputs->draftTokenIds->reshape(
|
||||
ITensor::makeShape({mMaxNumSequences, mMaxDecodingEngineTokens}));
|
||||
|
||||
auto const logitsShape = ITensor::makeShape(
|
||||
{mMaxNumSequences, mMaxDecodingEngineTokens, static_cast<SizeType32>(vocabSizePadded)});
|
||||
dInput.externalDraftTokensInputs->draftLogits->reshape(logitsShape);
|
||||
dInput.externalDraftTokensInputs->draftLogitsHost->reshape(logitsShape);
|
||||
|
||||
auto const tokenIdsShape = ITensor::makeShape({mMaxNumSequences, mMaxDecodingEngineTokens});
|
||||
dInput.externalDraftTokensInputs->draftTokenIds->reshape(tokenIdsShape);
|
||||
dInput.externalDraftTokensInputs->draftTokenIdsHost->reshape(tokenIdsShape);
|
||||
|
||||
dInput.externalDraftTokensInputs->numDraftTokens->reshape(maxNumSequencesShape);
|
||||
dInput.externalDraftTokensInputs->numDraftTokensHost->reshape(maxNumSequencesShape);
|
||||
dInput.externalDraftTokensInputs->useDraftLogits->reshape(maxNumSequencesShape);
|
||||
|
||||
@ -83,7 +83,7 @@ add_library(
|
||||
reducescatterOp.cpp
|
||||
relativeAttentionBiasOp.cpp
|
||||
dsv3RouterGemmOp.cpp
|
||||
renormMoeRoutingOp.cpp
|
||||
customMoeRoutingOp.cpp
|
||||
selectiveScanOp.cpp
|
||||
userbuffersFinalizeOp.cpp
|
||||
userbuffersTensor.cpp
|
||||
@ -119,9 +119,9 @@ endif()
|
||||
|
||||
if(NOT WIN32)
|
||||
set_target_properties(
|
||||
th_common
|
||||
PROPERTIES LINK_FLAGS
|
||||
"-Wl,-rpath='$ORIGIN' ${AS_NEEDED_FLAG} ${UNDEFINED_FLAG}")
|
||||
th_common PROPERTIES BUILD_RPATH "$ORIGIN;$ORIGIN/../../nvidia/nccl/lib")
|
||||
set_target_properties(
|
||||
th_common PROPERTIES LINK_FLAGS "${AS_NEEDED_FLAG} ${UNDEFINED_FLAG}")
|
||||
else()
|
||||
target_link_libraries(th_common PRIVATE context_attention_src)
|
||||
endif()
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/opUtils.h"
|
||||
#include "tensorrt_llm/kernels/renormMoeRoutingKernels.h"
|
||||
#include "tensorrt_llm/kernels/customMoeRoutingKernels.h"
|
||||
#include "tensorrt_llm/runtime/torchUtils.h"
|
||||
|
||||
namespace th = torch;
|
||||
@ -25,7 +25,8 @@ namespace tk = tensorrt_llm::kernels;
|
||||
namespace torch_ext
|
||||
{
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> renorm_moe_routing_op(th::Tensor const& router_logits, int64_t topk)
|
||||
template <bool DoSoftmaxBeforeTopK>
|
||||
std::tuple<at::Tensor, at::Tensor> custom_moe_routing_op(th::Tensor const& router_logits, int64_t topk)
|
||||
{
|
||||
auto data_type = router_logits.scalar_type();
|
||||
auto input_size = router_logits.sizes();
|
||||
@ -44,20 +45,22 @@ std::tuple<at::Tensor, at::Tensor> renorm_moe_routing_op(th::Tensor const& route
|
||||
{
|
||||
case torch::kFloat32:
|
||||
// Handle Float32
|
||||
tk::invokeRenormMoeRouting<float, float, int32_t>(reinterpret_cast<float*>(router_logits.mutable_data_ptr()),
|
||||
tk::invokeRenormMoeRouting<float, float, int32_t, DoSoftmaxBeforeTopK>(
|
||||
reinterpret_cast<float*>(router_logits.mutable_data_ptr()),
|
||||
reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
|
||||
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()), num_tokens, num_experts, topk, stream);
|
||||
break;
|
||||
case torch::kBFloat16:
|
||||
// Handle BFloat16
|
||||
tk::invokeRenormMoeRouting<__nv_bfloat16, float, int32_t>(
|
||||
tk::invokeRenormMoeRouting<__nv_bfloat16, float, int32_t, DoSoftmaxBeforeTopK>(
|
||||
reinterpret_cast<__nv_bfloat16*>(router_logits.mutable_data_ptr()),
|
||||
reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
|
||||
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()), num_tokens, num_experts, topk, stream);
|
||||
break;
|
||||
case torch::kHalf:
|
||||
// Handle Half
|
||||
tk::invokeRenormMoeRouting<half, float, int32_t>(reinterpret_cast<half*>(router_logits.mutable_data_ptr()),
|
||||
tk::invokeRenormMoeRouting<half, float, int32_t, DoSoftmaxBeforeTopK>(
|
||||
reinterpret_cast<half*>(router_logits.mutable_data_ptr()),
|
||||
reinterpret_cast<float*>(topk_values.mutable_data_ptr()),
|
||||
reinterpret_cast<int32_t*>(topk_indices.mutable_data_ptr()), num_tokens, num_experts, topk, stream);
|
||||
break;
|
||||
@ -69,6 +72,15 @@ std::tuple<at::Tensor, at::Tensor> renorm_moe_routing_op(th::Tensor const& route
|
||||
return {topk_indices, topk_values};
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> renorm_moe_routing_op(th::Tensor const& router_logits, int64_t topk)
|
||||
{
|
||||
return custom_moe_routing_op<false>(router_logits, topk);
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> default_moe_routing_op(th::Tensor const& router_logits, int64_t topk)
|
||||
{
|
||||
return custom_moe_routing_op<true>(router_logits, topk);
|
||||
}
|
||||
} // namespace torch_ext
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
@ -82,3 +94,15 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
{
|
||||
m.impl("renorm_moe_routing_op", &torch_ext::renorm_moe_routing_op);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
{
|
||||
m.def(
|
||||
"default_moe_routing_op(Tensor router_logits, SymInt topk"
|
||||
") -> (Tensor, Tensor)");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
{
|
||||
m.impl("default_moe_routing_op", &torch_ext::default_moe_routing_op);
|
||||
}
|
||||
@ -48,6 +48,7 @@ namespace common = tensorrt_llm::common;
|
||||
namespace kernels = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE;
|
||||
using ActivationParams = CUTLASS_MOE_GEMM_NAMESPACE::ActivationParams;
|
||||
using ActivationType = CUTLASS_MOE_GEMM_NAMESPACE::ActivationType;
|
||||
using MoeGemmId = CUTLASS_MOE_GEMM_NAMESPACE::MoeGemmId;
|
||||
// Always use public header as it is just utility functions and types
|
||||
using TmaWarpSpecializedGroupedGemmInput = tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput;
|
||||
using profiler_backend = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::GemmProfilerBackend;
|
||||
@ -215,7 +216,8 @@ public:
|
||||
mKernelRunner->use_fused_finalize_ = mUseFusedFinalize;
|
||||
|
||||
mProfiler = std::make_shared<kernels::GemmProfilerBackend>();
|
||||
mAllProfiles = mKernelRunner->getTactics();
|
||||
mGemm1Profiles = mKernelRunner->getTactics(MoeGemmId::GEMM_1);
|
||||
mGemm2Profiles = mKernelRunner->getTactics(MoeGemmId::GEMM_2);
|
||||
}
|
||||
|
||||
~FusedMoeRunner()
|
||||
@ -585,10 +587,11 @@ public:
|
||||
return std::make_tuple(output, num_active_experts_per_node, experts_to_token_score, active_expert_global_ids);
|
||||
}
|
||||
|
||||
int64_t getTacticNum()
|
||||
int64_t getTacticNum(int64_t const gemm_idx)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
return mAllProfiles.size();
|
||||
TORCH_CHECK(gemm_idx == 1 || gemm_idx == 2, "gemm_idx must be 1 or 2");
|
||||
return (gemm_idx == 1) ? mGemm1Profiles.size() : mGemm2Profiles.size();
|
||||
}
|
||||
|
||||
// TODO Update this to be able to tell if we are profiling swiglu bias
|
||||
@ -624,10 +627,14 @@ public:
|
||||
: group_size_;
|
||||
int const num_experts = static_cast<int>(fc2_expert_weights.sizes()[0] * ep_size);
|
||||
|
||||
auto const gemm_to_profile
|
||||
= (gemm_idx == 1) ? profiler_backend::GemmToProfile::GEMM_1 : profiler_backend::GemmToProfile::GEMM_2;
|
||||
auto const& profiles = (gemm_idx == 1) ? mGemm1Profiles : mGemm2Profiles;
|
||||
|
||||
// Get specific profile configs according to the profile_id.
|
||||
// Fallback tactic is set to be 0
|
||||
// TODO: use the best tactic id found offline for a better default inference perf
|
||||
auto const& profile = profile_id == -1 ? mAllProfiles.front() : mAllProfiles[profile_id];
|
||||
auto const& profile = profile_id == -1 ? profiles.front() : profiles[profile_id];
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
|
||||
@ -638,8 +645,7 @@ public:
|
||||
if (do_preparation)
|
||||
{
|
||||
// Set profiled gemm idx
|
||||
mProfiler->mGemmToProfile
|
||||
= (gemm_idx == 1) ? profiler_backend::GemmToProfile::GEMM_1 : profiler_backend::GemmToProfile::GEMM_2;
|
||||
mProfiler->mGemmToProfile = gemm_to_profile;
|
||||
|
||||
// mProfiler init
|
||||
auto parallelism_config = kernels::MOEParallelismConfig(static_cast<int>(tp_size),
|
||||
@ -704,7 +710,8 @@ private:
|
||||
bool mUseFusedFinalize = true;
|
||||
|
||||
using Profile = tensorrt_llm::cutlass_extensions::CutlassGemmConfig;
|
||||
std::vector<Profile> mAllProfiles;
|
||||
std::vector<Profile> mGemm1Profiles;
|
||||
std::vector<Profile> mGemm2Profiles;
|
||||
|
||||
void freeProfileWorkspace()
|
||||
{
|
||||
@ -730,15 +737,15 @@ private:
|
||||
return;
|
||||
}
|
||||
|
||||
auto best_gemm1_profile = mAllProfiles.front();
|
||||
auto best_gemm2_profile = mAllProfiles.front();
|
||||
auto best_gemm1_profile = mGemm1Profiles.front();
|
||||
auto best_gemm2_profile = mGemm2Profiles.front();
|
||||
if (profile_ids.has_value())
|
||||
{
|
||||
TORCH_CHECK(profile_ids.value().size() == 2, "Expecting 2 profile ids");
|
||||
best_gemm1_profile
|
||||
= profile_ids.value()[0] == -1 ? best_gemm1_profile : mAllProfiles.at(profile_ids.value()[0]);
|
||||
= profile_ids.value()[0] == -1 ? best_gemm1_profile : mGemm1Profiles.at(profile_ids.value()[0]);
|
||||
best_gemm2_profile
|
||||
= profile_ids.value()[1] == -1 ? best_gemm2_profile : mAllProfiles.at(profile_ids.value()[1]);
|
||||
= profile_ids.value()[1] == -1 ? best_gemm2_profile : mGemm2Profiles.at(profile_ids.value()[1]);
|
||||
}
|
||||
mKernelRunner->setTactic(best_gemm1_profile, best_gemm2_profile);
|
||||
}
|
||||
|
||||
@ -99,7 +99,7 @@ TEST_F(RequestInfoTest, Basic)
|
||||
}
|
||||
auto state = std::make_unique<texec::DataTransceiverState>();
|
||||
state->setCommState(texec::kv_cache::CommState{12, "127.0.0.1"});
|
||||
state->setCacheState(texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, nvinfer1::DataType::kFLOAT});
|
||||
state->setCacheState(texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, 8, nvinfer1::DataType::kFLOAT});
|
||||
RequestInfo info{1, *state};
|
||||
auto info2 = serializeDeserialize(info);
|
||||
EXPECT_EQ(info, info2);
|
||||
@ -133,7 +133,7 @@ TEST_F(CacheConfigTest, EqualTo)
|
||||
constexpr SizeType32 tokensPerBlock{64};
|
||||
constexpr SizeType32 tensorParallelism{8};
|
||||
constexpr SizeType32 pipelineParallelism{2};
|
||||
constexpr SizeType32 contextParallelism{1};
|
||||
constexpr SizeType32 contextParallelism{2};
|
||||
constexpr SizeType32 sizePerHead{hiddenSize / nbHeads};
|
||||
constexpr CacheState::AttentionType attentionType{CacheState::AttentionType::kDEFAULT};
|
||||
constexpr int kvFactor = 2;
|
||||
@ -148,7 +148,7 @@ TEST_F(CacheConfigTest, EqualTo)
|
||||
texec::kv_cache::CacheState state0{
|
||||
cacheStateCfg, worldConfig, modelConfig.getKvDataType(), attentionType, kvFactor};
|
||||
texec::kv_cache::CacheState state1{nbAttentionLayers, nbHeads, sizePerHead, tokensPerBlock, tensorParallelism,
|
||||
pipelineParallelism, dtype, attentionType, kvFactor, false, 0, tensorParallelism};
|
||||
pipelineParallelism, contextParallelism, dtype, attentionType, kvFactor, false, 0, tensorParallelism};
|
||||
EXPECT_EQ(state0, state1);
|
||||
}
|
||||
|
||||
@ -165,7 +165,7 @@ public:
|
||||
ON_CALL(*this, recvRequestInfo)
|
||||
.WillByDefault(Return(RequestInfo{0,
|
||||
texec::DataTransceiverState{
|
||||
texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, nvinfer1::DataType::kFLOAT},
|
||||
texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, 1, nvinfer1::DataType::kFLOAT},
|
||||
texec::kv_cache::CommState{std::vector<SizeType32>{0}, 0}}}));
|
||||
ON_CALL(*this, getCounterpartsCount).WillByDefault(Return(1));
|
||||
}
|
||||
@ -217,7 +217,8 @@ TEST_F(MockTransceiverTest, MpiResponderBasic)
|
||||
auto sender = std::make_unique<MockDataSender>();
|
||||
EXPECT_CALL(*sender, recvRequestInfo)
|
||||
.WillOnce(Return(RequestInfo{0,
|
||||
texec::DataTransceiverState{texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, nvinfer1::DataType::kFLOAT},
|
||||
texec::DataTransceiverState{
|
||||
texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, 1, nvinfer1::DataType::kFLOAT},
|
||||
texec::kv_cache::CommState{std::vector<SizeType32>{0}, 0}}}));
|
||||
EXPECT_CALL(*sender, sendSync).WillOnce(Return());
|
||||
EXPECT_CALL(*sender, getCounterpartsCount).WillOnce(Return(1));
|
||||
@ -318,7 +319,7 @@ protected:
|
||||
dataType, sinkTokenLength, stream, std::nullopt, enableBlockReuse, onboardBlocks, CacheType::kSELF,
|
||||
std::nullopt, nullptr, true);
|
||||
mCacheState = std::make_unique<texec::kv_cache::CacheState>(
|
||||
numLayers, numHeads, sizePerHead, tokensPerBlock, 1, 1, dataType);
|
||||
numLayers, numHeads, sizePerHead, tokensPerBlock, 1, 1, 1, dataType);
|
||||
|
||||
if (tensorrt_llm::common::getEnvUseUCXKvCache())
|
||||
{
|
||||
@ -506,7 +507,7 @@ TEST_F(SymmetricalCacheTest, SimpleTest)
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
|
||||
using AsymmetricTestParam
|
||||
= std::tuple<int, int, int, int, int, int, int, int, nvinfer1::DataType, int, bool, bool, bool, bool>;
|
||||
= std::tuple<int, int, int, int, int, int, int, int, int, int, nvinfer1::DataType, int, bool, bool, bool, bool>;
|
||||
|
||||
class AsymmetricalCacheTest : public ::testing::TestWithParam<AsymmetricTestParam>
|
||||
{
|
||||
@ -516,8 +517,8 @@ protected:
|
||||
|
||||
void TearDown() override {}
|
||||
|
||||
void setUpCommunicator(int contextTp, int contextPp, int genTp, int genPp, bool isMLA = false,
|
||||
bool contextDP = false, bool generationDP = false)
|
||||
void setUpCommunicator(int contextTp, int contextPp, int contextCp, int genTp, int genPp, int genCp,
|
||||
bool isMLA = false, bool contextDP = false, bool generationDP = false)
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
tensorrt_llm::mpi::initialize(tensorrt_llm::mpi::MpiThreadSupport::THREAD_MULTIPLE);
|
||||
@ -572,11 +573,13 @@ protected:
|
||||
{
|
||||
mTpSize = contextTp;
|
||||
mPpSize = contextPp;
|
||||
mCpSize = contextCp;
|
||||
}
|
||||
if (mIsGeneration)
|
||||
{
|
||||
mTpSize = genTp;
|
||||
mPpSize = genPp;
|
||||
mCpSize = genCp;
|
||||
}
|
||||
|
||||
mTpRank = mRankInInstance % mTpSize;
|
||||
@ -585,6 +588,7 @@ protected:
|
||||
mGenRankSize = genRanks;
|
||||
mContextTpSize = contextTp;
|
||||
mContextPpSize = contextPp;
|
||||
mContextCpSize = contextCp;
|
||||
|
||||
EXPECT_EQ((sessionComm.getRank()), mRankInInstance);
|
||||
EXPECT_EQ(sessionComm.getSize(), mSizeInInstance);
|
||||
@ -696,11 +700,12 @@ protected:
|
||||
texec::kv_cache::CacheState::AttentionType attentionType = isMLA
|
||||
? texec::kv_cache::CacheState::AttentionType::kMLA
|
||||
: texec::kv_cache::CacheState::AttentionType::kDEFAULT;
|
||||
mCacheState = std::make_unique<texec::kv_cache::CacheState>(numLayers, numHeadsPerRank, sizePerHead,
|
||||
tokensPerBlock, mTpSize, mPpSize, dataType, attentionType, kvFactor, enableDPAttention, DPrank, DPsize);
|
||||
mCacheState
|
||||
= std::make_unique<texec::kv_cache::CacheState>(numLayers, numHeadsPerRank, sizePerHead, tokensPerBlock,
|
||||
mTpSize, mPpSize, mCpSize, dataType, attentionType, kvFactor, enableDPAttention, DPrank, DPsize);
|
||||
mContextCacheState = std::make_unique<texec::kv_cache::CacheState>(numLayers, numHeadsPerRankForContext,
|
||||
sizePerHead, tokensPerBlock, mContextTpSize, mContextPpSize, dataType, attentionType, kvFactor, mContextDP,
|
||||
DPrank, mContextTpSize);
|
||||
sizePerHead, tokensPerBlock, mContextTpSize, mContextPpSize, mContextCpSize, dataType, attentionType,
|
||||
kvFactor, mContextDP, DPrank, mContextTpSize);
|
||||
|
||||
// UVM seems to be incompatible with MPI, and it is continuing to investigate.
|
||||
bool constexpr useUvm = false;
|
||||
@ -859,7 +864,8 @@ protected:
|
||||
texec::kv_cache::CacheState cacheState{mContextCacheState->getModelConfig().mNbKvHeadsPerLayer,
|
||||
mContextCacheState->getModelConfig().mSizePerHead, mContextCacheState->getModelConfig().mTokensPerBlock,
|
||||
mContextCacheState->getParallelConfig().mTensorParallelism,
|
||||
mContextCacheState->getParallelConfig().mPipelineParallelism, mContextCacheState->getDataType(),
|
||||
mContextCacheState->getParallelConfig().mPipelineParallelism,
|
||||
mContextCacheState->getParallelConfig().mContextParallelism, mContextCacheState->getDataType(),
|
||||
mContextCacheState->getAttentionConfig().mAttentionType, mContextCacheState->getAttentionConfig().mKvFactor,
|
||||
mContextCacheState->getParallelConfig().mEnableAttentionDP, contextDpRank,
|
||||
mContextCacheState->getParallelConfig().mTensorParallelism};
|
||||
@ -1094,8 +1100,8 @@ protected:
|
||||
tensorrt_llm::mpi::MpiComm const* mComm;
|
||||
tensorrt_llm::mpi::MpiComm mParticipatingComm{nullptr, false};
|
||||
SizeType32 mWorldSize{0}, mRank{0}, mRankInInstance{0};
|
||||
SizeType32 mSizeInInstance{0}, mTpRank{0}, mPpRank{0}, mTpSize{0}, mPpSize{0}, mContextRankSize{0}, mGenRankSize{0},
|
||||
mContextTpSize{0}, mContextPpSize{0};
|
||||
SizeType32 mSizeInInstance{0}, mTpRank{0}, mPpRank{0}, mTpSize{0}, mPpSize{0}, mCpSize{0}, mContextRankSize{0},
|
||||
mGenRankSize{0}, mContextTpSize{0}, mContextPpSize{0}, mContextCpSize{0};
|
||||
LlmRequest::RequestIdType mRequestId{0};
|
||||
bool mContextDP{false};
|
||||
bool mGenerationDP{false};
|
||||
@ -1129,22 +1135,24 @@ TEST_P(AsymmetricalCacheTest, TestCase)
|
||||
AsymmetricTestParam param = GetParam();
|
||||
int contextTp = std::get<0>(param);
|
||||
int contextPp = std::get<1>(param);
|
||||
int genTp = std::get<2>(param);
|
||||
int genPp = std::get<3>(param);
|
||||
int numLayers = std::get<4>(param);
|
||||
int numHeads = std::get<5>(param);
|
||||
int sizePerHead = std::get<6>(param);
|
||||
int tokensPerBlock = std::get<7>(param);
|
||||
nvinfer1::DataType dataType = std::get<8>(param);
|
||||
int contextCp = std::get<2>(param);
|
||||
int genTp = std::get<3>(param);
|
||||
int genPp = std::get<4>(param);
|
||||
int genCp = std::get<5>(param);
|
||||
int numLayers = std::get<6>(param);
|
||||
int numHeads = std::get<7>(param);
|
||||
int sizePerHead = std::get<8>(param);
|
||||
int tokensPerBlock = std::get<9>(param);
|
||||
nvinfer1::DataType dataType = std::get<10>(param);
|
||||
|
||||
int kvFactor = std::get<9>(param);
|
||||
bool isMLA = std::get<10>(param);
|
||||
bool contextDP = std::get<11>(param);
|
||||
bool generationDP = std::get<12>(param);
|
||||
int kvFactor = std::get<11>(param);
|
||||
bool isMLA = std::get<12>(param);
|
||||
bool contextDP = std::get<13>(param);
|
||||
bool generationDP = std::get<14>(param);
|
||||
|
||||
bool isWindow = std::get<13>(param);
|
||||
bool isWindow = std::get<15>(param);
|
||||
|
||||
setUpCommunicator(contextTp, contextPp, genTp, genPp, isMLA, contextDP, generationDP);
|
||||
setUpCommunicator(contextTp, contextPp, contextCp, genTp, genPp, genCp, isMLA, contextDP, generationDP);
|
||||
|
||||
if (mIsContext || mIsGeneration)
|
||||
{
|
||||
@ -1221,21 +1229,23 @@ TEST_P(AsymmetricalCacheTestWithDP, TestCase)
|
||||
AsymmetricTestParam param = GetParam();
|
||||
int contextTp = std::get<0>(param);
|
||||
int contextPp = std::get<1>(param);
|
||||
int genTp = std::get<2>(param);
|
||||
int genPp = std::get<3>(param);
|
||||
int numLayers = std::get<4>(param);
|
||||
int numHeads = std::get<5>(param);
|
||||
int sizePerHead = std::get<6>(param);
|
||||
int tokensPerBlock = std::get<7>(param);
|
||||
nvinfer1::DataType dataType = std::get<8>(param);
|
||||
int contextCp = std::get<2>(param);
|
||||
int genTp = std::get<3>(param);
|
||||
int genPp = std::get<4>(param);
|
||||
int genCp = std::get<5>(param);
|
||||
int numLayers = std::get<6>(param);
|
||||
int numHeads = std::get<7>(param);
|
||||
int sizePerHead = std::get<8>(param);
|
||||
int tokensPerBlock = std::get<9>(param);
|
||||
nvinfer1::DataType dataType = std::get<10>(param);
|
||||
|
||||
int kvFactor = std::get<9>(param);
|
||||
bool isMLA = std::get<10>(param);
|
||||
bool contextDP = std::get<11>(param);
|
||||
bool generationDP = std::get<12>(param);
|
||||
bool isWindow = std::get<13>(param);
|
||||
int kvFactor = std::get<11>(param);
|
||||
bool isMLA = std::get<12>(param);
|
||||
bool contextDP = std::get<13>(param);
|
||||
bool generationDP = std::get<14>(param);
|
||||
bool isWindow = std::get<15>(param);
|
||||
|
||||
setUpCommunicator(contextTp, contextPp, genTp, genPp, isMLA, contextDP, generationDP);
|
||||
setUpCommunicator(contextTp, contextPp, contextCp, genTp, genPp, genCp, isMLA, contextDP, generationDP);
|
||||
|
||||
if (mIsContext || mIsGeneration)
|
||||
{
|
||||
@ -1324,95 +1334,95 @@ TEST_P(AsymmetricalCacheTestWithDP, TestCase)
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTest0, AsymmetricalCacheTest,
|
||||
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2),
|
||||
testing::Values(4), testing::Values(4), testing::Values(4), testing::Values(16),
|
||||
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
|
||||
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(1, 2),
|
||||
testing::Values(1, 2), testing::Values(1), testing::Values(4), testing::Values(4), testing::Values(4),
|
||||
testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
|
||||
testing::Values(false), testing::Values(false), testing::Values(false), testing::Values(true, false)));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithWindow, AsymmetricalCacheTest,
|
||||
testing::Combine(testing::Values(1), testing::Values(1), testing::Values(1), testing::Values(1), testing::Values(5),
|
||||
testing::Values(4), testing::Values(4), testing::Values(8),
|
||||
testing::Combine(testing::Values(1), testing::Values(1), testing::Values(1), testing::Values(1), testing::Values(1),
|
||||
testing::Values(1), testing::Values(5), testing::Values(4), testing::Values(4), testing::Values(8),
|
||||
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
|
||||
testing::Values(false), testing::Values(false), testing::Values(false), testing::Values(true)));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTest1, AsymmetricalCacheTest,
|
||||
testing::Combine(testing::Values(4), testing::Values(1), testing::Values(1), testing::Values(4), testing::Values(8),
|
||||
testing::Values(4), testing::Values(4), testing::Values(8),
|
||||
testing::Combine(testing::Values(4), testing::Values(1), testing::Values(1), testing::Values(1), testing::Values(4),
|
||||
testing::Values(1), testing::Values(8), testing::Values(4), testing::Values(4), testing::Values(8),
|
||||
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
|
||||
testing::Values(false), testing::Values(false), testing::Values(false), testing::Values(false, true)));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTest2, AsymmetricalCacheTest,
|
||||
testing::Combine(testing::Values(1), testing::Values(2), testing::Values(1), testing::Values(1, 4),
|
||||
testing::Values(16), testing::Values(16), testing::Values(4), testing::Values(8),
|
||||
testing::Values(nvinfer1::DataType::kFLOAT), testing::Values(2), testing::Values(false), testing::Values(false),
|
||||
testing::Values(false), testing::Values(false)));
|
||||
testing::Combine(testing::Values(1), testing::Values(2), testing::Values(1), testing::Values(1),
|
||||
testing::Values(1, 4), testing::Values(1), testing::Values(16), testing::Values(16), testing::Values(4),
|
||||
testing::Values(8), testing::Values(nvinfer1::DataType::kFLOAT), testing::Values(2), testing::Values(false),
|
||||
testing::Values(false), testing::Values(false), testing::Values(false)));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTest0ForMLA, AsymmetricalCacheTest,
|
||||
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2),
|
||||
testing::Values(4), testing::Values(1), testing::Values(4), testing::Values(16),
|
||||
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1),
|
||||
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(1, 2),
|
||||
testing::Values(1, 2), testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(4),
|
||||
testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1),
|
||||
testing::Values(true), testing::Values(false), testing::Values(false), testing::Values(false)));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTest1ForMLA, AsymmetricalCacheTest,
|
||||
testing::Combine(testing::Values(4), testing::Values(1), testing::Values(1), testing::Values(4), testing::Values(4),
|
||||
testing::Values(1), testing::Values(4), testing::Values(8),
|
||||
testing::Combine(testing::Values(4), testing::Values(1), testing::Values(1), testing::Values(1), testing::Values(4),
|
||||
testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(4), testing::Values(8),
|
||||
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1),
|
||||
testing::Values(true), testing::Values(false), testing::Values(false), testing::Values(false)));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForMLA1, AsymmetricalCacheTestWithDP,
|
||||
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2),
|
||||
testing::Values(4), testing::Values(1), testing::Values(4), testing::Values(16),
|
||||
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1),
|
||||
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(1, 2),
|
||||
testing::Values(1, 2), testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(4),
|
||||
testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1),
|
||||
testing::Values(true), testing::Values(true), testing::Values(true), testing::Values(false)));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForMLA2, AsymmetricalCacheTestWithDP,
|
||||
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2),
|
||||
testing::Values(4), testing::Values(1), testing::Values(4), testing::Values(16),
|
||||
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1),
|
||||
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(1, 2),
|
||||
testing::Values(1, 2), testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(4),
|
||||
testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1),
|
||||
testing::Values(true), testing::Values(true), testing::Values(false), testing::Values(false)));
|
||||
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForMLA3, AsymmetricalCacheTestWithDP,
|
||||
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2),
|
||||
testing::Values(4), testing::Values(1), testing::Values(4), testing::Values(16),
|
||||
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1),
|
||||
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(1, 2),
|
||||
testing::Values(1, 2), testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(4),
|
||||
testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1),
|
||||
testing::Values(true), testing::Values(false), testing::Values(true), testing::Values(false)));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLA, AsymmetricalCacheTestWithDP,
|
||||
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2),
|
||||
testing::Values(4), testing::Values(4), testing::Values(4), testing::Values(16),
|
||||
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
|
||||
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(1, 2),
|
||||
testing::Values(1, 2), testing::Values(1), testing::Values(4), testing::Values(4), testing::Values(4),
|
||||
testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
|
||||
testing::Values(false), testing::Values(true), testing::Values(true), testing::Values(false)));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLA1, AsymmetricalCacheTestWithDP,
|
||||
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2),
|
||||
testing::Values(4), testing::Values(4), testing::Values(4), testing::Values(16),
|
||||
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
|
||||
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(1, 2),
|
||||
testing::Values(1, 2), testing::Values(1), testing::Values(4), testing::Values(4), testing::Values(4),
|
||||
testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
|
||||
testing::Values(false), testing::Values(true), testing::Values(false), testing::Values(false)));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLA2, AsymmetricalCacheTestWithDP,
|
||||
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2), testing::Values(1, 2),
|
||||
testing::Values(4), testing::Values(4), testing::Values(4), testing::Values(16),
|
||||
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
|
||||
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(1, 2),
|
||||
testing::Values(1, 2), testing::Values(1), testing::Values(4), testing::Values(4), testing::Values(4),
|
||||
testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
|
||||
testing::Values(false), testing::Values(false), testing::Values(true), testing::Values(false)));
|
||||
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate0, AsymmetricalCacheTestWithDP,
|
||||
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(4), testing::Values(1),
|
||||
testing::Values(4), testing::Values(2), testing::Values(4), testing::Values(16),
|
||||
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
|
||||
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(4),
|
||||
testing::Values(1), testing::Values(1), testing::Values(4), testing::Values(2), testing::Values(4),
|
||||
testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
|
||||
testing::Values(false), testing::Values(true, false), testing::Values(false), testing::Values(false)));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate1, AsymmetricalCacheTestWithDP,
|
||||
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(2), testing::Values(2),
|
||||
testing::Values(4), testing::Values(1), testing::Values(4), testing::Values(16),
|
||||
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
|
||||
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(2),
|
||||
testing::Values(2), testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(4),
|
||||
testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
|
||||
testing::Values(false), testing::Values(true, false), testing::Values(false), testing::Values(false)));
|
||||
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate2, AsymmetricalCacheTestWithDP,
|
||||
testing::Combine(testing::Values(4), testing::Values(1), testing::Values(4, 2), testing::Values(1),
|
||||
testing::Values(4), testing::Values(2), testing::Values(4), testing::Values(16),
|
||||
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
|
||||
testing::Combine(testing::Values(4), testing::Values(1), testing::Values(1), testing::Values(4, 2),
|
||||
testing::Values(1), testing::Values(1), testing::Values(4), testing::Values(2), testing::Values(4),
|
||||
testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
|
||||
testing::Values(false), testing::Values(false), testing::Values(false), testing::Values(false)));
|
||||
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate4, AsymmetricalCacheTestWithDP,
|
||||
testing::Combine(testing::Values(4), testing::Values(1), testing::Values(1, 2), testing::Values(2),
|
||||
testing::Values(4), testing::Values(1, 2), testing::Values(4), testing::Values(16),
|
||||
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
|
||||
testing::Combine(testing::Values(4), testing::Values(1), testing::Values(1), testing::Values(1, 2),
|
||||
testing::Values(2), testing::Values(1), testing::Values(4), testing::Values(1, 2), testing::Values(4),
|
||||
testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
|
||||
testing::Values(false), testing::Values(false), testing::Values(false), testing::Values(false)));
|
||||
|
||||
#endif
|
||||
@ -1430,8 +1440,10 @@ TEST(targetTest, CacheStateNODP)
|
||||
|
||||
int contextPP = 2;
|
||||
int contextTP = 4;
|
||||
int contextCP = 1;
|
||||
int genPP = 2;
|
||||
int genTP = 2;
|
||||
int genCP = 1;
|
||||
bool const contextEnableDP = false;
|
||||
bool const genEnableDP = false;
|
||||
|
||||
@ -1441,10 +1453,10 @@ TEST(targetTest, CacheStateNODP)
|
||||
auto attentionType = isMLA ? texec::kv_cache::CacheState::AttentionType::kMLA
|
||||
: texec::kv_cache::CacheState::AttentionType::kDEFAULT;
|
||||
auto const contextCache = tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead,
|
||||
tokensPerBlock, contextTP, contextPP, dataType, attentionType, kvFactor, contextEnableDP, 0, 0};
|
||||
tokensPerBlock, contextTP, contextPP, contextCP, dataType, attentionType, kvFactor, contextEnableDP, 0, 0};
|
||||
|
||||
auto const genCache = tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead,
|
||||
tokensPerBlock, genTP, genPP, dataType, attentionType, kvFactor, genEnableDP, 0, 0};
|
||||
tokensPerBlock, genTP, genPP, genCP, dataType, attentionType, kvFactor, genEnableDP, 0, 0};
|
||||
|
||||
auto const contextTragetInfo
|
||||
= tensorrt_llm::executor::kv_cache::TargetRanksInfoForDP(genCache, contextCache, contextRank);
|
||||
@ -1504,8 +1516,10 @@ TEST(targetTest, CacheStateContextDP)
|
||||
|
||||
int contextPP = 1;
|
||||
int contextTP = 4;
|
||||
int contextCP = 1;
|
||||
int genPP = 1;
|
||||
int genTP = 2;
|
||||
int genCP = 1;
|
||||
bool contextEnableDP = true;
|
||||
bool genEnableDP = true;
|
||||
|
||||
@ -1519,10 +1533,11 @@ TEST(targetTest, CacheStateContextDP)
|
||||
|
||||
auto const contextCache
|
||||
= tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead, tokensPerBlock, contextTP,
|
||||
contextPP, dataType, attentionType, kvFactor, contextEnableDP, contextDPRank, contextTP};
|
||||
contextPP, contextCP, dataType, attentionType, kvFactor, contextEnableDP, contextDPRank, contextTP};
|
||||
|
||||
auto const genCache = tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead,
|
||||
tokensPerBlock, genTP, genPP, dataType, attentionType, kvFactor, genEnableDP, generationDPRank, genTP};
|
||||
auto const genCache
|
||||
= tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead, tokensPerBlock, genTP,
|
||||
genPP, genCP, dataType, attentionType, kvFactor, genEnableDP, generationDPRank, genTP};
|
||||
|
||||
auto const contextTragetInfo
|
||||
= tensorrt_llm::executor::kv_cache::TargetRanksInfoForDP(genCache, contextCache, contextRank);
|
||||
@ -1625,10 +1640,11 @@ TEST(targetTest, CacheStateContextDP)
|
||||
|
||||
auto const contextCache
|
||||
= tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead, tokensPerBlock, contextTP,
|
||||
contextPP, dataType, attentionType, kvFactor, contextEnableDP, contextDPRank, contextTP};
|
||||
contextPP, contextCP, dataType, attentionType, kvFactor, contextEnableDP, contextDPRank, contextTP};
|
||||
|
||||
auto const genCache = tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead,
|
||||
tokensPerBlock, genTP, genPP, dataType, attentionType, kvFactor, genEnableDP, generationDPRank, genTP};
|
||||
auto const genCache
|
||||
= tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead, tokensPerBlock, genTP,
|
||||
genPP, genCP, dataType, attentionType, kvFactor, genEnableDP, generationDPRank, genTP};
|
||||
|
||||
auto const contextTragetInfo
|
||||
= tensorrt_llm::executor::kv_cache::TargetRanksInfoForDP(contextCache, genCache, generationRank);
|
||||
|
||||
@ -90,7 +90,7 @@ protected:
|
||||
|
||||
size_t maxNumTokens = 1024;
|
||||
mTransBufferManager = std::make_unique<CacheTransBufferManager>(mCacheManager.get(), maxNumTokens);
|
||||
mCacheState = std::make_unique<CacheState>(numLayers, numHeads, sizePerHead, tokensPerBlock, 1, 1, dataType);
|
||||
mCacheState = std::make_unique<CacheState>(numLayers, numHeads, sizePerHead, tokensPerBlock, 1, 1, 1, dataType);
|
||||
}
|
||||
|
||||
void TearDown() override
|
||||
|
||||
@ -726,7 +726,7 @@ TEST(SerializeUtilsTest, ContextPhaseParams)
|
||||
{
|
||||
auto state = std::make_unique<texec::DataTransceiverState>();
|
||||
state->setCommState(texec::kv_cache::CommState{12, "127.0.0.1"});
|
||||
state->setCacheState(texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, nvinfer1::DataType::kFLOAT});
|
||||
state->setCacheState(texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, 8, nvinfer1::DataType::kFLOAT});
|
||||
auto stats = texec::ContextPhaseParams({10, 20, 30, 40, 50, 60}, 0, state.release(), VecTokens{10, 20});
|
||||
auto stats2 = serializeDeserialize(stats);
|
||||
EXPECT_EQ(stats, stats2);
|
||||
|
||||
@ -255,7 +255,8 @@ TEST_F(TransferAgentTest, SyncMessage)
|
||||
checked = nixlAgent0->checkRemoteDescs(agent1, regMem3.getDescs());
|
||||
} while (!checked);
|
||||
auto syncMessage = std::string("agent_sync_message");
|
||||
TransferRequest writeReq{TransferOp::kWRITE, regMem0.getDescs(), regMem3.getDescs(), agent1, syncMessage};
|
||||
nixlAgent0->notifySyncMessage(agent1, syncMessage);
|
||||
TransferRequest writeReq{TransferOp::kWRITE, regMem0.getDescs(), regMem3.getDescs(), agent1};
|
||||
auto status = nixlAgent0->submitTransferRequests(writeReq);
|
||||
|
||||
auto notif = nixlAgent1->getNotifiedSyncMessages();
|
||||
@ -302,7 +303,8 @@ TEST_F(TransferAgentTest, SyncMessage)
|
||||
} while (!checked2);
|
||||
|
||||
std::string syncMessage4 = "four_agent_sync_message";
|
||||
TransferRequest writeReq1{TransferOp::kWRITE, regMem2.getDescs(), regMem1.getDescs(), agent0, syncMessage4};
|
||||
nixlAgent1->notifySyncMessage(agent0, syncMessage4);
|
||||
TransferRequest writeReq1{TransferOp::kWRITE, regMem2.getDescs(), regMem1.getDescs(), agent0};
|
||||
auto status1 = nixlAgent1->submitTransferRequests(writeReq1);
|
||||
auto notif4 = nixlAgent0->getNotifiedSyncMessages();
|
||||
for (std::size_t i = 0; i < MAX_QUERY_TIMES && notif4.size() == 0; i++)
|
||||
|
||||
@ -370,8 +370,8 @@ protected:
|
||||
|
||||
float mSparseMixerEpsilon = 0.2f;
|
||||
|
||||
// Default this to true. This only matters for K>2, and so by doing this we will test the fused and unfused paths
|
||||
bool mUseDeterministicHopperReduce = true;
|
||||
// Default this to false. This only matters for K>2, and so by doing this we will test the fused and unfused paths
|
||||
bool mUseFusedFinalize = false;
|
||||
|
||||
// Disable this for long running tests to speed up runtime
|
||||
bool mIsLongTest = false;
|
||||
@ -456,7 +456,7 @@ protected:
|
||||
{
|
||||
managed_buffers.clear();
|
||||
|
||||
mMoERunner.use_fused_finalize_ = k < 3 || !mUseDeterministicHopperReduce;
|
||||
mMoERunner.use_fused_finalize_ = k < 3 || mUseFusedFinalize;
|
||||
|
||||
mHiddenSize = hidden_size;
|
||||
mInterSize = hidden_size * mInterSizeFraction;
|
||||
@ -1087,9 +1087,9 @@ protected:
|
||||
return std::tuple{(void*) weight_1, (void*) weight_2, bias_1, bias2_ptr, scale_1, scale_2, scale_3};
|
||||
}
|
||||
|
||||
auto getFilteredConfigs(int sm)
|
||||
auto getFilteredConfigs(int sm, MoeGemmId gemm_id)
|
||||
{
|
||||
auto tactics = mMoERunner.getTactics();
|
||||
auto tactics = mMoERunner.getTactics(gemm_id);
|
||||
if (sm == 89 || sm >= 120)
|
||||
{
|
||||
// Filter some unsupported configs for L40S
|
||||
@ -1120,17 +1120,27 @@ protected:
|
||||
auto selectTacticsForArch(int sm)
|
||||
{
|
||||
bool is_tma_warp_specialized = sm >= 90 && !INT_QUANT;
|
||||
auto tactics = getFilteredConfigs(sm);
|
||||
auto it = std::find_if(tactics.begin(), tactics.end(),
|
||||
auto epilogue_fusion_type = (is_tma_warp_specialized && mUseFusedFinalize)
|
||||
? tensorrt_llm::cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE
|
||||
: tensorrt_llm::cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::NONE;
|
||||
auto tactics1 = getFilteredConfigs(sm, MoeGemmId::GEMM_1);
|
||||
auto tactics2 = getFilteredConfigs(sm, MoeGemmId::GEMM_2);
|
||||
auto it1 = std::find_if(tactics1.begin(), tactics1.end(),
|
||||
[is_tma_warp_specialized](auto& c) { return c.is_tma_warp_specialized == is_tma_warp_specialized; });
|
||||
if (it == tactics.end())
|
||||
auto it2 = std::find_if(tactics2.begin(), tactics2.end(),
|
||||
[is_tma_warp_specialized, epilogue_fusion_type](auto& c) {
|
||||
return c.is_tma_warp_specialized == is_tma_warp_specialized
|
||||
&& c.epilogue_fusion_type == epilogue_fusion_type;
|
||||
});
|
||||
if (it1 == tactics1.end() || it2 == tactics2.end())
|
||||
{
|
||||
// Fall back to any tactic
|
||||
std::cout << "WARNING: Could not find config for sm version " << sm << std::endl;
|
||||
return std::pair{tactics[0], tactics[0]};
|
||||
it1 = (it1 == tactics1.end()) ? tactics1.begin() : it1;
|
||||
it2 = (it2 == tactics2.end()) ? tactics2.begin() : it2;
|
||||
}
|
||||
|
||||
return std::pair(*it, *it);
|
||||
return std::pair(*it1, *it2);
|
||||
}
|
||||
|
||||
using ConfigsToTestVec = std::vector<std::pair<tensorrt_llm::cutlass_extensions::CutlassGemmConfig,
|
||||
@ -1164,7 +1174,7 @@ protected:
|
||||
auto stream = mStream->get();
|
||||
auto tactic1 = mInternalSelectedConfig1;
|
||||
auto tactic2 = mInternalSelectedConfig2;
|
||||
if (!tactic1)
|
||||
if (!tactic1 || !tactic2)
|
||||
{
|
||||
int sm = getSMVersion();
|
||||
std::tie(tactic1, tactic2) = selectTacticsForArch(sm);
|
||||
@ -1629,8 +1639,9 @@ void MixtureOfExpertsTest<TypeParam_>::BasicPermuteTest(
|
||||
auto [expected_experts, token_final_scales] = populateRouting(num_experts, num_tokens, k);
|
||||
|
||||
runMoEPermute(hidden_input, expected_experts, token_final_scales, hidden_size, num_experts, k);
|
||||
bool should_be_deterministic
|
||||
= mUseDeterministicHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120;
|
||||
bool is_finalize_fusion = gemm2.epilogue_fusion_type
|
||||
== tensorrt_llm::cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE;
|
||||
bool should_be_deterministic = !is_finalize_fusion || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120;
|
||||
if (should_be_deterministic && !mIsLongTest)
|
||||
{
|
||||
auto first_iter = getDataFromDevice(mFinalOutput, mTotalTokens * mHiddenSize);
|
||||
@ -1749,7 +1760,7 @@ TYPED_TEST(MixtureOfExpertsTest, PermuteSwigluBias)
|
||||
|
||||
TYPED_TEST(MixtureOfExpertsTest, PermuteNonDeterministic)
|
||||
{
|
||||
this->mUseDeterministicHopperReduce = false;
|
||||
this->mUseFusedFinalize = true;
|
||||
// Just test case 3, cases 1&2 always use the fused paths
|
||||
this->BasicPermuteTest(3);
|
||||
}
|
||||
@ -1896,8 +1907,10 @@ void MixtureOfExpertsTest<TypeParam_>::ParallelismTest(
|
||||
// Only need to init the inputs on the first iteration
|
||||
runMoEPermute(hidden_input, expected_experts, token_final_scales, hidden_size, num_experts, k,
|
||||
MOEParallelismConfig{tp_size, i, ep_size, j}, enable_alltoall);
|
||||
bool is_finalize_fusion = gemm2.epilogue_fusion_type
|
||||
== tensorrt_llm::cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE;
|
||||
bool should_be_deterministic
|
||||
= mUseDeterministicHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120;
|
||||
= !is_finalize_fusion || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120;
|
||||
if (should_be_deterministic && !mIsLongTest)
|
||||
{
|
||||
auto first_iter = getDataFromDevice(mFinalOutput, mTotalTokens * mHiddenSize);
|
||||
@ -1912,8 +1925,10 @@ void MixtureOfExpertsTest<TypeParam_>::ParallelismTest(
|
||||
else
|
||||
{
|
||||
runMoEPermute(MOEParallelismConfig{tp_size, i, ep_size, j}, enable_alltoall);
|
||||
bool is_finalize_fusion = gemm2.epilogue_fusion_type
|
||||
== tensorrt_llm::cutlass_extensions::CutlassGemmConfig::EpilogueFusionType::FINALIZE;
|
||||
bool should_be_deterministic
|
||||
= mUseDeterministicHopperReduce || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120;
|
||||
= !is_finalize_fusion || mK < 3 || getSMVersion() < 90 || getSMVersion() >= 120;
|
||||
if (should_be_deterministic && !mIsLongTest)
|
||||
{
|
||||
auto first_iter = getDataFromDevice(mFinalOutput, mTotalTokens * mHiddenSize);
|
||||
@ -2077,6 +2092,7 @@ PARALLEL_TEST_SUITE(MixedParallel)
|
||||
TYPED_TEST(MixtureOfExpertsTest, ConfigSweep)
|
||||
{
|
||||
this->mIsLongTest = true;
|
||||
this->mUseFusedFinalize = true; // True for all cases because we sweep both
|
||||
auto genConfigName = [](auto conf) -> std::string
|
||||
{
|
||||
using namespace tensorrt_llm::cutlass_extensions;
|
||||
@ -2103,12 +2119,13 @@ TYPED_TEST(MixtureOfExpertsTest, ConfigSweep)
|
||||
auto activation_pool = std::vector{ActivationType::Relu, ActivationType::Swiglu, ActivationType::SwigluBias};
|
||||
if (this->NVFP4)
|
||||
activation_pool = {ActivationType::Relu};
|
||||
auto configs = this->getFilteredConfigs(getSMVersion());
|
||||
auto configs1 = this->getFilteredConfigs(getSMVersion(), MoeGemmId::GEMM_1);
|
||||
auto configs2 = this->getFilteredConfigs(getSMVersion(), MoeGemmId::GEMM_2);
|
||||
for (auto const activation_type : activation_pool)
|
||||
{
|
||||
for (auto conf1 : configs)
|
||||
for (auto conf1 : configs1)
|
||||
{
|
||||
for (auto conf2 : configs)
|
||||
for (auto conf2 : configs2)
|
||||
{
|
||||
auto name1 = genConfigName(conf1);
|
||||
auto name2 = genConfigName(conf2);
|
||||
@ -2120,7 +2137,6 @@ TYPED_TEST(MixtureOfExpertsTest, ConfigSweep)
|
||||
this->mActType = activation_type;
|
||||
for (auto k : {2, 3})
|
||||
{
|
||||
|
||||
this->mOverrideSelectedConfig1 = conf1;
|
||||
this->mOverrideSelectedConfig2 = conf2;
|
||||
this->BasicPermuteTest(k, this->MINIMUM_ALIGNMENT);
|
||||
|
||||
@ -4,8 +4,9 @@ set -ex
|
||||
GITHUB_URL="https://github.com"
|
||||
UCX_INSTALL_PATH="/usr/local/ucx/"
|
||||
CUDA_PATH="/usr/local/cuda"
|
||||
NIXL_VERSION="0.3.1"
|
||||
NIXL_VERSION="0.5.0"
|
||||
NIXL_REPO="https://github.com/ai-dynamo/nixl.git"
|
||||
OLD_LD_LIBRARY_PATH=$LD_LIBRARY_PATH
|
||||
|
||||
ARCH_NAME="x86_64-linux-gnu"
|
||||
GDS_PATH="$CUDA_PATH/targets/x86_64-linux"
|
||||
@ -18,25 +19,26 @@ pip3 install --no-cache-dir meson ninja pybind11
|
||||
git clone --depth 1 -b ${NIXL_VERSION} ${NIXL_REPO}
|
||||
cd nixl
|
||||
|
||||
cuda_path=$(find / -name "libcuda.so.1" 2>/dev/null | head -n1)
|
||||
if [[ -z "$cuda_path" ]]; then
|
||||
echo "libcuda.so.1 not found "
|
||||
CUDA_SO_PATH=$(find "/usr/local" -name "libcuda.so.1" 2>/dev/null | head -n1)
|
||||
|
||||
if [[ -z "$CUDA_SO_PATH" ]]; then
|
||||
echo "libcuda.so.1 not found"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ln -sf $cuda_path $CUDA_PATH/lib64/libcuda.so.1
|
||||
CUDA_SO_PATH=$(dirname $CUDA_SO_PATH)
|
||||
|
||||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CUDA_SO_PATH
|
||||
meson setup builddir \
|
||||
-Ducx_path=$UCX_INSTALL_PATH \
|
||||
-Dcudapath_lib="$CUDA_PATH/lib64" \
|
||||
-Dcudapath_inc="$CUDA_PATH/include" \
|
||||
-Dgds_path="$GDS_PATH" \
|
||||
-Dinstall_headers=true \
|
||||
-Dstatic_plugins=UCX
|
||||
-Dinstall_headers=true
|
||||
|
||||
cd builddir && ninja install
|
||||
cd ../..
|
||||
rm -rf nixl* # Remove NIXL source tree to save space
|
||||
rm $CUDA_PATH/lib64/libcuda.so.1
|
||||
export LD_LIBRARY_PATH=$OLD_LD_LIBRARY_PATH
|
||||
|
||||
echo "export LD_LIBRARY_PATH=/opt/nvidia/nvda_nixl/lib/${ARCH_NAME}:/opt/nvidia/nvda_nixl/lib64:\$LD_LIBRARY_PATH" >> "${ENV}"
|
||||
|
||||
@ -2,29 +2,28 @@
|
||||
set -ex
|
||||
|
||||
GITHUB_URL="https://github.com"
|
||||
UCX_VERSION="v1.19.0"
|
||||
UCX_VERSION="v1.19.x"
|
||||
UCX_INSTALL_PATH="/usr/local/ucx/"
|
||||
CUDA_PATH="/usr/local/cuda"
|
||||
UCX_REPO="https://github.com/openucx/ucx.git"
|
||||
|
||||
if [ ! -d ${UCX_INSTALL_PATH} ]; then
|
||||
git clone --depth 1 -b ${UCX_VERSION} ${UCX_REPO}
|
||||
cd ucx
|
||||
./autogen.sh
|
||||
./contrib/configure-release \
|
||||
--prefix=${UCX_INSTALL_PATH} \
|
||||
--enable-shared \
|
||||
--disable-static \
|
||||
--disable-doxygen-doc \
|
||||
--enable-optimizations \
|
||||
--enable-cma \
|
||||
--enable-devel-headers \
|
||||
--with-cuda=${CUDA_PATH} \
|
||||
--with-verbs \
|
||||
--with-dm \
|
||||
--enable-mt
|
||||
make install -j$(nproc)
|
||||
cd ..
|
||||
rm -rf ucx # Remove UCX source to save space
|
||||
echo "export LD_LIBRARY_PATH=${UCX_INSTALL_PATH}/lib:\$LD_LIBRARY_PATH" >> "${ENV}"
|
||||
fi
|
||||
rm -rf ${UCX_INSTALL_PATH}
|
||||
git clone --depth 1 -b ${UCX_VERSION} ${UCX_REPO}
|
||||
cd ucx
|
||||
./autogen.sh
|
||||
./contrib/configure-release \
|
||||
--prefix=${UCX_INSTALL_PATH} \
|
||||
--enable-shared \
|
||||
--disable-static \
|
||||
--disable-doxygen-doc \
|
||||
--enable-optimizations \
|
||||
--enable-cma \
|
||||
--enable-devel-headers \
|
||||
--with-cuda=${CUDA_PATH} \
|
||||
--with-verbs \
|
||||
--with-dm \
|
||||
--enable-mt
|
||||
make install -j$(nproc)
|
||||
cd ..
|
||||
rm -rf ucx # Remove UCX source to save space
|
||||
echo "export LD_LIBRARY_PATH=${UCX_INSTALL_PATH}/lib:\$LD_LIBRARY_PATH" >> "${ENV}"
|
||||
|
||||
@ -19,11 +19,11 @@ We have a forthcoming guide for achieving great performance on H100; however, th
|
||||
|
||||
In this section, we introduce several ways to install TensorRT-LLM.
|
||||
|
||||
### NGC Docker Image of dev branch
|
||||
### NGC Docker Image
|
||||
|
||||
Day-0 support for gpt-oss is provided via the NGC container image `nvcr.io/nvidia/tensorrt-llm/release:gpt-oss-dev`. This image was built on top of the pre-day-0 **dev branch**. This container is multi-platform and will run on both x64 and arm64 architectures.
|
||||
Visit the [NGC TensorRT-LLM Release page](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/tensorrt-llm/containers/release) to find the most up-to-date NGC container image to use. You can also check the latest [release notes](https://github.com/NVIDIA/TensorRT-LLM/releases) to keep track of the support status of the latest releases.
|
||||
|
||||
Run the following docker command to start the TensorRT-LLM container in interactive mode:
|
||||
Run the following Docker command to start the TensorRT-LLM container in interactive mode (change the image tag to match latest release):
|
||||
|
||||
```bash
|
||||
docker run --rm --ipc=host -it \
|
||||
@ -33,7 +33,7 @@ docker run --rm --ipc=host -it \
|
||||
-p 8000:8000 \
|
||||
-e TRTLLM_ENABLE_PDL=1 \
|
||||
-v ~/.cache:/root/.cache:rw \
|
||||
nvcr.io/nvidia/tensorrt-llm/release:gpt-oss-dev \
|
||||
nvcr.io/nvidia/tensorrt-llm/release:1.1.0rc0 \
|
||||
/bin/bash
|
||||
```
|
||||
|
||||
@ -53,9 +53,9 @@ Additionally, the container mounts your user `.cache` directory to save the down
|
||||
Support for gpt-oss has been [merged](https://github.com/NVIDIA/TensorRT-LLM/pull/6645) into the **main branch** of TensorRT-LLM. As we continue to optimize gpt-oss performance, you can build TensorRT-LLM from source to get the latest features and support. Please refer to the [doc](https://nvidia.github.io/TensorRT-LLM/latest/installation/build-from-source-linux.html) if you want to build from source yourself.
|
||||
|
||||
|
||||
### Regular Release of TensorRT-LLM
|
||||
### TensorRT-LLM Python Wheel Install
|
||||
|
||||
Since gpt-oss has been supported on the main branch, you can get TensorRT-LLM out of the box through its regular release in the future. Please check the latest [release notes](https://github.com/NVIDIA/TensorRT-LLM/releases) to keep track of the support status. The release is provided as [NGC Container Image](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/tensorrt-llm/containers/release/tags) or [pip Python wheel](https://pypi.org/project/tensorrt-llm/#history). You can find instructions on pip install [here](https://nvidia.github.io/TensorRT-LLM/installation/linux.html).
|
||||
Regular releases of TensorRT-LLM are also provided as [Python wheels](https://pypi.org/project/tensorrt-llm/#history). You can find instructions on the pip install [here](https://nvidia.github.io/TensorRT-LLM/installation/linux.html).
|
||||
|
||||
|
||||
## Performance Benchmarking and Model Serving
|
||||
@ -210,7 +210,10 @@ We can use `trtllm-serve` to serve the model by translating the benchmark comman
|
||||
|
||||
```bash
|
||||
trtllm-serve \
|
||||
gpt-oss-120b \ # Or ${local_model_path}
|
||||
Note: You can also point to a local path containing the model weights instead of the HF repo (e.g., `${local_model_path}`).
|
||||
|
||||
trtllm-serve \
|
||||
openai/gpt-oss-120b \
|
||||
--host 0.0.0.0 \
|
||||
--port 8000 \
|
||||
--backend pytorch \
|
||||
@ -228,7 +231,8 @@ For max-throughput configuration, run:
|
||||
|
||||
```bash
|
||||
trtllm-serve \
|
||||
gpt-oss-120b \ # Or ${local_model_path}
|
||||
trtllm-serve \
|
||||
openai/gpt-oss-120b \
|
||||
--host 0.0.0.0 \
|
||||
--port 8000 \
|
||||
--backend pytorch \
|
||||
@ -262,7 +266,7 @@ curl localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is NVIDIA's advantage for inference?"
|
||||
"content": "What is NVIDIAs advantage for inference?"
|
||||
}
|
||||
],
|
||||
"max_tokens": 1024,
|
||||
@ -348,12 +352,7 @@ others according to your needs.
|
||||
|
||||
## (H200/H100 Only) Using OpenAI Triton Kernels for MoE
|
||||
|
||||
OpenAI ships a set of Triton kernels optimized for its MoE models. TensorRT-LLM can leverage these kernels for Hopper-based GPUs like NVIDIA's H200 for optimal performance. `TRTLLM` MoE backend is not supported on Hopper, and `CUTLASS` backend support is still ongoing. Please enable `TRITON` backend with the steps below if you are running on Hopper GPUs.
|
||||
|
||||
### Installing OpenAI Triton
|
||||
|
||||
The `nvcr.io/nvidia/tensorrt-llm/release:gpt-oss-dev` has prepared Triton already (`echo $TRITON_ROOT` could reveal the path). In other situations, you will need to build and install a specific version of Triton. Please follow the instructions in this [link](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/models/core/gpt_oss#using-openai-triton-kernels-for-moe).
|
||||
|
||||
OpenAI ships a set of Triton kernels optimized for its MoE models. TensorRT-LLM can leverage these kernels for Hopper-based GPUs like NVIDIA's H200 for optimal performance. `TRTLLM` MoE backend is not supported on Hopper, and `CUTLASS` backend support is still ongoing. Please follow the instructions in this [link](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/models/core/gpt_oss#using-openai-triton-kernels-for-moe) to install and enable the `TRITON` MoE kernels on Hopper GPUs.
|
||||
|
||||
### Selecting Triton as the MoE backend
|
||||
|
||||
|
||||
@ -0,0 +1,328 @@
|
||||
# Quick Start Recipe for GPT-OSS on TensorRT-LLM - Blackwell Hardware
|
||||
|
||||
## Introduction
|
||||
|
||||
This deployment guide provides step-by-step instructions for running the GPT-OSS model using TensorRT-LLM, optimized for NVIDIA GPUs. It covers the complete setup required; from accessing model weights and preparing the software environment to configuring TensorRT-LLM parameters, launching the server, and validating inference output.
|
||||
|
||||
The guide is intended for developers and practitioners seeking high-throughput or low-latency inference using NVIDIA’s accelerated stack—starting with the PyTorch container from NGC, then installing TensorRT-LLM for model serving.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
* GPU: NVIDIA Blackwell Architecture
|
||||
* OS: Linux
|
||||
* Drivers: CUDA Driver 575 or Later
|
||||
* Docker with NVIDIA Container Toolkit installed
|
||||
* Python3 and python3-pip (Optional, for accuracy evaluation only)
|
||||
|
||||
## Models
|
||||
|
||||
* MXFP4 model: [GPT-OSS-120B](https://huggingface.co/openai/gpt-oss-120b)
|
||||
|
||||
|
||||
## MoE Backend Support Matrix
|
||||
|
||||
There are multiple MOE backends inside TRT-LLM. Here are the support matrix of the MOE backends.
|
||||
|
||||
| Device | Activation Type | MoE Weights Type | MoE Backend | Use Case |
|
||||
|------------|------------------|------------------|-------------|----------------|
|
||||
| B200/GB200 | MXFP8 | MXFP4 | TRTLLM | Low Latency |
|
||||
| B200/GB200 | MXFP8 | MXFP4 | CUTLASS | Max Throughput |
|
||||
|
||||
The default moe backend is `CUTLASS`, so for the combination which is not supported by `CUTLASS`, one must set the `moe_config.backend` explicitly to run the model.
|
||||
|
||||
## Deployment Steps
|
||||
|
||||
### Run Docker Container
|
||||
|
||||
Run the docker container using the TensorRT-LLM NVIDIA NGC image.
|
||||
|
||||
```shell
|
||||
docker run --rm -it \
|
||||
--ipc=host \
|
||||
--gpus all \
|
||||
-p 8000:8000 \
|
||||
-v ~/.cache:/root/.cache:rw \
|
||||
--name tensorrt_llm \
|
||||
nvcr.io/nvidia/tensorrt-llm/release:1.0.0rc6 \
|
||||
/bin/bash
|
||||
```
|
||||
|
||||
Note:
|
||||
|
||||
* The command mounts your user `.cache` directory to save the downloaded model checkpoints which are saved to `~/.cache/huggingface/hub/` by default. This prevents having to redownload the weights each time you rerun the container. If the `~/.cache` directory doesn’t exist please create it using `$ mkdir ~/.cache`.
|
||||
* You can mount additional directories and paths using the `-v <host_path>:<container_path>` flag if needed, such as mounting the downloaded weight paths.
|
||||
* The command also maps port `8000` from the container to your host so you can access the LLM API endpoint from your host
|
||||
* See the <https://catalog.ngc.nvidia.com/orgs/nvidia/teams/tensorrt-llm/containers/release/tags> for all the available containers. The containers published in the main branch weekly have `rcN` suffix, while the monthly release with QA tests has no `rcN` suffix. Use the `rc` release to get the latest model and feature support.
|
||||
|
||||
If you want to use latest main branch, you can choose to build from source to install TensorRT-LLM, the steps refer to <https://nvidia.github.io/TensorRT-LLM/latest/installation/build-from-source-linux.html>.
|
||||
|
||||
### Creating the TRT-LLM Server config
|
||||
|
||||
We create a YAML configuration file `/tmp/config.yml` for the TensorRT-LLM Server and populate it with the following recommended performance settings.
|
||||
|
||||
For low-latency with `TRTLLM` MOE backend:
|
||||
|
||||
```shell
|
||||
EXTRA_LLM_API_FILE=/tmp/config.yml
|
||||
|
||||
cat << EOF > ${EXTRA_LLM_API_FILE}
|
||||
enable_attention_dp: false
|
||||
cuda_graph_config:
|
||||
enable_padding: true
|
||||
max_batch_size: 128
|
||||
moe_config:
|
||||
backend: TRTLLM
|
||||
EOF
|
||||
```
|
||||
|
||||
For max-throughput with `CUTLASS` MOE backend:
|
||||
|
||||
```shell
|
||||
EXTRA_LLM_API_FILE=/tmp/config.yml
|
||||
|
||||
cat << EOF > ${EXTRA_LLM_API_FILE}
|
||||
enable_attention_dp: true
|
||||
cuda_graph_config:
|
||||
enable_padding: true
|
||||
max_batch_size: 128
|
||||
moe_config:
|
||||
backend: CUTLASS
|
||||
EOF
|
||||
```
|
||||
|
||||
### Launch the TRT-LLM Server
|
||||
|
||||
Below is an example command to launch the TRT-LLM server with the GPT-OSS model from within the container. The command is specifically configured for the 1024/1024 Input/Output Sequence Length test. The explanation of each flag is shown in the “Configs and Parameters” section.
|
||||
|
||||
```shell
|
||||
trtllm-serve openai/gpt-oss-120b \
|
||||
--host 0.0.0.0 \
|
||||
--port 8000 \
|
||||
--backend pytorch \
|
||||
--max_batch_size 128 \
|
||||
--max_num_tokens 16384 \
|
||||
--max_seq_len 2048 \
|
||||
--kv_cache_free_gpu_memory_fraction 0.9 \
|
||||
--tp_size 8 \
|
||||
--ep_size 8 \
|
||||
--trust_remote_code \
|
||||
--extra_llm_api_options ${EXTRA_LLM_API_FILE}
|
||||
```
|
||||
|
||||
After the server is set up, the client can now send prompt requests to the server and receive results.
|
||||
|
||||
### Configs and Parameters
|
||||
|
||||
These options are used directly on the command line when you start the `trtllm-serve` process.
|
||||
|
||||
#### `--tp_size`
|
||||
|
||||
* **Description:** Sets the **tensor-parallel size**. This should typically match the number of GPUs you intend to use for a single model instance.
|
||||
|
||||
#### `--ep_size`
|
||||
|
||||
* **Description:** Sets the **expert-parallel size** for Mixture-of-Experts (MoE) models. Like `tp_size`, this should generally match the number of GPUs you're using. This setting has no effect on non-MoE models.
|
||||
|
||||
#### `--kv_cache_free_gpu_memory_fraction`
|
||||
|
||||
* **Description:** A value between `0.0` and `1.0` that specifies the fraction of free GPU memory to reserve for the KV cache after the model is loaded. Since memory usage can fluctuate, this buffer helps prevent out-of-memory (OOM) errors.
|
||||
* **Recommendation:** If you experience OOM errors, try reducing this value to `0.7` or lower.
|
||||
|
||||
#### `--backend pytorch`
|
||||
|
||||
* **Description:** Tells TensorRT-LLM to use the **pytorch** backend.
|
||||
|
||||
#### `--max_batch_size`
|
||||
|
||||
* **Description:** The maximum number of user requests that can be grouped into a single batch for processing.
|
||||
|
||||
#### `--max_num_tokens`
|
||||
|
||||
* **Description:** The maximum total number of tokens (across all requests) allowed inside a single scheduled batch.
|
||||
|
||||
#### `--max_seq_len`
|
||||
|
||||
* **Description:** The maximum possible sequence length for a single request, including both input and generated output tokens.
|
||||
|
||||
#### `--trust_remote_code`
|
||||
|
||||
* **Description:** Allows TensorRT-LLM to download models and tokenizers from Hugging Face. This flag is passed directly to the Hugging Face API.
|
||||
|
||||
|
||||
#### Extra LLM API Options (YAML Configuration)
|
||||
|
||||
These options provide finer control over performance and are set within a YAML file passed to the `trtllm-serve` command via the `--extra_llm_api_options` argument.
|
||||
|
||||
#### `cuda_graph_config`
|
||||
|
||||
* **Description**: A section for configuring CUDA graphs to optimize performance.
|
||||
|
||||
* **Options**:
|
||||
|
||||
* `enable_padding`: If `"true"`, input batches are padded to the nearest `cuda_graph_batch_size`. This can significantly improve performance.
|
||||
|
||||
**Default**: `false`
|
||||
|
||||
* `max_batch_size`: Sets the maximum batch size for which a CUDA graph will be created.
|
||||
|
||||
**Default**: `0`
|
||||
|
||||
**Recommendation**: Set this to the same value as the `--max_batch_size` command-line option.
|
||||
|
||||
#### `moe_config`
|
||||
|
||||
* **Description**: Configuration for Mixture-of-Experts (MoE) models.
|
||||
|
||||
* **Options**:
|
||||
|
||||
* `backend`: The backend to use for MoE operations.
|
||||
**Default**: `CUTLASS`
|
||||
|
||||
See the [`TorchLlmArgs` class](https://nvidia.github.io/TensorRT-LLM/llm-api/reference.html#tensorrt_llm.llmapi.TorchLlmArgs) for the full list of options which can be used in the `extra_llm_api_options`.
|
||||
|
||||
## Testing API Endpoint
|
||||
|
||||
### Basic Test
|
||||
|
||||
Start a new terminal on the host to test the TensorRT-LLM server you just launched.
|
||||
|
||||
You can query the health/readiness of the server using:
|
||||
|
||||
```shell
|
||||
curl -s -o /dev/null -w "Status: %{http_code}\n" "http://localhost:8000/health"
|
||||
```
|
||||
|
||||
When the `Status: 200` code is returned, the server is ready for queries. Note that the very first query may take longer due to initialization and compilation.
|
||||
|
||||
After the TRT-LLM server is set up and shows Application startup complete, you can send requests to the server.
|
||||
|
||||
```shell
|
||||
curl http://localhost:8000/v1/chat/completions -H "Content-Type: application/json" -d '{
|
||||
"model": "openai/gpt-oss-120b",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Where is New York?"
|
||||
}
|
||||
],
|
||||
"max_tokens": 1024,
|
||||
"top_p": 1.0
|
||||
}' -w "\n"
|
||||
```
|
||||
|
||||
Here is an example response, showing that the TRT-LLM server reasons and answers the questions.
|
||||
|
||||
TODO: Use Chat Compeletions API / Responses API as the example after the PR is merged.
|
||||
|
||||
```json
|
||||
{"id":"chatcmpl-c5bf51b5cab94e10ba5da5266d12ee59","object":"chat.completion","created":1755815898,"model":"openai/gpt-oss-120b","choices":[{"index":0,"message":{"role":"assistant","content":"analysisThe user asks: \"Where is New York?\" Likely they want location info. Provide answer: New York State in northeastern US, New York City on the east coast, coordinates, etc. Provide context.assistantfinal**New York** can refer to two related places in the United States:\n\n| What it is | Where it is | Approx. coordinates | How to picture it |\n|------------|------------|--------------------|-------------------|\n| **New York State** | The northeastern corner of the United States, bordered by **Vermont, Massachusetts, Connecticut, New Jersey, Pennsylvania, and the Canadian provinces of Ontario and Quebec**. | 42.7° N, 75.5° W (roughly the state’s geographic centre) | A roughly rectangular state that stretches from the Atlantic Ocean in the southeast to the Adirondack Mountains and the Great Lakes region in the north. |\n| **New York City (NYC)** | The largest city in the state, located on the **southern tip of the state** where the **Hudson River meets the Atlantic Ocean**. It occupies five boroughs: Manhattan, Brooklyn, Queens, The Bronx, and Staten Island. | 40.7128° N, 74.0060° W | A dense, world‑famous metropolis that sits on a series of islands (Manhattan, Staten Island, parts of the Bronx) and the mainland (Brooklyn and Queens). |\n\n### Quick geographic context\n- **On a map of the United States:** New York State is in the **Northeast** region, just east of the Great Lakes and north of Pennsylvania. \n- **From Washington, D.C.:** Travel roughly **225 mi (360 km) northeast**. \n- **From Boston, MA:** Travel about **215 mi (350 km) southwest**. \n- **From Toronto, Canada:** Travel about **500 mi (800 km) southeast**.\n\n### Travel tips\n- **By air:** Major airports include **John F. Kennedy International (JFK)**, **LaGuardia (LGA)**, and **Newark Liberty International (EWR)** (the latter is actually in New Jersey but serves the NYC metro area). \n- **By train:** Amtrak’s **Northeast Corridor** runs from **Boston → New York City → Washington, D.C.** \n- **By car:** Interstates **I‑87** (north‑south) and **I‑90** (east‑west) are the primary highways crossing the state.\n\n### Fun fact\n- The name “**New York**” was given by the English in 1664, honoring the Duke of York (later King James II). The city’s original Dutch name was **“New Amsterdam.”**\n\nIf you need more specific directions (e.g., how to get to a particular neighborhood, landmark, or the state capital **Albany**), just let me know!","reasoning_content":null,"tool_calls":[]},"logprobs":null,"finish_reason":"stop","stop_reason":null,"mm_embedding_handle":null,"disaggregated_params":null,"avg_decoded_tokens_per_iter":1.0}],"usage":{"prompt_tokens":72,"total_tokens":705,"completion_tokens":633},"prompt_token_ids":null}
|
||||
```
|
||||
|
||||
### Troubleshooting Tips
|
||||
|
||||
* If you encounter CUDA out-of-memory errors, try reducing `max_batch_size` or `max_seq_len`.
|
||||
* Ensure your model checkpoints are compatible with the expected format.
|
||||
* For performance issues, check GPU utilization with nvidia-smi while the server is running.
|
||||
* If the container fails to start, verify that the NVIDIA Container Toolkit is properly installed.
|
||||
* For connection issues, make sure the server port (`8000` in this guide) is not being used by another application.
|
||||
|
||||
### Running Evaluations to Verify Accuracy (Optional)
|
||||
|
||||
We use OpenAI's official evaluation tool to test the model's accuracy. For more information see [https://github.com/openai/gpt-oss/tree/main/gpt_oss/evals](gpt-oss-eval).
|
||||
|
||||
TODO(@Binghan Chen): Add instructions for running gpt-oss-eval.
|
||||
|
||||
## Benchmarking Performance
|
||||
|
||||
To benchmark the performance of your TensorRT-LLM server you can leverage the built-in `benchmark_serving.py` script. To do this first creating a wrapper `bench.sh` script.
|
||||
|
||||
```shell
|
||||
cat <<'EOF' > bench.sh
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
concurrency_list="32 64 128 256 512 1024 2048 4096"
|
||||
multi_round=5
|
||||
isl=1024
|
||||
osl=1024
|
||||
result_dir=/tmp/gpt_oss_output
|
||||
|
||||
for concurrency in ${concurrency_list}; do
|
||||
num_prompts=$((concurrency * multi_round))
|
||||
python -m tensorrt_llm.serve.scripts.benchmark_serving \
|
||||
--model openai/gpt-oss-120b \
|
||||
--backend openai \
|
||||
--dataset-name "random" \
|
||||
--random-input-len ${isl} \
|
||||
--random-output-len ${osl} \
|
||||
--random-prefix-len 0 \
|
||||
--random-ids \
|
||||
--num-prompts ${num_prompts} \
|
||||
--max-concurrency ${concurrency} \
|
||||
--ignore-eos \
|
||||
--tokenize-on-client \
|
||||
--percentile-metrics "ttft,tpot,itl,e2el"
|
||||
done
|
||||
EOF
|
||||
chmod +x bench.sh
|
||||
```
|
||||
|
||||
If you want to save the results to a file add the following options.
|
||||
|
||||
```shell
|
||||
--save-result \
|
||||
--result-dir "${result_dir}" \
|
||||
--result-filename "concurrency_${concurrency}.json"
|
||||
```
|
||||
|
||||
For more benchmarking options see <https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt\_llm/serve/scripts/benchmark\_serving.py>.
|
||||
|
||||
Run `bench.sh` to begin a serving benchmark. This will take a long time if you run all the concurrencies mentioned in the above `bench.sh` script.
|
||||
|
||||
```shell
|
||||
./bench.sh
|
||||
```
|
||||
|
||||
Sample TensorRT-LLM serving benchmark output. Your results may vary due to ongoing software optimizations.
|
||||
|
||||
```
|
||||
============ Serving Benchmark Result ============
|
||||
Successful requests: 16
|
||||
Benchmark duration (s): 17.66
|
||||
Total input tokens: 16384
|
||||
Total generated tokens: 16384
|
||||
Request throughput (req/s): [result]
|
||||
Output token throughput (tok/s): [result]
|
||||
Total Token throughput (tok/s): [result]
|
||||
User throughput (tok/s): [result]
|
||||
---------------Time to First Token----------------
|
||||
Mean TTFT (ms): [result]
|
||||
Median TTFT (ms): [result]
|
||||
P99 TTFT (ms): [result]
|
||||
-----Time per Output Token (excl. 1st token)------
|
||||
Mean TPOT (ms): [result]
|
||||
Median TPOT (ms): [result]
|
||||
P99 TPOT (ms): [result]
|
||||
---------------Inter-token Latency----------------
|
||||
Mean ITL (ms): [result]
|
||||
Median ITL (ms): [result]
|
||||
P99 ITL (ms): [result]
|
||||
----------------End-to-end Latency----------------
|
||||
Mean E2EL (ms): [result]
|
||||
Median E2EL (ms): [result]
|
||||
P99 E2EL (ms): [result]
|
||||
==================================================
|
||||
```
|
||||
|
||||
### Key Metrics
|
||||
|
||||
* Median Time to First Token (TTFT)
|
||||
* The typical time elapsed from when a request is sent until the first output token is generated.
|
||||
* Median Time Per Output Token (TPOT)
|
||||
* The typical time required to generate each token *after* the first one.
|
||||
* Median Inter-Token Latency (ITL)
|
||||
* The typical time delay between the completion of one token and the completion of the next.
|
||||
* Median End-to-End Latency (E2EL)
|
||||
* The typical total time from when a request is submitted until the final token of the response is received.
|
||||
* Total Token Throughput
|
||||
* The combined rate at which the system processes both input (prompt) tokens and output (generated) tokens.
|
||||
@ -38,6 +38,7 @@ Welcome to TensorRT-LLM's Documentation!
|
||||
deployment-guide/quick-start-recipe-for-llama4-scout-on-trtllm.md
|
||||
deployment-guide/quick-start-recipe-for-deepseek-r1-on-trtllm.md
|
||||
deployment-guide/quick-start-recipe-for-llama3.3-70b-on-trtllm.md
|
||||
deployment-guide/quick-start-recipe-for-gpt-oss-on-trtllm.md
|
||||
|
||||
|
||||
.. toctree::
|
||||
|
||||
@ -0,0 +1,77 @@
|
||||
# Serving with trtllm-serve
|
||||
|
||||
AutoDeploy integrates with the OpenAI-compatible `trtllm-serve` CLI so you can expose AutoDeploy-optimized models over HTTP without writing server code. This page shows how to launch the server with the AutoDeploy backend, configure it via YAML, and validate with a simple request.
|
||||
|
||||
## Quick start
|
||||
|
||||
Launch `trtllm-serve` with the AutoDeploy backend by setting `--backend _autodeploy`:
|
||||
|
||||
```bash
|
||||
trtllm-serve \
|
||||
meta-llama/Llama-3.1-8B-Instruct \
|
||||
--backend _autodeploy
|
||||
```
|
||||
|
||||
- `model`: HF name or local path
|
||||
- `--backend _autodeploy`: uses AutoDeploy runtime
|
||||
|
||||
Once the server is ready, test with an OpenAI-compatible request:
|
||||
|
||||
```bash
|
||||
curl -s http://localhost:8000/v1/chat/completions \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"messages":[{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Where is New York? Tell me in a single sentence."}],
|
||||
"max_tokens": 32
|
||||
}'
|
||||
```
|
||||
|
||||
## Configuration via YAML
|
||||
|
||||
Use `--extra_llm_api_options` to supply a YAML file that augments or overrides server/runtime settings.
|
||||
|
||||
```bash
|
||||
trtllm-serve \
|
||||
meta-llama/Llama-3.1-8B \
|
||||
--backend _autodeploy \
|
||||
--extra_llm_api_options autodeploy_config.yaml
|
||||
```
|
||||
|
||||
Example `autodeploy_config.yaml`:
|
||||
|
||||
```yaml
|
||||
# Compilation backend for AutoDeploy
|
||||
compile_backend: torch-opt # options: torch-simple, torch-compile, torch-cudagraph, torch-opt
|
||||
|
||||
# Runtime engine
|
||||
runtime: trtllm # options: trtllm, demollm
|
||||
|
||||
# Model loading
|
||||
skip_loading_weights: false # set true for architecture-only perf runs
|
||||
|
||||
# KV cache memory
|
||||
free_mem_ratio: 0.8 # fraction of free GPU mem for KV cache
|
||||
|
||||
# CUDA graph optimization
|
||||
cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64]
|
||||
|
||||
# Attention backend
|
||||
attn_backend: flashinfer # recommended for best performance
|
||||
```
|
||||
|
||||
## Limitations and tips
|
||||
|
||||
- KV cache block reuse is disabled automatically for AutoDeploy backend
|
||||
- AutoDeploy backend doesn't yet support disaggregated serving. WIP
|
||||
- For best performance:
|
||||
- Prefer `compile_backend: torch-opt`
|
||||
- Use `attn_backend: flashinfer`
|
||||
- Set realistic `cuda_graph_batch_sizes` that match expected traffic
|
||||
- Tune `free_mem_ratio` to 0.8–0.9
|
||||
|
||||
## See also
|
||||
|
||||
- [AutoDeploy overview](../auto-deploy.md)
|
||||
- [Benchmarking with trtllm-bench](./benchmarking_with_trtllm_bench.md)
|
||||
@ -59,6 +59,7 @@ The exported graph then undergoes a series of automated transformations, includi
|
||||
- [Incorporating AutoDeploy into Your Own Workflow](./advanced/workflow.md)
|
||||
- [Expert Configurations](./advanced/expert_configurations.md)
|
||||
- [Performance Benchmarking](./advanced/benchmarking_with_trtllm_bench.md)
|
||||
- [Serving with trtllm-serve](./advanced/serving_with_trtllm_serve.md)
|
||||
|
||||
## Roadmap
|
||||
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
tensorrt_llm==1.1.0rc1
|
||||
tensorrt_llm==1.1.0rc2
|
||||
evaluate~=0.4.1
|
||||
rouge_score~=0.1.2
|
||||
|
||||
@ -7,8 +7,8 @@ from difflib import SequenceMatcher
|
||||
import torch
|
||||
|
||||
from tensorrt_llm import LLM, SamplingParams
|
||||
from tensorrt_llm.llmapi.llm_args import KvCacheConfig
|
||||
from tensorrt_llm.mapping import CpType
|
||||
from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig
|
||||
|
||||
|
||||
def dump_jsonl(data, fname):
|
||||
@ -54,11 +54,8 @@ def similarity_score(a, b):
|
||||
return SequenceMatcher(None, a, b).ratio()
|
||||
|
||||
|
||||
# Generate the outputs using either TRT or PyTorch (based on the use_pytorch argument). It’s the same function for both workflows.
|
||||
def generate_llm_outputs(args, data, fp8=False, fp8_kv_cache=False):
|
||||
quant_config = QuantConfig(quant_algo=QuantAlgo.FP8,
|
||||
kv_cache_quant_algo=QuantAlgo.FP8 if fp8_kv_cache
|
||||
else None) if fp8 else QuantConfig()
|
||||
kv_cache_config = KvCacheConfig(dtype="fp8" if fp8_kv_cache else "auto")
|
||||
cp_config = {
|
||||
"cp_type": CpType.STAR,
|
||||
"cp_anchor_size": args.sa_anchor_size,
|
||||
@ -70,7 +67,7 @@ def generate_llm_outputs(args, data, fp8=False, fp8_kv_cache=False):
|
||||
max_input_len=args.max_input_len,
|
||||
max_seq_len=args.max_seq_len,
|
||||
max_num_tokens=args.max_num_tokens,
|
||||
quant_config=quant_config,
|
||||
kv_cache_config=kv_cache_config,
|
||||
tensor_parallel_size=1,
|
||||
context_parallel_size=args.num_procs,
|
||||
cp_config=cp_config,
|
||||
|
||||
@ -57,10 +57,10 @@ def CONFIG_LINUX_AARCH64_CU12 = "linux_aarch64_CU12"
|
||||
def CONFIG_LINUX_AARCH64_LLVM = "linux_aarch64_LLVM"
|
||||
|
||||
@Field
|
||||
def CONFIG_LINUX_X86_64_NANOBIND = "linux_x86_64_Nanobind"
|
||||
def CONFIG_LINUX_X86_64_PYBIND = "linux_x86_64_Pybind"
|
||||
|
||||
@Field
|
||||
def CONFIG_LINUX_AARCH64_NANOBIND = "linux_aarch64_Nanobind"
|
||||
def CONFIG_LINUX_AARCH64_PYBIND = "linux_aarch64_Pybind"
|
||||
|
||||
@Field
|
||||
def BUILD_CONFIGS = [
|
||||
@ -76,9 +76,9 @@ def BUILD_CONFIGS = [
|
||||
(TARNAME) : "TensorRT-LLM-CU12.tar.gz",
|
||||
(WHEEL_ARCHS): "80-real;86-real;89-real;90-real;100-real;120-real",
|
||||
],
|
||||
(CONFIG_LINUX_X86_64_NANOBIND) : [
|
||||
(WHEEL_EXTRA_ARGS) : "--binding_type nanobind --extra-cmake-vars ENABLE_MULTI_DEVICE=1 --extra-cmake-vars WARNING_IS_ERROR=ON --extra-cmake-vars NIXL_ROOT=/opt/nvidia/nvda_nixl --micro_benchmarks",
|
||||
(TARNAME) : "nanobind-TensorRT-LLM.tar.gz",
|
||||
(CONFIG_LINUX_X86_64_PYBIND) : [
|
||||
(WHEEL_EXTRA_ARGS) : "--binding_type pybind --extra-cmake-vars ENABLE_MULTI_DEVICE=1 --extra-cmake-vars WARNING_IS_ERROR=ON --extra-cmake-vars NIXL_ROOT=/opt/nvidia/nvda_nixl --micro_benchmarks",
|
||||
(TARNAME) : "pybind-TensorRT-LLM.tar.gz",
|
||||
(WHEEL_ARCHS): "80-real;86-real;89-real;90-real;100-real;120-real",
|
||||
],
|
||||
(CONFIG_LINUX_X86_64_SINGLE_DEVICE) : [
|
||||
@ -101,9 +101,9 @@ def BUILD_CONFIGS = [
|
||||
(TARNAME) : "TensorRT-LLM-GH200-CU12.tar.gz",
|
||||
(WHEEL_ARCHS): "90-real;100-real;120-real",
|
||||
],
|
||||
(CONFIG_LINUX_AARCH64_NANOBIND): [
|
||||
(WHEEL_EXTRA_ARGS) : "--binding_type nanobind --extra-cmake-vars WARNING_IS_ERROR=ON",
|
||||
(TARNAME) : "nanobind-TensorRT-LLM-GH200.tar.gz",
|
||||
(CONFIG_LINUX_AARCH64_PYBIND): [
|
||||
(WHEEL_EXTRA_ARGS) : "--binding_type pybind --extra-cmake-vars WARNING_IS_ERROR=ON",
|
||||
(TARNAME) : "pybind-TensorRT-LLM-GH200.tar.gz",
|
||||
(WHEEL_ARCHS): "90-real;100-real;120-real",
|
||||
],
|
||||
(CONFIG_LINUX_AARCH64_LLVM) : [
|
||||
@ -568,8 +568,8 @@ def launchStages(pipeline, cpu_arch, enableFailFast, globalVars)
|
||||
pipeline, cpu_arch == AARCH64_TRIPLE ? CONFIG_LINUX_AARCH64_CU12 : CONFIG_LINUX_X86_64_VANILLA_CU12),
|
||||
"Build TRT-LLM LLVM": [LLM_DOCKER_IMAGE] + prepareLLMBuild(
|
||||
pipeline, cpu_arch == AARCH64_TRIPLE ? CONFIG_LINUX_AARCH64_LLVM : CONFIG_LINUX_X86_64_LLVM),
|
||||
"Build TRT-LLM Nanobind": [LLM_DOCKER_IMAGE] + prepareLLMBuild(
|
||||
pipeline, cpu_arch == AARCH64_TRIPLE ? CONFIG_LINUX_AARCH64_NANOBIND : CONFIG_LINUX_X86_64_NANOBIND),
|
||||
"Build TRT-LLM Pybind": [LLM_DOCKER_IMAGE] + prepareLLMBuild(
|
||||
pipeline, cpu_arch == AARCH64_TRIPLE ? CONFIG_LINUX_AARCH64_PYBIND : CONFIG_LINUX_X86_64_PYBIND),
|
||||
]
|
||||
|
||||
if (cpu_arch == X86_64_TRIPLE) {
|
||||
|
||||
@ -74,7 +74,7 @@ def LINUX_AARCH64_CONFIG = "linux_aarch64"
|
||||
def LINUX_AARCH64_CONFIG_CU12 = "linux_aarch64_CU12"
|
||||
|
||||
@Field
|
||||
def NANOBIND_CONFIG = "Nanobind"
|
||||
def PYBIND_CONFIG = "Pybind"
|
||||
|
||||
@Field
|
||||
def BUILD_CONFIGS = [
|
||||
@ -85,7 +85,7 @@ def BUILD_CONFIGS = [
|
||||
(LLVM_CONFIG) : [(TARNAME) : "llvm-TensorRT-LLM.tar.gz"],
|
||||
(LINUX_AARCH64_CONFIG) : [(TARNAME) : "TensorRT-LLM-GH200.tar.gz"],
|
||||
(LINUX_AARCH64_CONFIG_CU12) : [(TARNAME) : "TensorRT-LLM-GH200-CU12.tar.gz"],
|
||||
(NANOBIND_CONFIG) : [(TARNAME) : "nanobind-TensorRT-LLM.tar.gz"],
|
||||
(PYBIND_CONFIG) : [(TARNAME) : "pybind-TensorRT-LLM.tar.gz"],
|
||||
]
|
||||
|
||||
// TODO: Move common variables to an unified location
|
||||
@ -657,8 +657,7 @@ def createKubernetesPodConfig(image, type, arch = "amd64", gpuCount = 1, perfMod
|
||||
def driverVersion = Constants.DEFAULT_NVIDIA_DRIVER_VERSION
|
||||
def cpuCount = "${TESTER_CORES}"
|
||||
|
||||
// Multi-GPU only supports DGX-H100 and DGX-H200 due to the hardware stability.
|
||||
if ((type.contains("dgx-h100") || type.contains("dgx-h200")) && hasMultipleGPUs)
|
||||
if (hasMultipleGPUs)
|
||||
{
|
||||
// Not a hard requirement, but based on empirical values.
|
||||
memorySize = "${gpuCount * 150}" + "Gi"
|
||||
@ -672,7 +671,7 @@ def createKubernetesPodConfig(image, type, arch = "amd64", gpuCount = 1, perfMod
|
||||
targetCould = "kubernetes"
|
||||
|
||||
// The following GPU types doesn't support dynamic driver flashing.
|
||||
if (type.contains("dgx-h100") || type.contains("dgx-h200") || type in ["b100-ts2", "gh200", "rtx-5080", "rtx-5090"]) {
|
||||
if (type.contains("dgx-h100") || type.contains("dgx-h200") || type.contains("rtx-pro-6000") || type in ["b100-ts2", "gh200", "rtx-5080", "rtx-5090"]) {
|
||||
selectors = """
|
||||
kubernetes.io/arch: ${arch}
|
||||
kubernetes.io/os: linux
|
||||
@ -1281,6 +1280,7 @@ def runLLMTestlistOnPlatformImpl(pipeline, platform, testList, config=VANILLA_CO
|
||||
echoNodeAndGpuInfo(pipeline, stageName)
|
||||
sh "cat ${MODEL_CACHE_DIR}/README"
|
||||
sh "nvidia-smi -q"
|
||||
sh "nvidia-smi topo -m"
|
||||
sh "df -h"
|
||||
|
||||
// setup HF_HOME to cache model and datasets
|
||||
@ -1789,7 +1789,7 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null)
|
||||
"A10-TensorRT-4": ["a10", "l0_a10", 4, 6],
|
||||
"A10-TensorRT-5": ["a10", "l0_a10", 5, 6],
|
||||
"A10-TensorRT-6": ["a10", "l0_a10", 6, 6],
|
||||
"A10-Nanobind": ["a10", "l0_a10_nanobind", 1, 1],
|
||||
"A10-Pybind": ["a10", "l0_a10_pybind", 1, 1],
|
||||
"A30-Triton-1": ["a30", "l0_a30", 1, 1],
|
||||
"A30-PyTorch-1": ["a30", "l0_a30", 1, 2],
|
||||
"A30-PyTorch-2": ["a30", "l0_a30", 2, 2],
|
||||
@ -1809,8 +1809,6 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null)
|
||||
"B200_PCIe-PyTorch-1": ["b100-ts2", "l0_b200", 1, 3],
|
||||
"B200_PCIe-PyTorch-2": ["b100-ts2", "l0_b200", 2, 3],
|
||||
"B200_PCIe-PyTorch-3": ["b100-ts2", "l0_b200", 3, 3],
|
||||
"B200_PCIe-TensorRT-1": ["b100-ts2", "l0_b200", 1, 2],
|
||||
"B200_PCIe-TensorRT-2": ["b100-ts2", "l0_b200", 2, 2],
|
||||
"RTX5090-PyTorch-1": ["rtx-5090", "l0_gb202", 1, 1],
|
||||
"RTX5080-TensorRT-1": ["rtx-5080", "l0_gb203", 1, 2],
|
||||
"RTX5080-TensorRT-2": ["rtx-5080", "l0_gb203", 2, 2],
|
||||
@ -1850,6 +1848,8 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null)
|
||||
"H100_PCIe-TensorRT-Post-Merge-5": ["h100-cr", "l0_h100", 5, 5],
|
||||
"B200_PCIe-Triton-Post-Merge-1": ["b100-ts2", "l0_b200", 1, 1],
|
||||
"B200_PCIe-PyTorch-Post-Merge-1": ["b100-ts2", "l0_b200", 1, 1],
|
||||
"B200_PCIe-TensorRT-Post-Merge-1": ["b100-ts2", "l0_b200", 1, 2],
|
||||
"B200_PCIe-TensorRT-Post-Merge-2": ["b100-ts2", "l0_b200", 2, 2],
|
||||
"H100_PCIe-TensorRT-Perf-1": ["h100-cr", "l0_perf", 1, 1],
|
||||
"H100_PCIe-PyTorch-Perf-1": ["h100-cr", "l0_perf", 1, 1],
|
||||
"DGX_H200-8_GPUs-PyTorch-Post-Merge-1": ["dgx-h200-x8", "l0_dgx_h200", 1, 1, 8],
|
||||
@ -1857,6 +1857,9 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null)
|
||||
"DGX_H200-4_GPUs-TensorRT-Post-Merge-1": ["dgx-h200-x4", "l0_dgx_h200", 1, 3, 4],
|
||||
"DGX_H200-4_GPUs-TensorRT-Post-Merge-2": ["dgx-h200-x4", "l0_dgx_h200", 2, 3, 4],
|
||||
"DGX_H200-4_GPUs-TensorRT-Post-Merge-3": ["dgx-h200-x4", "l0_dgx_h200", 3, 3, 4],
|
||||
"RTXPro6000-Pytorch-Post-Merge-1": ["rtx-pro-6000", "l0_rtx_pro_6000", 1, 1],
|
||||
"RTXPro6000-4_GPUs-Pytorch-Post-Merge-1": ["rtx-pro-6000-x4", "l0_rtx_pro_6000", 1, 2, 4],
|
||||
"RTXPro6000-4_GPUs-Pytorch-Post-Merge-2": ["rtx-pro-6000-x4", "l0_rtx_pro_6000", 2, 2, 4],
|
||||
]
|
||||
|
||||
parallelJobs = x86TestConfigs.collectEntries{key, values -> [key, [createKubernetesPodConfig(key.contains("-CU12-") ? LLM_DOCKER_IMAGE_12_9 : LLM_DOCKER_IMAGE, values[0], "amd64", values[4] ?: 1, key.contains("Perf")), {
|
||||
@ -1867,8 +1870,8 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null)
|
||||
if (key.contains("llvm")) {
|
||||
config = LLVM_CONFIG
|
||||
}
|
||||
if (key.contains("Nanobind")) {
|
||||
config = NANOBIND_CONFIG
|
||||
if (key.contains("Pybind")) {
|
||||
config = PYBIND_CONFIG
|
||||
}
|
||||
if (key.contains("-CU12-")) {
|
||||
config = VANILLA_CONFIG_CU12
|
||||
@ -1878,7 +1881,6 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null)
|
||||
fullSet = parallelJobs.keySet()
|
||||
|
||||
x86SlurmTestConfigs = [
|
||||
"RTXPro6000-PyTorch-Post-Merge-1": ["rtx-pro-6000", "l0_rtx_pro_6000", 1, 1],
|
||||
"DGX_B200-4_GPUs-PyTorch-Post-Merge-1": ["b200-x4", "l0_dgx_b200", 1, 1, 4],
|
||||
]
|
||||
fullSet += x86SlurmTestConfigs.keySet()
|
||||
@ -2099,11 +2101,6 @@ def launchTestJobs(pipeline, testFilter, dockerNode=null)
|
||||
checkPipStage = true
|
||||
}
|
||||
|
||||
if (cpu_arch == AARCH64_TRIPLE && values[5] != DLFW_IMAGE) {
|
||||
checkPipStage = false
|
||||
echo "Skip pip install sanity check due to https://nvbugs/5453827"
|
||||
}
|
||||
|
||||
if (checkPipStage) {
|
||||
stage("Run LLMAPI tests") {
|
||||
pipInstallSanitySpec = createKubernetesPodConfig(values[5], gpu_type, k8s_arch)
|
||||
@ -2484,7 +2481,7 @@ pipeline {
|
||||
|
||||
def testPhase2StageName = env.testPhase2StageName
|
||||
if (testPhase2StageName) {
|
||||
def dgxSigns = ["DGX_H100", "DGX_H200", "GB200", "DGX_B200"]
|
||||
def dgxSigns = ["DGX_H100", "DGX_H200", "GB200", "DGX_B200", "RTXPro6000-4_GPUs"]
|
||||
singleGpuJobs = parallelJobs.findAll{!dgxSigns.any{sign -> it.key.contains(sign)}}
|
||||
dgxJobs = parallelJobs.findAll{dgxSigns.any{sign -> it.key.contains(sign)}}
|
||||
}
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu128
|
||||
-c constraints.txt
|
||||
accelerate>=0.25.0
|
||||
accelerate>=1.7.0
|
||||
build
|
||||
colored
|
||||
cuda-python>=12
|
||||
|
||||
@ -435,7 +435,7 @@ def main(*,
|
||||
install: bool = False,
|
||||
skip_building_wheel: bool = False,
|
||||
linking_install_binary: bool = False,
|
||||
binding_type: str = "pybind",
|
||||
binding_type: str = "nanobind",
|
||||
benchmarks: bool = False,
|
||||
micro_benchmarks: bool = False,
|
||||
nvtx: bool = False,
|
||||
@ -984,8 +984,8 @@ def add_arguments(parser: ArgumentParser):
|
||||
)
|
||||
parser.add_argument("--binding_type",
|
||||
choices=["pybind", "nanobind"],
|
||||
default="pybind",
|
||||
help="Which binding type to build: pybind or nanobind")
|
||||
default="nanobind",
|
||||
help="Which binding library to use: pybind or nanobind")
|
||||
parser.add_argument("--benchmarks",
|
||||
action="store_true",
|
||||
help="Build the benchmarks for the C++ runtime")
|
||||
|
||||
@ -19,6 +19,11 @@ transforms:
|
||||
stage: post_export
|
||||
cleanup_input_constraints:
|
||||
stage: post_export
|
||||
############################################################################################
|
||||
# RUN PATTERN MATCHER TRANSFORMATIONS TO STANDARDIZE GRAPH REPRESENTATION
|
||||
############################################################################################
|
||||
match_moe_pattern:
|
||||
stage: pattern_matcher
|
||||
match_repeat_kv:
|
||||
stage: pattern_matcher
|
||||
match_eager_attention:
|
||||
@ -27,12 +32,13 @@ transforms:
|
||||
stage: pattern_matcher
|
||||
match_attention_layout:
|
||||
stage: pattern_matcher
|
||||
match_moe_pattern:
|
||||
stage: pattern_matcher
|
||||
match_rope_pattern:
|
||||
stage: pattern_matcher
|
||||
match_rope_layout:
|
||||
stage: pattern_matcher
|
||||
############################################################################################
|
||||
# RUN TRANSFORMATIONS ON STANDARDIZED GRAPH REPRESENTATION
|
||||
############################################################################################
|
||||
eliminate_redundant_transposes:
|
||||
stage: pattern_matcher
|
||||
# TODO (lucaslie): let's move this to perf optimization once TP sharding is improved
|
||||
@ -57,5 +63,44 @@ transforms:
|
||||
sharding_transform_executor:
|
||||
stage: sharding
|
||||
run_shape_prop: true
|
||||
############################################################################################
|
||||
# MOVE MODEL AND LOAD WEIGHTS
|
||||
############################################################################################
|
||||
load_weights:
|
||||
stage: weight_load
|
||||
############################################################################################
|
||||
# RUN POST-LOAD FUSION AND OPTIMIZATIONS
|
||||
############################################################################################
|
||||
# TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/4674 this is causing OOMs
|
||||
# fuse_moe:
|
||||
# stage: post_load_fusion
|
||||
# fuse_gemms:
|
||||
# stage: post_load_fusion
|
||||
fuse_allreduce_residual_rmsnorm:
|
||||
stage: post_load_fusion
|
||||
fuse_collectives:
|
||||
stage: post_load_fusion
|
||||
# TODO (lucaslie): add backend selection as part of configurable inference optimizers
|
||||
# check if we can fuse rmsnorm
|
||||
fuse_rmsnorm:
|
||||
stage: post_load_fusion
|
||||
backend: flashinfer
|
||||
############################################################################################
|
||||
# SWITCH TO CACHED+FLATTENED ATTENTION + INITIALIZE CACHES
|
||||
############################################################################################
|
||||
update_in_out_nodes:
|
||||
stage: cache_init
|
||||
insert_cached_attention:
|
||||
stage: cache_init
|
||||
insert_cached_mla_attention:
|
||||
stage: cache_init
|
||||
attn_backend: MultiHeadLatentAttention
|
||||
initialize_cache:
|
||||
stage: cache_init
|
||||
resize_kv_cache:
|
||||
stage: cache_init
|
||||
############################################################################################
|
||||
# COMPILE MODEL
|
||||
############################################################################################
|
||||
compile_model:
|
||||
stage: compile
|
||||
|
||||
@ -198,7 +198,6 @@ def prepare_flashinfer_metadata(
|
||||
flashinfer.get_seq_lens(paged_kv_indptr, paged_kv_last_page_len, page_size),
|
||||
position_ids.numel(),
|
||||
)
|
||||
|
||||
# return metadata
|
||||
return (
|
||||
qo_indptr,
|
||||
|
||||
@ -274,6 +274,16 @@ class LlmArgs(AutoDeployConfig, BaseLlmArgs, BaseSettings):
|
||||
self._quant_config = value
|
||||
|
||||
### VALIDATION #################################################################################
|
||||
@field_validator("max_seq_len", mode="before")
|
||||
@classmethod
|
||||
def ensure_max_seq_len(cls, value: Any, info: ValidationInfo) -> Any:
|
||||
if value is None:
|
||||
# Fallback to the AutoDeployConfig default when not provided
|
||||
return AutoDeployConfig.model_fields["max_seq_len"].get_default(
|
||||
call_default_factory=True
|
||||
)
|
||||
return value
|
||||
|
||||
@field_validator("build_config", mode="before")
|
||||
@classmethod
|
||||
def ensure_no_build_config(cls, value: Any, info: ValidationInfo) -> Any:
|
||||
|
||||
@ -54,6 +54,7 @@ class SharedConfig(BaseModel):
|
||||
sharding_config: ShardingConfig = Field(default_factory=ShardingConfig)
|
||||
local_rank: int = Field(default=0)
|
||||
world_size: int = Field(default=1)
|
||||
attn_backend: str = Field(default="flashinfer", description="The attention backend to use.")
|
||||
|
||||
|
||||
class TransformConfig(BaseModel):
|
||||
@ -285,7 +286,10 @@ class BaseTransform(ABC):
|
||||
# update + store new meta data
|
||||
history[t_name] = info
|
||||
autodeploy_meta[self._history_key] = history
|
||||
self._set_autodeploy_meta(gm, autodeploy_meta)
|
||||
|
||||
if isinstance(gm, GraphModule):
|
||||
# After compilation, gm becomes type CapturedGraph with no meta data.
|
||||
self._set_autodeploy_meta(gm, autodeploy_meta)
|
||||
|
||||
# return the graph module
|
||||
return gm
|
||||
|
||||
204
tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py
Normal file
204
tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py
Normal file
@ -0,0 +1,204 @@
|
||||
import operator
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from ...distributed.trtllm import is_trtllm_op_available
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils.node_utils import get_op_overload_packet, get_user_if_pattern_match, is_op
|
||||
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
|
||||
|
||||
# TODO: This is an overly simplified model that works well for vanilla Llama models.
|
||||
# However, we eventually want to consider more sophisticated patterns such as
|
||||
# * all_reduce(lin1(x) + lin2(x))
|
||||
# * version above with fused GEMMs (i.e. with a split node)
|
||||
# * all_reduce(pointwise_op(linear(x)))
|
||||
# * ...
|
||||
|
||||
|
||||
@TransformRegistry.register("fuse_collectives")
|
||||
class FuseCollectives(BaseTransform):
|
||||
"""
|
||||
Fuses all_reduce ops with preceding (quantized) linear ops into a single fused node for improved performance.
|
||||
"""
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
num_gemm_collective_fusions = 0
|
||||
|
||||
# lookup for fused ops
|
||||
# TODO: avoid this hardcoded lookup, e.g., by generating fused ops on the fly.
|
||||
lookup = {
|
||||
torch.ops.auto_deploy.torch_linear_simple: torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce,
|
||||
torch.ops.aten.linear: torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce,
|
||||
torch.ops.auto_deploy.torch_quant_fp8_linear: torch.ops.auto_deploy.torch_quant_fused_fp8_linear_all_reduce,
|
||||
}
|
||||
|
||||
# go through all nodes and find all_reduce nodes
|
||||
for node in gm.graph.nodes:
|
||||
if not is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce):
|
||||
continue
|
||||
|
||||
# check if args are as expected
|
||||
assert len(node.args) == 1 and not len(node.kwargs), (
|
||||
"Unexpected args/kwargs for all_reduce"
|
||||
)
|
||||
|
||||
# retrieve parent and check a few conditions on the parent node
|
||||
parent_node = node.args[0]
|
||||
if not is_op(parent_node, lookup.keys()):
|
||||
continue
|
||||
if len(parent_node.users) > 1:
|
||||
continue
|
||||
|
||||
with gm.graph.inserting_before(node):
|
||||
# insert fused node
|
||||
fused_linear_collective_node = gm.graph.call_function(
|
||||
lookup[get_op_overload_packet(parent_node.target)],
|
||||
args=parent_node.args,
|
||||
kwargs=parent_node.kwargs,
|
||||
)
|
||||
node.replace_all_uses_with(fused_linear_collective_node)
|
||||
gm.graph.erase_node(node)
|
||||
gm.graph.erase_node(parent_node)
|
||||
num_gemm_collective_fusions += 1
|
||||
|
||||
info = TransformInfo(
|
||||
skipped=False,
|
||||
num_matches=num_gemm_collective_fusions,
|
||||
is_clean=False,
|
||||
has_valid_shapes=False,
|
||||
)
|
||||
|
||||
return gm, info
|
||||
|
||||
|
||||
@TransformRegistry.register("fuse_allreduce_residual_rmsnorm")
|
||||
class FuseAllreduceResidualRMSNorm(BaseTransform):
|
||||
"""Essentially, this transformation fuses the following operators into one allreduce trtllm implementation.
|
||||
|
||||
* target pattern:
|
||||
x = all_reduce(x)
|
||||
y = x + residual
|
||||
return rmsnorm(y), y
|
||||
* replacement:
|
||||
fused_allreduce_residual_rmsnorm(x, residual, rmsnorm_weight, rmsnorm_eps)
|
||||
|
||||
"""
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
if not is_trtllm_op_available():
|
||||
return gm, TransformInfo(
|
||||
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
|
||||
)
|
||||
|
||||
num_ar_r_rms_fusions = 0
|
||||
|
||||
def trace_and_fuse(allreduce_node, graph):
|
||||
# Check if all_reduce is followed by addition
|
||||
users = list(allreduce_node.users.keys())
|
||||
if len(users) != 1:
|
||||
return # Skip if all_reduce has more than one consumer
|
||||
add_node = users[0]
|
||||
|
||||
# Traverse nodes for RMSNorm pattern which is composed of to_copy, pow, mean, add, refer
|
||||
# the Huggingface LlamaRMSNorm implementation as example for more details
|
||||
to_copy_1 = get_user_if_pattern_match(add_node, [torch.ops.aten.add, operator.add], 2)
|
||||
# operand of pow and mul
|
||||
pow_node = get_user_if_pattern_match(
|
||||
to_copy_1, [torch.ops.aten._to_copy, torch.ops.aten.to], 2
|
||||
)
|
||||
mean_node = get_user_if_pattern_match(pow_node, torch.ops.aten.pow, 1)
|
||||
add_eps_node = get_user_if_pattern_match(mean_node, torch.ops.aten.mean, 1)
|
||||
rsqrt_node = get_user_if_pattern_match(
|
||||
add_eps_node, [torch.ops.aten.add, operator.add], 1
|
||||
)
|
||||
mul_node_1 = get_user_if_pattern_match(rsqrt_node, torch.ops.aten.rsqrt, 1)
|
||||
to_copy_2 = get_user_if_pattern_match(mul_node_1, torch.ops.aten.mul, 1)
|
||||
mul_node_2 = get_user_if_pattern_match(
|
||||
to_copy_2, [torch.ops.aten._to_copy, torch.ops.aten.to], 1
|
||||
)
|
||||
# check args of ops: pow(2) and mean(-1)
|
||||
ARGS_MATCH = pow_node is not None and pow_node.args[1] == 2 # exponent
|
||||
ARGS_MATCH &= mean_node is not None and mean_node.args[1] == [-1] # dimensions
|
||||
|
||||
# Match found: Replace with fused operation
|
||||
if (
|
||||
to_copy_1
|
||||
and pow_node
|
||||
and mean_node
|
||||
and add_eps_node
|
||||
and rsqrt_node
|
||||
and mul_node_1
|
||||
and to_copy_2
|
||||
and mul_node_2
|
||||
and ARGS_MATCH
|
||||
):
|
||||
# Gather the inputs for the custom operation
|
||||
tensor = allreduce_node.args[0]
|
||||
# Identify the residual argument in the add operation
|
||||
# One of the args in add_node.args is the output of all_reduce
|
||||
# The same idea also applies to norm_weight
|
||||
residual = (
|
||||
add_node.args[0] if add_node.args[1] is allreduce_node else add_node.args[1]
|
||||
)
|
||||
norm_weight = (
|
||||
mul_node_2.args[0] if mul_node_2.args[1] is to_copy_2 else mul_node_2.args[1]
|
||||
)
|
||||
eps = add_eps_node.args[1]
|
||||
|
||||
# Insert nodes
|
||||
with graph.inserting_before(allreduce_node):
|
||||
fused_node = graph.call_function(
|
||||
torch.ops.dist.fused_allreduce_residual_rmsnorm,
|
||||
args=(
|
||||
tensor,
|
||||
residual,
|
||||
norm_weight,
|
||||
eps,
|
||||
),
|
||||
)
|
||||
# Extract outputs from the tuple returned by `fused_node`
|
||||
final_output_node = gm.graph.create_node(
|
||||
"call_function",
|
||||
target=operator.getitem,
|
||||
args=(fused_node, 0),
|
||||
)
|
||||
add_output_node = gm.graph.create_node(
|
||||
"call_function",
|
||||
target=operator.getitem,
|
||||
args=(fused_node, 1),
|
||||
)
|
||||
|
||||
# Replace all uses of rmsnorm_node with final_output_node
|
||||
mul_node_2.replace_all_uses_with(final_output_node)
|
||||
|
||||
# Replace all uses of add_node with add_output_node
|
||||
add_node.replace_all_uses_with(add_output_node)
|
||||
|
||||
nonlocal num_ar_r_rms_fusions
|
||||
num_ar_r_rms_fusions += 1
|
||||
|
||||
# Traverse all nodes
|
||||
for node in gm.graph.nodes:
|
||||
if is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce):
|
||||
trace_and_fuse(allreduce_node=node, graph=gm.graph)
|
||||
|
||||
info = TransformInfo(
|
||||
skipped=False, num_matches=num_ar_r_rms_fusions, is_clean=False, has_valid_shapes=False
|
||||
)
|
||||
|
||||
return gm, info
|
||||
@ -0,0 +1,65 @@
|
||||
from typing import List, Literal, Optional, Tuple, Type
|
||||
|
||||
from pydantic import Field
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from ...compile import compile_and_capture
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ..interface import (
|
||||
BaseTransform,
|
||||
SharedConfig,
|
||||
TransformConfig,
|
||||
TransformInfo,
|
||||
TransformRegistry,
|
||||
)
|
||||
|
||||
|
||||
class CompileModelConfig(TransformConfig):
|
||||
"""Configuration for the compile model transform."""
|
||||
|
||||
cuda_graph_batch_sizes: Optional[List[int]] = Field(
|
||||
default=None, description="The batch sizes to use for CUDA graphs."
|
||||
)
|
||||
num_batched_inputs: int = Field(
|
||||
default=2, description="The number of batched inputs to use for CUDA graphs."
|
||||
)
|
||||
compile_backend: Literal["torch-simple", "torch-compile", "torch-cudagraph", "torch-opt"] = (
|
||||
Field(description="The backend to use for compiling the model.")
|
||||
)
|
||||
|
||||
|
||||
@TransformRegistry.register("compile_model")
|
||||
class CompileModel(BaseTransform):
|
||||
"""A transform to compile the model."""
|
||||
|
||||
config: CompileModelConfig
|
||||
|
||||
@classmethod
|
||||
def get_config_class(cls) -> Type[TransformConfig]:
|
||||
return CompileModelConfig
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
cm.info.set_generate_only_batch()
|
||||
egm_compiled = compile_and_capture(
|
||||
gm,
|
||||
self.config.compile_backend,
|
||||
args=cm.args,
|
||||
dynamic_shapes=cm.dynamic_shapes,
|
||||
compiler_kwargs={
|
||||
"cuda_graph_batch_sizes": self.config.cuda_graph_batch_sizes,
|
||||
"num_batched_inputs": self.config.num_batched_inputs,
|
||||
},
|
||||
)
|
||||
cm.info.reset()
|
||||
|
||||
# store info object about the transform
|
||||
info = TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True)
|
||||
|
||||
return egm_compiled, info
|
||||
@ -6,6 +6,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils.cuda_mem_tracker import cuda_memory_tracker
|
||||
from ...utils.logger import ad_logger
|
||||
from ...utils.node_utils import (
|
||||
@ -14,7 +16,7 @@ from ...utils.node_utils import (
|
||||
is_linear_op,
|
||||
)
|
||||
from ...utils.quantization_utils import QuantizationImpl
|
||||
from .._graph import canonicalize_graph
|
||||
from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry
|
||||
|
||||
|
||||
def _insert_fused_gemm(gm: GraphModule, idx: int, parent_node: Node, linear_nodes: List[Node]):
|
||||
@ -116,30 +118,36 @@ def _insert_fused_gemm(gm: GraphModule, idx: int, parent_node: Node, linear_node
|
||||
gm.delete_all_unused_submodules()
|
||||
|
||||
|
||||
def fuse_gemms(gm: GraphModule) -> None:
|
||||
ad_logger.info("GEMM fusion")
|
||||
ad_logger.debug("Before GEMM fusion: " + str(gm))
|
||||
# sort linear nodes by parent node
|
||||
linear_nodes = defaultdict(list)
|
||||
for node in gm.graph.nodes:
|
||||
# TODO: we don't handle bias for now...
|
||||
if is_linear_op(node, include_quantization=True) and node.args[2] is None:
|
||||
linear_nodes[node.args[0]].append(node)
|
||||
@TransformRegistry.register("fuse_gemms")
|
||||
class FuseGemms(BaseTransform):
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
# sort linear nodes by parent node
|
||||
linear_nodes = defaultdict(list)
|
||||
for node in gm.graph.nodes:
|
||||
# TODO: we don't handle bias for now...
|
||||
if is_linear_op(node, include_quantization=True) and node.args[2] is None:
|
||||
linear_nodes[node.args[0]].append(node)
|
||||
|
||||
# fuse linear nodes
|
||||
idx = -1
|
||||
with cuda_memory_tracker():
|
||||
for parent_node, lin_children in linear_nodes.items():
|
||||
if len(lin_children) < 2:
|
||||
continue
|
||||
# linear nodes to fuse
|
||||
ad_logger.debug(
|
||||
f"Found linear nodes to fuse: {lin_children} with parent node: {parent_node}"
|
||||
)
|
||||
_insert_fused_gemm(gm, idx := idx + 1, parent_node, lin_children)
|
||||
# fuse linear nodes
|
||||
idx = -1
|
||||
num_matches = 0
|
||||
with cuda_memory_tracker():
|
||||
for parent_node, lin_children in linear_nodes.items():
|
||||
if len(lin_children) < 2:
|
||||
continue
|
||||
# linear nodes to fuse
|
||||
_insert_fused_gemm(gm, idx := idx + 1, parent_node, lin_children)
|
||||
num_matches += 1
|
||||
|
||||
# clean up and return
|
||||
canonicalize_graph(gm)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
ad_logger.debug("After GEMM fusion: " + str(gm))
|
||||
torch.cuda.empty_cache()
|
||||
info = TransformInfo(
|
||||
skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=False
|
||||
)
|
||||
return gm, info
|
||||
299
tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py
Normal file
299
tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py
Normal file
@ -0,0 +1,299 @@
|
||||
"""Graph transformation to automatically add kv cache into fused MHA op."""
|
||||
|
||||
import operator
|
||||
from typing import Dict, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
from torch.fx import Graph, GraphModule, Node
|
||||
|
||||
from ...custom_ops.attention_interface import AttentionRegistry
|
||||
from ...distributed.common import all_gather_object, get_world_size
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...transformations._graph import add_graph_input
|
||||
from ...utils.logger import ad_logger
|
||||
from ...utils.node_utils import get_all_input_output_nodes, is_op
|
||||
from ..interface import (
|
||||
BaseTransform,
|
||||
SharedConfig,
|
||||
TransformConfig,
|
||||
TransformInfo,
|
||||
TransformRegistry,
|
||||
)
|
||||
|
||||
|
||||
@TransformRegistry.register("update_in_out_nodes")
|
||||
class UpdateInOutNodes(BaseTransform):
|
||||
"""Modify the graph module by adding new input nodes.
|
||||
|
||||
The new input nodes correspond to the extra arguments needed for cached and flattened attention.
|
||||
|
||||
Args:
|
||||
egm: The graph module to analyze and modify.
|
||||
cm: Cached sequence interface containing extra argument information.
|
||||
"""
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
# loop through nodes to get input, output, and get_attr nodes
|
||||
input_nodes, output_nodes = get_all_input_output_nodes(gm.graph)
|
||||
|
||||
# we only expect one input node
|
||||
assert len(input_nodes) == 2, "Expected exactly two input nodes (input_ids, position_ids)."
|
||||
|
||||
# NOTE: for now, we wanna make sure we *only* return the final output and no hidden states.
|
||||
# Later on, we can revisit how to support returning hidden states.
|
||||
assert len(output_nodes) == 1, "Expected exactly one output node!"
|
||||
assert len(output_nodes[0].all_input_nodes) == 1, (
|
||||
"Expected to only return final tensor output!"
|
||||
)
|
||||
|
||||
# Activate and add extra argument nodes
|
||||
new_args = cm.info.switch_to_cached_attn_inputs()
|
||||
for name in new_args:
|
||||
input_nodes.append(add_graph_input(gm, name))
|
||||
|
||||
info = TransformInfo(skipped=False, num_matches=1, is_clean=False, has_valid_shapes=False)
|
||||
|
||||
return gm, info
|
||||
|
||||
|
||||
class InsertCachedAttentionConfig(TransformConfig):
|
||||
"""Configuration for the insert cached attention transform."""
|
||||
|
||||
attn_backend: Optional[str] = Field(default=None, description="The attention backend to use.")
|
||||
|
||||
|
||||
@TransformRegistry.register("insert_cached_attention")
|
||||
class InsertCachedAttention(BaseTransform):
|
||||
"""
|
||||
A transform to insert cached attention into the graph module.
|
||||
|
||||
If attn_backend is not provided in transform config, will find from shared config.
|
||||
"""
|
||||
|
||||
config: InsertCachedAttentionConfig
|
||||
|
||||
@classmethod
|
||||
def get_config_class(cls) -> Type[TransformConfig]:
|
||||
return InsertCachedAttentionConfig
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
"""Replace uncached source attention node with corresponding cached attn node."""
|
||||
attn_descriptor = AttentionRegistry.get(self.config.attn_backend)
|
||||
|
||||
cache_config = factory.get_cache_config()
|
||||
|
||||
# Get all attention nodes and their info objects
|
||||
source_op = attn_descriptor.get_source_attention_op()
|
||||
|
||||
# pick up graph
|
||||
graph: Graph = gm.graph
|
||||
|
||||
# look for relevant source attention nodes
|
||||
source_attn_nodes = [n for n in graph.nodes if is_op(n, source_op)]
|
||||
|
||||
if not source_attn_nodes:
|
||||
# If there are no nodes for kv cache insertion found, return current graph
|
||||
return gm, TransformInfo(
|
||||
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
|
||||
)
|
||||
|
||||
# Sanity check
|
||||
if cm.info.is_paged:
|
||||
assert attn_descriptor.is_paged(), "Paged sequence info requires paged attention op."
|
||||
|
||||
# retrieve input nodes
|
||||
input_nodes, _ = get_all_input_output_nodes(gm.graph)
|
||||
|
||||
# insert metadata computation and extract each argument as a node
|
||||
get_metadata, num_metadata = attn_descriptor.get_prepare_metadata_op()
|
||||
with graph.inserting_before(input_nodes[-1].next):
|
||||
ret_node = graph.call_function(
|
||||
get_metadata,
|
||||
args=(
|
||||
*input_nodes,
|
||||
cm.info.page_size,
|
||||
),
|
||||
)
|
||||
metadata_nodes = [
|
||||
graph.call_function(operator.getitem, args=(ret_node, idx))
|
||||
for idx in range(num_metadata)
|
||||
]
|
||||
|
||||
buffer_in_lookup: Dict[str, Node] = {}
|
||||
|
||||
# replace fused attention node with attention node that has kv cache
|
||||
num_cached_attn_replacements = 0
|
||||
for idx, attn_node in enumerate(source_attn_nodes):
|
||||
# pick out GEMMs
|
||||
qkv = attn_node.args[: attn_descriptor.get_num_qkv_args()]
|
||||
|
||||
# setup + store cache initializers and caches as input nodes
|
||||
cache_in_nodes = []
|
||||
for k, get_cache in attn_descriptor.get_cache_initializers(
|
||||
attn_node, cache_config
|
||||
).items():
|
||||
k_indexed = f"{k}_{idx}"
|
||||
cm.add_cache(k_indexed, get_cache)
|
||||
cache_in_nodes.append(add_graph_input(gm, k_indexed))
|
||||
|
||||
# setup + store global buffer initializers and buffers as input nodes
|
||||
# NOTE: we have to check against existing keys to make sure nothing is registered twice...
|
||||
buffer_in_nodes = []
|
||||
for k, get_buffer in attn_descriptor.get_global_buffer_initializers(attn_node).items():
|
||||
if k not in buffer_in_lookup:
|
||||
cm.add_cache(k, get_buffer)
|
||||
buffer_in_lookup[k] = add_graph_input(gm, k)
|
||||
buffer_in_nodes.append(buffer_in_lookup[k]) # store buffer nodes for this op
|
||||
|
||||
# retrieve constants for attention_op
|
||||
constants = attn_descriptor.get_constants(attn_node)
|
||||
|
||||
# insert cached attention replacement op
|
||||
with graph.inserting_before(attn_node):
|
||||
cached_attn_node = graph.call_function(
|
||||
attn_descriptor.get_cached_attention_op(),
|
||||
args=(*qkv, *metadata_nodes, *cache_in_nodes, *buffer_in_nodes, *constants),
|
||||
)
|
||||
attn_node.replace_all_uses_with(cached_attn_node)
|
||||
graph.erase_node(attn_node)
|
||||
num_cached_attn_replacements += 1
|
||||
|
||||
info = TransformInfo(
|
||||
skipped=False,
|
||||
num_matches=num_cached_attn_replacements,
|
||||
is_clean=False,
|
||||
has_valid_shapes=False,
|
||||
)
|
||||
|
||||
return gm, info
|
||||
|
||||
|
||||
@TransformRegistry.register("insert_cached_mla_attention")
|
||||
class InsertCachedMLAAttention(InsertCachedAttention):
|
||||
"""
|
||||
A transform to insert cached MLA attention into the graph module.
|
||||
|
||||
This class is identical to InsertCachedAttention and inherits all its behavior.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ResizeKVCacheConfig(TransformConfig):
|
||||
"""Configuration for the resize kv cache transform."""
|
||||
|
||||
free_mem_ratio: float = Field(
|
||||
description="The fraction of available memory to occupy.", default=0.8
|
||||
)
|
||||
|
||||
|
||||
@TransformRegistry.register("resize_kv_cache")
|
||||
class ResizeKVCache(BaseTransform):
|
||||
"""Inflate the kv cache to occupy the available GPU memory.
|
||||
|
||||
free_mem_ratio specifies the fraction of available memory to occupy.
|
||||
"""
|
||||
|
||||
config: ResizeKVCacheConfig
|
||||
|
||||
@classmethod
|
||||
def get_config_class(cls) -> Type[TransformConfig]:
|
||||
return ResizeKVCacheConfig
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
free_mem_ratio = self.config.free_mem_ratio
|
||||
|
||||
def _get_mem_info_in_mb():
|
||||
free_mem, total_mem = torch.cuda.mem_get_info()
|
||||
return free_mem // 1024**2, total_mem // 1024**2
|
||||
|
||||
free_mem, total_mem = _get_mem_info_in_mb()
|
||||
ad_logger.info(f"Free memory (MB): {free_mem}, Total memory (MB): {total_mem}")
|
||||
current_cache_size = cm.current_cache_size_bytes()
|
||||
current_num_pages = cm.info.num_pages
|
||||
ad_logger.info(
|
||||
f"Current cache size: {current_cache_size}, Current num pages: {current_num_pages}"
|
||||
)
|
||||
|
||||
if free_mem_ratio == 0.0:
|
||||
ad_logger.info(f"Skipping cache resize for {free_mem_ratio=}")
|
||||
return gm, TransformInfo(
|
||||
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
|
||||
)
|
||||
|
||||
try:
|
||||
# Let's run a forward pass to get the memory usage
|
||||
cm.info._set_max_num_tokens_sample()
|
||||
free_mem_pre, _ = _get_mem_info_in_mb()
|
||||
ad_logger.info(f"Free memory before forward pass (MB): {free_mem_pre}")
|
||||
|
||||
gm(*cm.args)
|
||||
|
||||
free_mem_post, _ = _get_mem_info_in_mb()
|
||||
ad_logger.info(f"Free memory after forward pass (MB): {free_mem_post}")
|
||||
|
||||
memory_for_forward_pass = free_mem_pre - free_mem_post
|
||||
ad_logger.info(f"Memory for forward pass (MB): {memory_for_forward_pass}")
|
||||
|
||||
new_cache_size = free_mem_post * 1024 * 1024 * free_mem_ratio + current_cache_size
|
||||
new_num_pages = int(new_cache_size // (current_cache_size // current_num_pages))
|
||||
|
||||
# Need to sync all the GPUs
|
||||
gathered_num_pages = [None] * get_world_size()
|
||||
all_gather_object(gathered_num_pages, new_num_pages)
|
||||
new_num_pages = min(gathered_num_pages)
|
||||
ad_logger.info(f"After all_gather - new_num_pages: {new_num_pages}")
|
||||
|
||||
cm.resize_cache(new_num_pages)
|
||||
except Exception as e:
|
||||
ad_logger.warning(
|
||||
f"Error encountered while resizing kv cache: {e}.\nSkipping cache resize."
|
||||
)
|
||||
|
||||
# Free memory
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
info = TransformInfo(
|
||||
skipped=False,
|
||||
num_matches=0,
|
||||
is_clean=True,
|
||||
has_valid_shapes=True,
|
||||
)
|
||||
|
||||
return gm, info
|
||||
|
||||
|
||||
@TransformRegistry.register("initialize_cache")
|
||||
class InitializeCache(BaseTransform):
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
cm.initialize_caches()
|
||||
|
||||
info = TransformInfo(skipped=False, num_matches=1, is_clean=True, has_valid_shapes=True)
|
||||
|
||||
return gm, info
|
||||
148
tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py
Normal file
148
tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py
Normal file
@ -0,0 +1,148 @@
|
||||
"""Graph transform to optimize RMSNorm execution using FlashInfer."""
|
||||
|
||||
from functools import partial
|
||||
from typing import Tuple, Type
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from ...models.factory import ModelFactory
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
|
||||
# It is important to import ADPatternMatcherPass from pattern_matcher.py, not from torch._inductor.pattern_matcher
|
||||
from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern
|
||||
from ..interface import (
|
||||
BaseTransform,
|
||||
SharedConfig,
|
||||
TransformConfig,
|
||||
TransformInfo,
|
||||
TransformRegistry,
|
||||
)
|
||||
|
||||
_BACKEND_OPS = {
|
||||
"flashinfer": torch.ops.auto_deploy.flashinfer_rms_norm,
|
||||
"triton": torch.ops.auto_deploy.triton_rms_norm,
|
||||
"torch": torch.ops.auto_deploy.torch_rmsnorm,
|
||||
}
|
||||
|
||||
|
||||
def _rms_norm_pattern(data: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
|
||||
"""Implements the RMSNorm pattern for pattern matching.
|
||||
|
||||
Args:
|
||||
data: Input tensor to normalize.
|
||||
weight: Scaling weights for the normalized output.
|
||||
eps: Small constant for numerical stability.
|
||||
|
||||
Returns:
|
||||
Normalized and scaled tensor.
|
||||
"""
|
||||
input_dtype = data.dtype
|
||||
data = data.to(torch.float32)
|
||||
variance = data.pow(2).mean(-1, keepdim=True)
|
||||
data = data * torch.rsqrt(variance + eps)
|
||||
return weight * data.to(input_dtype)
|
||||
|
||||
|
||||
def _rms_norm_replacement(
|
||||
data: torch.Tensor, weight: torch.Tensor, eps: float, backend: str
|
||||
) -> torch.Tensor:
|
||||
"""Backend-specific rms_norm implementation.
|
||||
|
||||
Args:
|
||||
data: Input tensor to normalize.
|
||||
weight: Scaling weights for the normalized output.
|
||||
eps: Small constant for numerical stability.
|
||||
backend: Backend to use for RMSNorm computation ("flashinfer" or "triton").
|
||||
|
||||
Returns:
|
||||
Normalized and scaled tensor using the specified backend implementation.
|
||||
"""
|
||||
|
||||
assert backend.lower() in _BACKEND_OPS, (
|
||||
f"Invalid {backend=}; must be one of {list(_BACKEND_OPS)}"
|
||||
)
|
||||
return _BACKEND_OPS[backend.lower()](data, weight, eps)
|
||||
|
||||
|
||||
class FuseRMSNormConfig(TransformConfig):
|
||||
"""Configuration for the RMSNorm fusion transform."""
|
||||
|
||||
backend: str = Field(
|
||||
default="flashinfer",
|
||||
description="Backend to use for RMSNorm computation ('flashinfer' or 'triton').",
|
||||
)
|
||||
|
||||
|
||||
@TransformRegistry.register("fuse_rmsnorm")
|
||||
class FuseRMSNorm(BaseTransform):
|
||||
"""Matches and replaces RMSNorm patterns in the graph with FlashInfer or Triton implementation.
|
||||
|
||||
This function sets up pattern matching to identify RMSNorm operations in the graph
|
||||
and replaces them with optimized implementations. It uses dummy tensors to register
|
||||
the pattern matching rules.
|
||||
|
||||
Args:
|
||||
gm: Input graph module to transform.
|
||||
backend: Backend to use for RMSNorm computation ("flashinfer" or "triton").
|
||||
|
||||
Returns:
|
||||
Transformed graph module with optimized RMSNorm operations.
|
||||
"""
|
||||
|
||||
config: FuseRMSNormConfig
|
||||
|
||||
@classmethod
|
||||
def get_config_class(cls) -> Type[TransformConfig]:
|
||||
return FuseRMSNormConfig
|
||||
|
||||
def _apply(
|
||||
self,
|
||||
gm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
factory: ModelFactory,
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
if self.config.backend.lower() not in _BACKEND_OPS:
|
||||
raise ValueError(
|
||||
f"Invalid backend, must be one of {list(_BACKEND_OPS)}, got {self.config.backend}"
|
||||
)
|
||||
|
||||
graph = gm.graph
|
||||
patterns = ADPatternMatcherPass()
|
||||
|
||||
# Create dummy tensors for pattern matching
|
||||
bs = 2
|
||||
hidden_size = 512
|
||||
|
||||
def dummy_args(input_dtype: torch.dtype, weight_dtype: torch.dtype, eps: float = 1e-6):
|
||||
return [
|
||||
torch.randn(bs, hidden_size, device="cuda", dtype=input_dtype),
|
||||
torch.randn(hidden_size, device="cuda", dtype=weight_dtype),
|
||||
eps,
|
||||
]
|
||||
|
||||
# Define configurations for different data types
|
||||
configs = [
|
||||
(torch.bfloat16, torch.bfloat16),
|
||||
(torch.float16, torch.float16),
|
||||
(torch.float32, torch.float32),
|
||||
]
|
||||
|
||||
# Register patterns for each configuration
|
||||
for input_dtype, weight_dtype in configs:
|
||||
register_ad_pattern(
|
||||
search_fn=_rms_norm_pattern,
|
||||
replace_fn=partial(_rms_norm_replacement, backend=self.config.backend),
|
||||
patterns=patterns,
|
||||
dummy_args=dummy_args(input_dtype, weight_dtype),
|
||||
op_ignore_types={},
|
||||
scalar_workaround={"eps": 1e-6},
|
||||
)
|
||||
|
||||
cnt = patterns.apply(graph)
|
||||
|
||||
info = TransformInfo(skipped=False, num_matches=cnt, is_clean=False, has_valid_shapes=False)
|
||||
|
||||
return gm, info
|
||||
@ -1,11 +1,5 @@
|
||||
"""A library of transformation passes."""
|
||||
|
||||
from .collectives import *
|
||||
from .fused_moe import *
|
||||
from .fusion import *
|
||||
from .kvcache import *
|
||||
from .rms_norm import *
|
||||
|
||||
try:
|
||||
from .visualization import visualize_namespace
|
||||
except ImportError:
|
||||
|
||||
@ -1,167 +0,0 @@
|
||||
import operator
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from ...distributed.trtllm import is_trtllm_op_available
|
||||
from ...utils.logger import ad_logger
|
||||
from ...utils.node_utils import get_op_overload_packet, get_user_if_pattern_match, is_op
|
||||
from .._graph import canonicalize_graph
|
||||
|
||||
|
||||
# TODO: This is an overly simplified model that works well for vanilla Llama models.
|
||||
# However, we eventually want to consider more sophisticated patterns such as
|
||||
# * all_reduce(lin1(x) + lin2(x))
|
||||
# * version above with fused GEMMs (i.e. with a split node)
|
||||
# * all_reduce(pointwise_op(linear(x)))
|
||||
# * ...
|
||||
def fuse_collectives(gm: GraphModule) -> None:
|
||||
num_gemm_collective_fusions = 0
|
||||
ad_logger.debug("Before GEMM+Collective fusion: " + str(gm))
|
||||
|
||||
# lookup for fused ops
|
||||
# TODO: avoid this hardcoded lookup, e.g., by generating fused ops on the fly.
|
||||
lookup = {
|
||||
torch.ops.auto_deploy.torch_linear_simple: torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce,
|
||||
torch.ops.aten.linear: torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce,
|
||||
torch.ops.auto_deploy.torch_quant_fp8_linear: torch.ops.auto_deploy.torch_quant_fused_fp8_linear_all_reduce,
|
||||
}
|
||||
|
||||
# go through all nodes and find all_reduce nodes
|
||||
for node in gm.graph.nodes:
|
||||
if not is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce):
|
||||
continue
|
||||
|
||||
# check if args are as expected
|
||||
assert len(node.args) == 1 and not len(node.kwargs), "Unexpected args/kwargs for all_reduce"
|
||||
|
||||
# retrieve parent and check a few conditions on the parent node
|
||||
parent_node = node.args[0]
|
||||
if not is_op(parent_node, lookup.keys()):
|
||||
continue
|
||||
if len(parent_node.users) > 1:
|
||||
continue
|
||||
|
||||
with gm.graph.inserting_before(node):
|
||||
# insert fused node
|
||||
fused_linear_collective_node = gm.graph.call_function(
|
||||
lookup[get_op_overload_packet(parent_node.target)],
|
||||
args=parent_node.args,
|
||||
kwargs=parent_node.kwargs,
|
||||
)
|
||||
node.replace_all_uses_with(fused_linear_collective_node)
|
||||
gm.graph.erase_node(node)
|
||||
gm.graph.erase_node(parent_node)
|
||||
num_gemm_collective_fusions += 1
|
||||
|
||||
canonicalize_graph(gm)
|
||||
ad_logger.info(f"Found {num_gemm_collective_fusions} GEMM+Collective fusions")
|
||||
ad_logger.debug("After GEMM+Collective fusion: " + str(gm))
|
||||
|
||||
|
||||
def fuse_allreduce_residual_rmsnorm(gm: GraphModule) -> None:
|
||||
"""Essentially, this function fuses the following operators into one allreduce trtllm implementation.
|
||||
|
||||
* target pattern:
|
||||
x = all_reduce(x)
|
||||
y = x + residual
|
||||
return rmsnorm(y), y
|
||||
* replacement:
|
||||
fused_allreduce_residual_rmsnorm(x, residual, rmsnorm_weight, rmsnorm_eps)
|
||||
|
||||
"""
|
||||
if not is_trtllm_op_available():
|
||||
return
|
||||
|
||||
num_ar_r_rms_fusions = 0
|
||||
ad_logger.debug("Before allreduce+residual+rmsnorm fusion: " + str(gm))
|
||||
|
||||
def trace_and_fuse(allreduce_node, graph):
|
||||
# Check if all_reduce is followed by addition
|
||||
users = list(allreduce_node.users.keys())
|
||||
if len(users) != 1:
|
||||
return # Skip if all_reduce has more than one consumer
|
||||
add_node = users[0]
|
||||
|
||||
# Traverse nodes for RMSNorm pattern which is composed of to_copy, pow, mean, add, refer
|
||||
# the Huggingface LlamaRMSNorm implementation as example for more details
|
||||
to_copy_1 = get_user_if_pattern_match(add_node, [torch.ops.aten.add, operator.add], 2)
|
||||
# operand of pow and mul
|
||||
pow_node = get_user_if_pattern_match(
|
||||
to_copy_1, [torch.ops.aten._to_copy, torch.ops.aten.to], 2
|
||||
)
|
||||
mean_node = get_user_if_pattern_match(pow_node, torch.ops.aten.pow, 1)
|
||||
add_eps_node = get_user_if_pattern_match(mean_node, torch.ops.aten.mean, 1)
|
||||
rsqrt_node = get_user_if_pattern_match(add_eps_node, [torch.ops.aten.add, operator.add], 1)
|
||||
mul_node_1 = get_user_if_pattern_match(rsqrt_node, torch.ops.aten.rsqrt, 1)
|
||||
to_copy_2 = get_user_if_pattern_match(mul_node_1, torch.ops.aten.mul, 1)
|
||||
mul_node_2 = get_user_if_pattern_match(
|
||||
to_copy_2, [torch.ops.aten._to_copy, torch.ops.aten.to], 1
|
||||
)
|
||||
# check args of ops: pow(2) and mean(-1)
|
||||
ARGS_MATCH = pow_node is not None and pow_node.args[1] == 2 # exponent
|
||||
ARGS_MATCH &= mean_node is not None and mean_node.args[1] == [-1] # dimensions
|
||||
|
||||
# Match found: Replace with fused operation
|
||||
if (
|
||||
to_copy_1
|
||||
and pow_node
|
||||
and mean_node
|
||||
and add_eps_node
|
||||
and rsqrt_node
|
||||
and mul_node_1
|
||||
and to_copy_2
|
||||
and mul_node_2
|
||||
and ARGS_MATCH
|
||||
):
|
||||
# Gather the inputs for the custom operation
|
||||
tensor = allreduce_node.args[0]
|
||||
# Identify the residual argument in the add operation
|
||||
# One of the args in add_node.args is the output of all_reduce
|
||||
# The same idea also applies to norm_weight
|
||||
residual = add_node.args[0] if add_node.args[1] is allreduce_node else add_node.args[1]
|
||||
norm_weight = (
|
||||
mul_node_2.args[0] if mul_node_2.args[1] is to_copy_2 else mul_node_2.args[1]
|
||||
)
|
||||
eps = add_eps_node.args[1]
|
||||
|
||||
# Insert nodes
|
||||
with graph.inserting_before(allreduce_node):
|
||||
fused_node = graph.call_function(
|
||||
torch.ops.dist.fused_allreduce_residual_rmsnorm,
|
||||
args=(
|
||||
tensor,
|
||||
residual,
|
||||
norm_weight,
|
||||
eps,
|
||||
),
|
||||
)
|
||||
# Extract outputs from the tuple returned by `fused_node`
|
||||
final_output_node = gm.graph.create_node(
|
||||
"call_function",
|
||||
target=operator.getitem,
|
||||
args=(fused_node, 0),
|
||||
)
|
||||
add_output_node = gm.graph.create_node(
|
||||
"call_function",
|
||||
target=operator.getitem,
|
||||
args=(fused_node, 1),
|
||||
)
|
||||
|
||||
# Replace all uses of rmsnorm_node with final_output_node
|
||||
mul_node_2.replace_all_uses_with(final_output_node)
|
||||
|
||||
# Replace all uses of add_node with add_output_node
|
||||
add_node.replace_all_uses_with(add_output_node)
|
||||
|
||||
nonlocal num_ar_r_rms_fusions
|
||||
num_ar_r_rms_fusions += 1
|
||||
|
||||
# Traverse all nodes
|
||||
for node in gm.graph.nodes:
|
||||
if is_op(node, torch.ops.auto_deploy.torch_dist_all_reduce):
|
||||
trace_and_fuse(allreduce_node=node, graph=gm.graph)
|
||||
|
||||
canonicalize_graph(gm)
|
||||
ad_logger.info(f"Found {num_ar_r_rms_fusions} allreduce+residual+rmsnorm fusions")
|
||||
ad_logger.debug("After allreduce+residual+rmsnorm fusion: " + str(gm))
|
||||
@ -1,511 +0,0 @@
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from ...utils.cuda_mem_tracker import cuda_memory_tracker
|
||||
from ...utils.logger import ad_logger
|
||||
from ...utils.node_utils import bfs, identify_regions_between_residuals, is_linear_op, is_op
|
||||
from ...utils.quantization_utils import get_scales_and_type_from_node
|
||||
from .._graph import canonicalize_graph
|
||||
|
||||
|
||||
def match_moe_pattern(gm: GraphModule) -> None:
|
||||
graph = gm.graph
|
||||
|
||||
ad_logger.debug("Before MoE Pattern Matching: " + str(gm))
|
||||
# Preprocessing: Identify boundary nodes (e.g. residual connections) in the graph.
|
||||
boundary_nodes = identify_regions_between_residuals(gm)
|
||||
|
||||
num_moe_patterns = 0
|
||||
|
||||
for start_boundary, end_boundary in zip(boundary_nodes[:-1], boundary_nodes[1:]):
|
||||
# Step 1: Identify Expert Compute pattern
|
||||
(pattern_input_nodes, pattern_output_nodes, expert_weights, expert_scales, weight_type) = (
|
||||
_match_expert_compute_pattern(start_boundary, end_boundary)
|
||||
)
|
||||
if not expert_weights:
|
||||
continue
|
||||
# TODO: naming convention to verify the order of the weight nodes
|
||||
|
||||
# Step 2: Trace upwards to locate normalize_routing_weight and selected_experts:
|
||||
arg1_list, arg2_list = _extract_index_branches_from_expert_outputs(pattern_output_nodes)
|
||||
normalized_routing_weights = _find_lowest_common_ancessor(arg1_list)
|
||||
if not normalized_routing_weights:
|
||||
continue
|
||||
|
||||
common_ancessor2 = _find_lowest_common_ancessor(arg2_list)
|
||||
if not common_ancessor2:
|
||||
continue
|
||||
selected_experts = bfs(
|
||||
common_ancessor2,
|
||||
lambda node: is_op(node, torch.ops.aten.one_hot),
|
||||
attr_next="all_input_nodes",
|
||||
boundary=start_boundary,
|
||||
).args[0]
|
||||
if not selected_experts:
|
||||
continue
|
||||
|
||||
# Step 3: Trace upwards to find input node:
|
||||
hidden_states = _find_lowest_common_ancessor(pattern_input_nodes)
|
||||
if not hidden_states:
|
||||
continue
|
||||
|
||||
# Step 4: Find output node with the combine pattern
|
||||
final_hidden_state_node = _find_final_hidden_state_node(pattern_output_nodes, end_boundary)
|
||||
if final_hidden_state_node is None:
|
||||
continue
|
||||
|
||||
# Step 5: Insert the MoE op into the graph.
|
||||
ad_logger.debug(
|
||||
f"Found MoE Pattern: between boundary {start_boundary} and {end_boundary}.\n"
|
||||
f"Input hidden states node: {hidden_states}, "
|
||||
f"selected_experts node: {selected_experts}, "
|
||||
f"routing_weights node: {normalized_routing_weights}, "
|
||||
f"expert weights: {expert_weights}, weight type: {weight_type}"
|
||||
)
|
||||
with graph.inserting_before(final_hidden_state_node):
|
||||
w1_list = expert_weights["w1"]
|
||||
w2_list = expert_weights["w2"]
|
||||
w3_list = expert_weights["w3"]
|
||||
|
||||
if weight_type == "fp8":
|
||||
fused_moe_node = graph.call_function(
|
||||
torch.ops.auto_deploy.torch_quant_fp8_moe,
|
||||
args=(
|
||||
hidden_states,
|
||||
selected_experts,
|
||||
normalized_routing_weights,
|
||||
w1_list,
|
||||
w2_list,
|
||||
w3_list,
|
||||
expert_scales["w1_input_scale"],
|
||||
expert_scales["w2_input_scale"],
|
||||
expert_scales["w3_input_scale"],
|
||||
expert_scales["w1_weight_scale"],
|
||||
expert_scales["w2_weight_scale"],
|
||||
expert_scales["w3_weight_scale"],
|
||||
),
|
||||
)
|
||||
elif weight_type == "fp4":
|
||||
fused_moe_node = graph.call_function(
|
||||
torch.ops.auto_deploy.torch_quant_fp4_moe,
|
||||
args=(
|
||||
hidden_states,
|
||||
selected_experts,
|
||||
normalized_routing_weights,
|
||||
w1_list,
|
||||
w2_list,
|
||||
w3_list,
|
||||
expert_scales["w1_input_scale"],
|
||||
expert_scales["w2_input_scale"],
|
||||
expert_scales["w3_input_scale"],
|
||||
expert_scales["w1_weight_scale"],
|
||||
expert_scales["w2_weight_scale"],
|
||||
expert_scales["w3_weight_scale"],
|
||||
expert_scales["w1_alpha"],
|
||||
expert_scales["w2_alpha"],
|
||||
expert_scales["w3_alpha"],
|
||||
),
|
||||
)
|
||||
else:
|
||||
fused_moe_node = graph.call_function(
|
||||
torch.ops.auto_deploy.torch_moe,
|
||||
args=(
|
||||
hidden_states,
|
||||
selected_experts,
|
||||
normalized_routing_weights,
|
||||
w1_list,
|
||||
w2_list,
|
||||
w3_list,
|
||||
),
|
||||
)
|
||||
|
||||
final_hidden_state_node.replace_all_uses_with(fused_moe_node)
|
||||
graph.erase_node(final_hidden_state_node)
|
||||
|
||||
while _remove_dead_inplace_nodes_in_region(gm.graph, start_boundary, end_boundary):
|
||||
gm.graph.eliminate_dead_code()
|
||||
|
||||
num_moe_patterns += 1
|
||||
|
||||
canonicalize_graph(gm)
|
||||
|
||||
ad_logger.info(f"Found {num_moe_patterns} MoE Patterns")
|
||||
ad_logger.debug("After MoE Pattern Matching: " + str(gm))
|
||||
|
||||
|
||||
def fuse_moe(gm: torch.fx.GraphModule) -> None:
|
||||
"""
|
||||
Scan the FX graph and replace all calls to torch.ops.auto_deploy.torch_moe with
|
||||
torch.ops.auto_deploy.trtllm_moe_fused.
|
||||
"""
|
||||
ad_logger.debug("Before MoE fusion: " + str(gm))
|
||||
|
||||
with cuda_memory_tracker():
|
||||
fused_key_counter = _insert_fused_moe_ops(gm)
|
||||
if fused_key_counter:
|
||||
canonicalize_graph(gm)
|
||||
|
||||
ad_logger.info(f"Found {fused_key_counter} MoE fusions")
|
||||
ad_logger.debug("After MoE fusion: " + str(gm))
|
||||
|
||||
|
||||
def _insert_fused_moe_ops(gm: GraphModule) -> int:
|
||||
fused_key_counter = 0
|
||||
graph = gm.graph
|
||||
|
||||
for node in list(graph.nodes):
|
||||
if not is_op(node, torch.ops.auto_deploy.torch_moe):
|
||||
continue
|
||||
|
||||
ad_logger.debug(f"Found MoE op to fuse: {node} with args: {node.args}")
|
||||
hidden_states, selected_experts, routing_weights, w1_list, w2_list, w3_list = node.args
|
||||
|
||||
fused_w3_w1_experts = torch.stack(
|
||||
[
|
||||
torch.cat(
|
||||
[gm.get_parameter(w3_node.target), gm.get_parameter(w1_node.target)], dim=-2
|
||||
)
|
||||
for w1_node, w3_node in zip(w1_list, w3_list)
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
fused_w2_experts = torch.stack([gm.get_parameter(n.target) for n in w2_list], dim=0)
|
||||
|
||||
new_key_w3_w1 = f"fused_moe_w3_w1_stacked_{fused_key_counter}"
|
||||
new_key_w2 = f"fused_moe_w2_stacked_{fused_key_counter}"
|
||||
fused_key_counter += 1
|
||||
param_w3_w1 = torch.nn.Parameter(fused_w3_w1_experts)
|
||||
param_w2 = torch.nn.Parameter(fused_w2_experts)
|
||||
gm.register_parameter(new_key_w3_w1, param_w3_w1)
|
||||
gm.register_parameter(new_key_w2, param_w2)
|
||||
|
||||
with graph.inserting_before(node):
|
||||
new_node = graph.call_function(
|
||||
# TODO(Fridah-nv): torch.ops.auto_deploy.trtllm_moe_fused for quantized models
|
||||
torch.ops.auto_deploy.trtllm_moe_fused,
|
||||
args=(
|
||||
hidden_states,
|
||||
selected_experts,
|
||||
routing_weights,
|
||||
graph.get_attr(new_key_w3_w1),
|
||||
graph.get_attr(new_key_w2),
|
||||
),
|
||||
)
|
||||
|
||||
node.replace_all_uses_with(new_node)
|
||||
graph.erase_node(node)
|
||||
|
||||
return fused_key_counter
|
||||
|
||||
|
||||
def _find_lowest_common_ancessor(nodes: list[Node]) -> Optional[Node]:
|
||||
"""
|
||||
Find the lowest common ancestor for a list of nodes in a torch.fx Graph by following
|
||||
each node's primary branch (recursively following the first Node argument).
|
||||
|
||||
It first finds the LCA of the first two nodes and then
|
||||
iteratively computes the LCA of the result with the next node, and so on.
|
||||
|
||||
Returns:
|
||||
The common ancestor Node if found, otherwise None.
|
||||
"""
|
||||
if not nodes:
|
||||
return None
|
||||
|
||||
def get_parent(node: Node) -> Optional[Node]:
|
||||
"""Return the first Node-valued argument for a given node, or None if not found."""
|
||||
for arg in node.args:
|
||||
if isinstance(arg, Node):
|
||||
return arg
|
||||
return None
|
||||
|
||||
def get_depth(node: Node) -> int:
|
||||
"""
|
||||
Recursively compute the depth of the node by following its primary branch.
|
||||
Depth is defined as the number of steps to reach a node with no parent.
|
||||
"""
|
||||
parent = get_parent(node)
|
||||
if parent is None:
|
||||
return 0
|
||||
return 1 + get_depth(parent)
|
||||
|
||||
def lca_two(a: Node, b: Node) -> Optional[Node]:
|
||||
"""
|
||||
Find the lowest common ancestor of two nodes by first equalizing their depth
|
||||
and then moving upward until a common node is found.
|
||||
"""
|
||||
depth_a = get_depth(a)
|
||||
depth_b = get_depth(b)
|
||||
|
||||
# Equalize depths
|
||||
while depth_a > depth_b:
|
||||
a = get_parent(a)
|
||||
depth_a -= 1
|
||||
while depth_b > depth_a:
|
||||
b = get_parent(b)
|
||||
depth_b -= 1
|
||||
|
||||
# Walk upward in lockstep
|
||||
while a is not None and b is not None:
|
||||
if a is b:
|
||||
return a
|
||||
a = get_parent(a)
|
||||
b = get_parent(b)
|
||||
return None
|
||||
|
||||
# Iteratively compute the LCA across all nodes.
|
||||
common = nodes[0]
|
||||
for node in nodes[1:]:
|
||||
common = lca_two(common, node)
|
||||
if common is None:
|
||||
return None
|
||||
|
||||
return common
|
||||
|
||||
|
||||
def _extract_linear_parameters(linear_node: Node) -> tuple[Node, torch.Tensor, Optional[dict], str]:
|
||||
"""
|
||||
Given a linear op node, extract the input tensor node, weight tensor,
|
||||
any quantization scales (if the op is quantized), and return a weight type.
|
||||
|
||||
For a torch.ops.auto_deploy.torch_linear_simple.default op:
|
||||
- Returns (input_node, weight, None, "simple")
|
||||
|
||||
For a torch.ops.auto_deploy.torch_quant_fp8_linear op:
|
||||
- Returns (input_node, weight, {"input_scale": input_scale, "weight_scale": weight_scale}, "fp8")
|
||||
For a torch.ops.auto_deploy.torch_quant_fp4_linear op:
|
||||
- Returns (input_node, weight, {"input_scale": input_scale, "weight_scale": weight_scale, "alpha": alpha}, "fp4")
|
||||
"""
|
||||
input_node = linear_node.args[0]
|
||||
if is_op(linear_node, torch.ops.auto_deploy.torch_linear_simple):
|
||||
weight = linear_node.args[1]
|
||||
return input_node, weight, None, ""
|
||||
elif {
|
||||
is_op(linear_node, torch.ops.auto_deploy.torch_quant_fp4_linear),
|
||||
is_op(linear_node, torch.ops.auto_deploy.torch_quant_fp8_linear),
|
||||
}:
|
||||
weight = linear_node.args[1]
|
||||
scales, quant_type = get_scales_and_type_from_node(linear_node)
|
||||
return input_node, weight, scales, quant_type
|
||||
|
||||
|
||||
def _match_expert_compute_pattern(start_boundary: Node, end_boundary: Node):
|
||||
"""
|
||||
Match the expert compute pattern between the given boundaries.
|
||||
|
||||
The expert compute pattern corresponds to:
|
||||
|
||||
(F.silu(x @ w1.t()) * (x @ w3.t())) @ w2.t()
|
||||
|
||||
For each expert, the function extracts the input node from the w1 branch and
|
||||
collects the weight parameters from three linear ops (w1, w3, and w2 branches).
|
||||
|
||||
This function supports both:
|
||||
- torch.ops.auto_deploy.torch_linear_simple.default ops, and
|
||||
- torch.ops.auto_deploy.torch_quant_fp8_linear ops (also extracts quantization scales).
|
||||
- torch.ops.auto_deploy.torch_quant_fp4_linear ops (also extracts quantization scales).
|
||||
|
||||
Returns:
|
||||
A tuple:
|
||||
(pattern_input_nodes, pattern_output_nodes, expert_weights, expert_scales, weight_type)
|
||||
|
||||
- pattern_input_nodes: List of input nodes (x) used for the expert compute.
|
||||
- pattern_output_nodes: List of final expert output nodes (the linear op with weight w2).
|
||||
- expert_weights: Dict with keys "w1", "w2", "w3" mapping to lists of weight tensors.
|
||||
- expert_scales: Dict with keys "w1_input_scale", "w1_weight_scale", etc., containing scale tensors
|
||||
(empty if weight_type is "simple").
|
||||
- weight_type: "fp8" if FP8 ops were used, "simple" otherwise.
|
||||
"""
|
||||
pattern_input_nodes, pattern_output_nodes = [], []
|
||||
expert_weights = defaultdict(list)
|
||||
expert_scales = defaultdict(list)
|
||||
weight_type = "simple" # default
|
||||
|
||||
nodes = list(start_boundary.graph.nodes)
|
||||
region_nodes = nodes[nodes.index(start_boundary) + 1 : nodes.index(end_boundary)]
|
||||
|
||||
for node in region_nodes:
|
||||
# Accept both simple and quantized linear ops.
|
||||
if not is_linear_op(node, include_quantization=True):
|
||||
continue
|
||||
|
||||
final_linear = node
|
||||
if not final_linear.args or not isinstance(final_linear.args[0], Node):
|
||||
continue
|
||||
|
||||
mul_node = final_linear.args[0]
|
||||
if not is_op(mul_node, torch.ops.aten.mul) or len(mul_node.args) < 2:
|
||||
continue
|
||||
|
||||
arg_a, arg_b = mul_node.args[:2]
|
||||
silu_node = (
|
||||
arg_a
|
||||
if is_op(arg_a, torch.ops.aten.silu)
|
||||
else arg_b
|
||||
if is_op(arg_b, torch.ops.aten.silu)
|
||||
else None
|
||||
)
|
||||
if silu_node is None:
|
||||
continue
|
||||
|
||||
if not (silu_node.args and is_linear_op(silu_node.args[0], include_quantization=True)):
|
||||
continue
|
||||
linear_w1_node = silu_node.args[0]
|
||||
|
||||
# The other branch should be a linear op (w3 branch).
|
||||
linear_w3_node = arg_b if arg_a is silu_node else arg_a
|
||||
if not is_linear_op(linear_w3_node, include_quantization=True):
|
||||
continue
|
||||
if not (linear_w1_node.args and linear_w3_node.args):
|
||||
continue
|
||||
|
||||
# Extract parameters from each linear op.
|
||||
input_node_w1, weight_w1, quant_params_w1, wt_type_w1 = _extract_linear_parameters(
|
||||
linear_w1_node
|
||||
)
|
||||
_, weight_w3, quant_params_w3, wt_type_w3 = _extract_linear_parameters(linear_w3_node)
|
||||
_, weight_w2, quant_params_w2, wt_type_w2 = _extract_linear_parameters(final_linear)
|
||||
|
||||
if None in (weight_w1, weight_w3, weight_w2):
|
||||
continue
|
||||
|
||||
# Ensure the weight type is consistent across branches.
|
||||
if wt_type_w1 != wt_type_w3 or wt_type_w1 != wt_type_w2:
|
||||
continue
|
||||
weight_type = wt_type_w1
|
||||
|
||||
pattern_input_nodes.append(input_node_w1)
|
||||
pattern_output_nodes.append(final_linear)
|
||||
expert_weights["w1"].append(weight_w1)
|
||||
expert_weights["w3"].append(weight_w3)
|
||||
expert_weights["w2"].append(weight_w2)
|
||||
|
||||
# TODO: sanity check that all experts have same weight type
|
||||
if weight_type == "fp8":
|
||||
expert_scales["w1_input_scale"].append(quant_params_w1["input_scale"])
|
||||
expert_scales["w1_weight_scale"].append(quant_params_w1["weight_scale"])
|
||||
expert_scales["w3_input_scale"].append(quant_params_w3["input_scale"])
|
||||
expert_scales["w3_weight_scale"].append(quant_params_w3["weight_scale"])
|
||||
expert_scales["w2_input_scale"].append(quant_params_w2["input_scale"])
|
||||
expert_scales["w2_weight_scale"].append(quant_params_w2["weight_scale"])
|
||||
elif weight_type == "fp4":
|
||||
expert_scales["w1_input_scale"].append(quant_params_w1["input_scale"])
|
||||
expert_scales["w1_weight_scale"].append(quant_params_w1["weight_scale"])
|
||||
expert_scales["w1_alpha"].append(quant_params_w1["alpha"])
|
||||
expert_scales["w3_input_scale"].append(quant_params_w3["input_scale"])
|
||||
expert_scales["w3_weight_scale"].append(quant_params_w3["weight_scale"])
|
||||
expert_scales["w3_alpha"].append(quant_params_w3["alpha"])
|
||||
expert_scales["w2_input_scale"].append(quant_params_w2["input_scale"])
|
||||
expert_scales["w2_weight_scale"].append(quant_params_w2["weight_scale"])
|
||||
expert_scales["w2_alpha"].append(quant_params_w2["alpha"])
|
||||
|
||||
return pattern_input_nodes, pattern_output_nodes, expert_weights, expert_scales, weight_type
|
||||
|
||||
|
||||
def _find_final_hidden_state_node(
|
||||
pattern_output_nodes: list[Node], end_boundary: Node
|
||||
) -> Optional[Node]:
|
||||
"""
|
||||
Identify the final hidden state node corresponding to the combine pattern:
|
||||
|
||||
(expert_output * routing_weight) → index_add_
|
||||
|
||||
For each expert output node (from the expert compute pattern), this function:
|
||||
1. Retrieves a multiplication node from its users.
|
||||
2. Extracts the second argument from the multiplication node (assumed to be the index node).
|
||||
3. Uses a BFS to locate the subsequent index_add_ node (guarded by the end_boundary).
|
||||
|
||||
After collecting all such index_add_ nodes, the final hidden state node is determined
|
||||
as the one that is not used by any of the other index_add_ nodes.
|
||||
|
||||
If any required attribute (users or args) is missing during the process or if no valid
|
||||
final node is found, the function returns None.
|
||||
"""
|
||||
|
||||
if not pattern_output_nodes:
|
||||
return None
|
||||
|
||||
index_add_nodes = []
|
||||
for node in pattern_output_nodes:
|
||||
if not node.users:
|
||||
return None
|
||||
mul_node = next(iter(node.users))
|
||||
if not (hasattr(mul_node, "args") and len(mul_node.args) >= 2):
|
||||
return None
|
||||
index_node = mul_node.args[1]
|
||||
index_add_node = bfs(
|
||||
index_node, lambda n: is_op(n, torch.ops.aten.index_add_), boundary=end_boundary
|
||||
)
|
||||
if not index_add_node:
|
||||
return None
|
||||
index_add_nodes.append(index_add_node)
|
||||
|
||||
# The final node is defined as the index_add_node that is not used by any other index_add_nodes
|
||||
return next(
|
||||
(
|
||||
candidate
|
||||
for candidate in index_add_nodes
|
||||
if not any(
|
||||
candidate in other.args for other in index_add_nodes if candidate is not other
|
||||
)
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def _extract_index_branches_from_expert_outputs(
|
||||
pattern_output_nodes: list[Node],
|
||||
) -> tuple[list[Node], list[Node]]:
|
||||
"""
|
||||
Extract routing and experts branches from expert outputs.
|
||||
|
||||
For each expert output, find its multiplication user. From the
|
||||
multiplication node's second argument (an index node),
|
||||
extract:
|
||||
- The first argument as the routing branch.
|
||||
- The second argument (flattened if a list/tuple) as the experts branch.
|
||||
|
||||
Returns:
|
||||
A tuple (routing_branches, experts_branches).
|
||||
"""
|
||||
routing_branches, experts_branches = [], []
|
||||
for out in pattern_output_nodes:
|
||||
mul = next((u for u in out.users if is_op(u, torch.ops.aten.mul)), None)
|
||||
if not mul or len(mul.args) < 2:
|
||||
continue
|
||||
idx_node = mul.args[1]
|
||||
if not is_op(idx_node, torch.ops.aten.index):
|
||||
continue
|
||||
routing_branches.append(idx_node.args[0])
|
||||
experts = idx_node.args[1]
|
||||
experts_branches.extend(experts) if isinstance(
|
||||
experts, (list, tuple)
|
||||
) else experts_branches.append(experts)
|
||||
return routing_branches, experts_branches
|
||||
|
||||
|
||||
def _remove_dead_inplace_nodes_in_region(
|
||||
graph: torch.fx.Graph,
|
||||
start_boundary: torch.fx.Node,
|
||||
end_boundary: torch.fx.Node,
|
||||
) -> bool:
|
||||
"""
|
||||
Searches (via BFS) for a dead in-place node (index_add_) in the region
|
||||
between start_boundary and end_boundary. If one is found, it is removed from the graph.
|
||||
Returns True if a node was removed, False otherwise.
|
||||
"""
|
||||
|
||||
def target(n: torch.fx.Node) -> bool:
|
||||
return is_op(n, {torch.ops.aten.index_add_}) and len(n.users) == 0
|
||||
|
||||
try:
|
||||
node_to_remove = bfs(start_boundary, target, attr_next="users", boundary=end_boundary)
|
||||
ad_logger.debug(f"Removing In-place Dead Node: {node_to_remove}")
|
||||
graph.erase_node(node_to_remove)
|
||||
return True
|
||||
except RuntimeError:
|
||||
return False
|
||||
@ -1,193 +0,0 @@
|
||||
"""Graph transformation to automatically add kv cache into fused MHA op."""
|
||||
|
||||
import operator
|
||||
from typing import Dict, Type
|
||||
|
||||
import torch
|
||||
from torch.fx import Graph, GraphModule, Node
|
||||
|
||||
from ...custom_ops.attention_interface import AttentionDescriptor, CacheConfig
|
||||
from ...distributed.common import all_gather_object, get_world_size
|
||||
from ...shim.interface import CachedSequenceInterface
|
||||
from ...utils.logger import ad_logger
|
||||
from ...utils.node_utils import get_all_input_output_nodes, is_op
|
||||
from .._graph import add_graph_input, canonicalize_graph
|
||||
|
||||
|
||||
def update_in_out_nodes(egm: GraphModule, cm: CachedSequenceInterface) -> None:
|
||||
"""Modify the graph module by adding new input nodes and canonicalizing the graph.
|
||||
|
||||
The new input nodes correspond to the extra arguments needed for cached and flattened attention.
|
||||
|
||||
Args:
|
||||
egm: The graph module to analyze and modify.
|
||||
cm: Cached sequence interface containing extra argument information.
|
||||
"""
|
||||
# loop through nodes to get input, output, and get_attr nodes
|
||||
input_nodes, output_nodes = get_all_input_output_nodes(egm.graph)
|
||||
|
||||
# we only expect one input node
|
||||
assert len(input_nodes) == 2, "Expected exactly two input nodes (input_ids, position_ids)."
|
||||
|
||||
# NOTE: for now, we wanna make sure we *only* return the final output and no hidden states.
|
||||
# Later on, we can revisit how to support returning hidden states.
|
||||
assert len(output_nodes) == 1, "Expected exactly one output node!"
|
||||
assert len(output_nodes[0].all_input_nodes) == 1, "Expected to only return final tensor output!"
|
||||
|
||||
ad_logger.info(f"Found {len(input_nodes)} input nodes and {len(output_nodes)} output nodes")
|
||||
|
||||
# Activate and add extra argument nodes
|
||||
new_args = cm.info.switch_to_cached_attn_inputs()
|
||||
for name in new_args:
|
||||
input_nodes.append(add_graph_input(egm, name))
|
||||
ad_logger.info(f"Added {len(new_args)} new input nodes for cached attention metadata")
|
||||
|
||||
canonicalize_graph(egm)
|
||||
|
||||
|
||||
def insert_cached_attention(
|
||||
egm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
attn_descriptor: Type[AttentionDescriptor],
|
||||
cache_config: CacheConfig,
|
||||
) -> None:
|
||||
"""Replace uncached source attention node with corresponding cached attn node."""
|
||||
# Get all attention nodes and their info objects
|
||||
source_op = attn_descriptor.get_source_attention_op()
|
||||
|
||||
# pick up graph
|
||||
graph: Graph = egm.graph
|
||||
|
||||
# look for relevant source attention nodes
|
||||
source_attn_nodes = [n for n in graph.nodes if is_op(n, source_op)]
|
||||
|
||||
if not source_attn_nodes:
|
||||
# If there are no nodes for kv cache insertion found, return current graph
|
||||
return
|
||||
|
||||
# Sanity check
|
||||
if cm.info.is_paged:
|
||||
assert attn_descriptor.is_paged(), "Paged sequence info requires paged attention op."
|
||||
|
||||
ad_logger.debug(f"Before inserting {attn_descriptor=} with cache: {egm}")
|
||||
|
||||
# retrieve input nodes
|
||||
input_nodes, _ = get_all_input_output_nodes(egm.graph)
|
||||
|
||||
# insert metadata computation and extract each argument as a node
|
||||
get_metadata, num_metadata = attn_descriptor.get_prepare_metadata_op()
|
||||
with graph.inserting_before(input_nodes[-1].next):
|
||||
ret_node = graph.call_function(
|
||||
get_metadata,
|
||||
args=(
|
||||
*input_nodes,
|
||||
cm.info.page_size,
|
||||
),
|
||||
)
|
||||
metadata_nodes = [
|
||||
graph.call_function(operator.getitem, args=(ret_node, idx))
|
||||
for idx in range(num_metadata)
|
||||
]
|
||||
|
||||
buffer_in_lookup: Dict[str, Node] = {}
|
||||
|
||||
# replace fused attention node with attention node that has kv cache
|
||||
num_cached_attn_replacements = 0
|
||||
for idx, attn_node in enumerate(source_attn_nodes):
|
||||
# pick out GEMMs
|
||||
qkv = attn_node.args[: attn_descriptor.get_num_qkv_args()]
|
||||
|
||||
# setup + store cache initializers and caches as input nodes
|
||||
cache_in_nodes = []
|
||||
for k, get_cache in attn_descriptor.get_cache_initializers(attn_node, cache_config).items():
|
||||
k_indexed = f"{k}_{idx}"
|
||||
cm.add_cache(k_indexed, get_cache)
|
||||
cache_in_nodes.append(add_graph_input(egm, k_indexed))
|
||||
|
||||
# setup + store global buffer initializers and buffers as input nodes
|
||||
# NOTE: we have to check against existing keys to make sure nothing is registered twice...
|
||||
buffer_in_nodes = []
|
||||
for k, get_buffer in attn_descriptor.get_global_buffer_initializers(attn_node).items():
|
||||
if k not in buffer_in_lookup:
|
||||
cm.add_cache(k, get_buffer)
|
||||
buffer_in_lookup[k] = add_graph_input(egm, k)
|
||||
buffer_in_nodes.append(buffer_in_lookup[k]) # store buffer nodes for this op
|
||||
|
||||
# retrieve constants for attention_op
|
||||
constants = attn_descriptor.get_constants(attn_node)
|
||||
|
||||
# insert cached attention replacement op
|
||||
with graph.inserting_before(attn_node):
|
||||
cached_attn_node = graph.call_function(
|
||||
attn_descriptor.get_cached_attention_op(),
|
||||
args=(*qkv, *metadata_nodes, *cache_in_nodes, *buffer_in_nodes, *constants),
|
||||
)
|
||||
attn_node.replace_all_uses_with(cached_attn_node)
|
||||
graph.erase_node(attn_node)
|
||||
num_cached_attn_replacements += 1
|
||||
|
||||
canonicalize_graph(egm)
|
||||
ad_logger.info(
|
||||
f"Replaced {num_cached_attn_replacements} {source_op} ops "
|
||||
f"with {attn_descriptor.get_cached_attention_op()}"
|
||||
)
|
||||
ad_logger.debug(f"After inserting {attn_descriptor=} with cache: {egm}")
|
||||
|
||||
|
||||
def resize_kv_cache(
|
||||
egm: GraphModule,
|
||||
cm: CachedSequenceInterface,
|
||||
free_mem_ratio: float = 0.8,
|
||||
) -> None:
|
||||
"""Inflate the kv cache to occupy the available GPU memory.
|
||||
|
||||
free_mem_ratio specifies the fraction of available memory to occupy.
|
||||
"""
|
||||
|
||||
def _get_mem_info_in_mb():
|
||||
free_mem, total_mem = torch.cuda.mem_get_info()
|
||||
return free_mem // 1024**2, total_mem // 1024**2
|
||||
|
||||
free_mem, total_mem = _get_mem_info_in_mb()
|
||||
ad_logger.info(f"Free memory (MB): {free_mem}, Total memory (MB): {total_mem}")
|
||||
current_cache_size = cm.current_cache_size_bytes()
|
||||
current_num_pages = cm.info.num_pages
|
||||
ad_logger.info(
|
||||
f"Current cache size: {current_cache_size}, Current num pages: {current_num_pages}"
|
||||
)
|
||||
|
||||
if free_mem_ratio == 0.0:
|
||||
ad_logger.info(f"Skipping cache resize for {free_mem_ratio=}")
|
||||
return
|
||||
|
||||
try:
|
||||
# Let's run a forward pass to get the memory usage
|
||||
cm.info._set_max_num_tokens_sample()
|
||||
free_mem_pre, _ = _get_mem_info_in_mb()
|
||||
ad_logger.info(f"Free memory before forward pass (MB): {free_mem_pre}")
|
||||
|
||||
egm(*cm.args)
|
||||
|
||||
free_mem_post, _ = _get_mem_info_in_mb()
|
||||
ad_logger.info(f"Free memory after forward pass (MB): {free_mem_post}")
|
||||
|
||||
memory_for_forward_pass = free_mem_pre - free_mem_post
|
||||
ad_logger.info(f"Memory for forward pass (MB): {memory_for_forward_pass}")
|
||||
|
||||
new_cache_size = free_mem_post * 1024 * 1024 * free_mem_ratio + current_cache_size
|
||||
new_num_pages = int(new_cache_size // (current_cache_size // current_num_pages))
|
||||
|
||||
# Need to sync all the GPUs
|
||||
gathered_num_pages = [None] * get_world_size()
|
||||
all_gather_object(gathered_num_pages, new_num_pages)
|
||||
new_num_pages = min(gathered_num_pages)
|
||||
ad_logger.info(f"After all_gather - new_num_pages: {new_num_pages}")
|
||||
|
||||
cm.resize_cache(new_num_pages)
|
||||
except Exception as e:
|
||||
ad_logger.warning(
|
||||
f"Error encountered while resizing kv cache: {e}.\nSkipping cache resize."
|
||||
)
|
||||
|
||||
# Free memory
|
||||
torch.cuda.empty_cache()
|
||||
@ -1,113 +0,0 @@
|
||||
"""Graph transform to optimize RMSNorm execution using FlashInfer."""
|
||||
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
|
||||
from ...utils.logger import ad_logger
|
||||
|
||||
# It is important to import ADPatternMatcherPass from pattern_matcher.py, not from torch._inductor.pattern_matcher
|
||||
from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern
|
||||
from .._graph import canonicalize_graph
|
||||
|
||||
_BACKEND_OPS = {
|
||||
"flashinfer": torch.ops.auto_deploy.flashinfer_rms_norm,
|
||||
"triton": torch.ops.auto_deploy.triton_rms_norm,
|
||||
"torch": torch.ops.auto_deploy.torch_rmsnorm,
|
||||
}
|
||||
|
||||
|
||||
def _rms_norm_pattern(data: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
|
||||
"""Implements the RMSNorm pattern for pattern matching.
|
||||
|
||||
Args:
|
||||
data: Input tensor to normalize.
|
||||
weight: Scaling weights for the normalized output.
|
||||
eps: Small constant for numerical stability.
|
||||
|
||||
Returns:
|
||||
Normalized and scaled tensor.
|
||||
"""
|
||||
input_dtype = data.dtype
|
||||
data = data.to(torch.float32)
|
||||
variance = data.pow(2).mean(-1, keepdim=True)
|
||||
data = data * torch.rsqrt(variance + eps)
|
||||
return weight * data.to(input_dtype)
|
||||
|
||||
|
||||
def _rms_norm_replacement(
|
||||
data: torch.Tensor, weight: torch.Tensor, eps: float, backend: str
|
||||
) -> torch.Tensor:
|
||||
"""Backend-specific rms_norm implementation.
|
||||
|
||||
Args:
|
||||
data: Input tensor to normalize.
|
||||
weight: Scaling weights for the normalized output.
|
||||
eps: Small constant for numerical stability.
|
||||
backend: Backend to use for RMSNorm computation ("flashinfer" or "triton").
|
||||
|
||||
Returns:
|
||||
Normalized and scaled tensor using the specified backend implementation.
|
||||
"""
|
||||
|
||||
assert backend.lower() in _BACKEND_OPS, (
|
||||
f"Invalid {backend=}; must be one of {list(_BACKEND_OPS)}"
|
||||
)
|
||||
return _BACKEND_OPS[backend.lower()](data, weight, eps)
|
||||
|
||||
|
||||
def fuse_rmsnorm(gm: GraphModule, backend: str = "triton") -> None:
|
||||
"""Matches and replaces RMSNorm patterns in the graph with FlashInfer or Triton implementation.
|
||||
|
||||
This function sets up pattern matching to identify RMSNorm operations in the graph
|
||||
and replaces them with optimized implementations. It uses dummy tensors to register
|
||||
the pattern matching rules.
|
||||
|
||||
Args:
|
||||
gm: Input graph module to transform.
|
||||
backend: Backend to use for RMSNorm computation ("flashinfer" or "triton").
|
||||
|
||||
Returns:
|
||||
Transformed graph module with optimized RMSNorm operations.
|
||||
"""
|
||||
if backend.lower() not in _BACKEND_OPS:
|
||||
raise ValueError(f"Invalid backend, must be one of {list(_BACKEND_OPS)}, got {backend}")
|
||||
ad_logger.info(f"Starting RMSNorm pattern matching with backend: {backend}")
|
||||
|
||||
graph = gm.graph
|
||||
patterns = ADPatternMatcherPass()
|
||||
|
||||
# Create dummy tensors for pattern matching
|
||||
bs = 2
|
||||
hidden_size = 512
|
||||
|
||||
def dummy_args(input_dtype: torch.dtype, weight_dtype: torch.dtype, eps: float = 1e-6):
|
||||
return [
|
||||
torch.randn(bs, hidden_size, device="cuda", dtype=input_dtype),
|
||||
torch.randn(hidden_size, device="cuda", dtype=weight_dtype),
|
||||
eps,
|
||||
]
|
||||
|
||||
# Define configurations for different data types
|
||||
configs = [
|
||||
(torch.bfloat16, torch.bfloat16),
|
||||
(torch.float16, torch.float16),
|
||||
(torch.float32, torch.float32),
|
||||
]
|
||||
|
||||
# Register patterns for each configuration
|
||||
for input_dtype, weight_dtype in configs:
|
||||
register_ad_pattern(
|
||||
search_fn=_rms_norm_pattern,
|
||||
replace_fn=partial(_rms_norm_replacement, backend=backend),
|
||||
patterns=patterns,
|
||||
dummy_args=dummy_args(input_dtype, weight_dtype),
|
||||
op_ignore_types={},
|
||||
scalar_workaround={"eps": 1e-6},
|
||||
)
|
||||
|
||||
cnt = patterns.apply(graph)
|
||||
ad_logger.info(f"RMSNorm pattern count: {cnt}")
|
||||
canonicalize_graph(gm)
|
||||
ad_logger.debug("RMSNorm pattern matching completed.")
|
||||
@ -5,21 +5,11 @@ import gc
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..compile import compile_and_capture
|
||||
from ..custom_ops.attention_interface import AttentionRegistry
|
||||
from ..llm_args import AutoDeployConfig
|
||||
from ..models.factory import ModelFactory
|
||||
from ..shim.interface import CachedSequenceInterface
|
||||
from ..transform.optimizer import InferenceOptimizer as ModularInferenceOptimizer
|
||||
from ..utils.logger import ad_logger
|
||||
from .library import (
|
||||
fuse_allreduce_residual_rmsnorm,
|
||||
fuse_collectives,
|
||||
fuse_rmsnorm,
|
||||
insert_cached_attention,
|
||||
resize_kv_cache,
|
||||
update_in_out_nodes,
|
||||
)
|
||||
|
||||
|
||||
class InferenceOptimizer:
|
||||
@ -55,88 +45,60 @@ class InferenceOptimizer:
|
||||
self.ad_config.attn_backend
|
||||
).get_attention_layout()
|
||||
|
||||
new_optimizer = ModularInferenceOptimizer(self.factory, self.ad_config.transforms)
|
||||
|
||||
# TODO (hg): similar to above.
|
||||
if "load_weights" in new_optimizer.config:
|
||||
new_optimizer.config[
|
||||
if "load_weights" in self.ad_config.transforms:
|
||||
self.ad_config.transforms[
|
||||
"load_weights"
|
||||
].checkpoint_device = self.ad_config.checkpoint_device
|
||||
new_optimizer.config["load_weights"].device = cm.device
|
||||
self.ad_config.transforms["load_weights"].device = cm.device
|
||||
|
||||
if "resize_kv_cache" in self.ad_config.transforms:
|
||||
self.ad_config.transforms[
|
||||
"resize_kv_cache"
|
||||
].free_mem_ratio = self.ad_config.free_mem_ratio
|
||||
if "insert_cached_attention" in self.ad_config.transforms:
|
||||
self.ad_config.transforms[
|
||||
"insert_cached_attention"
|
||||
].attn_backend = self.ad_config.attn_backend
|
||||
if "insert_cached_mla_attention" in self.ad_config.transforms:
|
||||
self.ad_config.transforms[
|
||||
"insert_cached_mla_attention"
|
||||
].attn_backend = self.ad_config.mla_backend
|
||||
|
||||
# TODO: (hg)Missing MLA here. Figure out how to add MLA since duplicate transforms are not allowed.
|
||||
# Old code:
|
||||
# detect attention op and replace with cache-aware op
|
||||
# for a_backend in [self.ad_config.attn_backend, self.ad_config.mla_backend]:
|
||||
# attn_descriptor = AttentionRegistry.get(a_backend)
|
||||
# insert_cached_attention(egm, cm, attn_descriptor, self.factory.get_cache_config())
|
||||
|
||||
if "compile_model" in self.ad_config.transforms:
|
||||
self.ad_config.transforms[
|
||||
"compile_model"
|
||||
].cuda_graph_batch_sizes = self.ad_config.cuda_graph_batch_sizes
|
||||
self.ad_config.transforms[
|
||||
"compile_model"
|
||||
].compile_backend = self.ad_config.compile_backend
|
||||
|
||||
new_optimizer = ModularInferenceOptimizer(self.factory, self.ad_config.transforms)
|
||||
# TODO: (hg) move this. let match_rope_layout and match_atten_layout use this shared config
|
||||
new_optimizer.shared_config.attn_backend = self.ad_config.attn_backend
|
||||
|
||||
egm = new_optimizer(cm)
|
||||
|
||||
# TODO (lucaslie): continue moving legacy transforms to the new optimizer
|
||||
############################################################################################
|
||||
# RUN POST-LOAD FUSION AND OPTIMIZATIONS
|
||||
############################################################################################
|
||||
# NOTE: (hg)Disabled visualization since compiled gm is a CapturedGraph instead of GraphModule.
|
||||
# We can add a new stage in the optimizer to visualize the intermediate gm.
|
||||
# if self.ad_config.visualize:
|
||||
# try:
|
||||
# from .library import visualize_namespace
|
||||
|
||||
# run MoE fusion
|
||||
# TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/4674 this is causing OOMs
|
||||
# fuse_moe(egm)
|
||||
|
||||
# run GEMM fusion
|
||||
# TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/4674 this is causing OOMs
|
||||
# fuse_gemms(egm)
|
||||
|
||||
# check if we can fuse allreduce, residual and rmsnorm
|
||||
fuse_allreduce_residual_rmsnorm(egm)
|
||||
|
||||
# check if we can fuse collectives
|
||||
fuse_collectives(egm)
|
||||
|
||||
# TODO (lucaslie): add backend selection as part of configurable inference optimizers
|
||||
# check if we can fuse rmsnorm
|
||||
fuse_rmsnorm(egm, "flashinfer")
|
||||
|
||||
# visualize the final graph
|
||||
if self.ad_config.visualize:
|
||||
try:
|
||||
from .library import visualize_namespace
|
||||
|
||||
visualize_namespace(egm, args=cm.args, dynamic_shapes=cm.dynamic_shapes)
|
||||
ad_logger.warning(
|
||||
"Please run `pip install -r examples/auto_deploy/requirements.txt` to visualize"
|
||||
" the graph."
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
############################################################################################
|
||||
# SWITCH TO CACHED+FLATTENED ATTENTION + INITIALIZE CACHES
|
||||
############################################################################################
|
||||
|
||||
update_in_out_nodes(egm, cm)
|
||||
|
||||
# detect attention op and replace with cache-aware op
|
||||
for a_backend in [self.ad_config.attn_backend, self.ad_config.mla_backend]:
|
||||
attn_descriptor = AttentionRegistry.get(a_backend)
|
||||
insert_cached_attention(egm, cm, attn_descriptor, self.factory.get_cache_config())
|
||||
|
||||
# initialize cache on correct device
|
||||
cm.initialize_caches()
|
||||
|
||||
# resize kv cache to occupy the available GPU memory up to free_mem_ratio
|
||||
resize_kv_cache(egm, cm, free_mem_ratio=self.ad_config.free_mem_ratio)
|
||||
|
||||
############################################################################################
|
||||
# COMPILE MODEL
|
||||
############################################################################################
|
||||
|
||||
cm.info.set_generate_only_batch()
|
||||
compiler_kwargs = {
|
||||
"cuda_graph_batch_sizes": self.ad_config.cuda_graph_batch_sizes,
|
||||
"num_batched_inputs": 2, # TODO (lucaslie): improve once we have a config system...
|
||||
}
|
||||
egm_compiled = compile_and_capture(
|
||||
egm,
|
||||
self.ad_config.compile_backend,
|
||||
args=cm.args,
|
||||
dynamic_shapes=cm.dynamic_shapes,
|
||||
compiler_kwargs=compiler_kwargs,
|
||||
)
|
||||
cm.info.reset()
|
||||
# visualize_namespace(egm, args=cm.args, dynamic_shapes=cm.dynamic_shapes)
|
||||
# ad_logger.warning(
|
||||
# "Please run `pip install -r examples/auto_deploy/requirements.txt` to visualize"
|
||||
# " the graph."
|
||||
# )
|
||||
# except ImportError:
|
||||
# pass
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
return egm_compiled
|
||||
return egm
|
||||
|
||||
@ -43,11 +43,13 @@ def _patch_unsupported_input_tensor():
|
||||
"""
|
||||
original_fn = lowering.unsupported_input_tensor
|
||||
|
||||
def patched_fn(t: torch.Tensor, node=None):
|
||||
def patched_fn(t: torch.Tensor, *args, **kwargs):
|
||||
"""Bypass meta tensor check."""
|
||||
if t.is_meta:
|
||||
return False
|
||||
return original_fn(t, node)
|
||||
return original_fn(
|
||||
t, *args, **kwargs
|
||||
) # a generic pass-through of the arguments to accommodate torch side change
|
||||
|
||||
lowering.unsupported_input_tensor = patched_fn
|
||||
try:
|
||||
|
||||
@ -453,7 +453,8 @@ class AutoTuner:
|
||||
p.name
|
||||
for p in inspect.signature(runner.forward).parameters.values()
|
||||
}
|
||||
valid_tactics = runner.get_valid_tactics(input_tensors, profile)
|
||||
valid_tactics = runner.get_valid_tactics(input_tensors, profile,
|
||||
**kwargs)
|
||||
if "do_preparation" in runner_arg_names and len(valid_tactics) > 0:
|
||||
runner(
|
||||
input_tensors,
|
||||
|
||||
@ -531,3 +531,11 @@ def _register_fake():
|
||||
return router_logits.new_empty(
|
||||
sz, dtype=torch.int32), router_logits.new_empty(sz,
|
||||
dtype=torch.float32)
|
||||
|
||||
@torch.library.register_fake("trtllm::default_moe_routing_op")
|
||||
def _(router_logits, topk):
|
||||
num_tokens = router_logits.shape[0]
|
||||
sz = (num_tokens, topk)
|
||||
return router_logits.new_empty(
|
||||
sz, dtype=torch.int32), router_logits.new_empty(sz,
|
||||
dtype=torch.float32)
|
||||
|
||||
@ -81,12 +81,9 @@ class MoERunner(TunableRunner):
|
||||
use_fused_finalize)
|
||||
self.fused_moe_runner = MoERunner.runner_dict[instance_key]
|
||||
|
||||
def get_valid_tactics(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile,
|
||||
) -> List[int]:
|
||||
return range(self.fused_moe_runner.get_tactic_num())
|
||||
def get_valid_tactics(self, inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile, **kwargs) -> List[int]:
|
||||
return range(self.fused_moe_runner.get_tactic_num(kwargs["gemm_idx"]))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -318,11 +315,8 @@ class FP8RowwiseGemmRunner(TunableRunner):
|
||||
self.fp8_rowwise_gemm_runner = FP8RowwiseGemmRunner.runner_dict[
|
||||
instance_key]
|
||||
|
||||
def get_valid_tactics(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile,
|
||||
) -> List[int]:
|
||||
def get_valid_tactics(self, inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile, **kwargs) -> List[int]:
|
||||
return list(range(self.fp8_rowwise_gemm_runner.get_num_configs()))
|
||||
|
||||
def forward(
|
||||
@ -403,11 +397,8 @@ class FP4GemmRunner(TunableRunner):
|
||||
output_dtype, int(fp4_gemm_type))
|
||||
self.fp4_gemm_runner = FP4GemmRunner.runner_dict[instance_key]
|
||||
|
||||
def get_valid_tactics(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile,
|
||||
) -> List[int]:
|
||||
def get_valid_tactics(self, inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile, **kwargs) -> List[int]:
|
||||
return list(range(self.fp4_gemm_runner.get_num_configs()))
|
||||
|
||||
def forward(
|
||||
@ -518,11 +509,8 @@ class FP8BatchedGemmRunner(TunableRunner):
|
||||
|
||||
return out_tensors
|
||||
|
||||
def get_valid_tactics(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile,
|
||||
) -> List[int]:
|
||||
def get_valid_tactics(self, inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile, **kwargs) -> List[int]:
|
||||
|
||||
mat1, mat2, _, _, _ = inputs
|
||||
|
||||
@ -735,11 +723,8 @@ class WeightOnlyQuantGemmRunner(TunableRunner):
|
||||
self.weight_only_quant_gemm_runner = WeightOnlyQuantGemmRunner.runner_dict[
|
||||
instance_key]
|
||||
|
||||
def get_valid_tactics(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile,
|
||||
) -> List[int]:
|
||||
def get_valid_tactics(self, inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile, **kwargs) -> List[int]:
|
||||
return list(range(self.weight_only_quant_gemm_runner.get_num_configs()))
|
||||
|
||||
def forward(
|
||||
@ -813,11 +798,8 @@ class FinegrainedMixedDtypeGemm(TunableRunner):
|
||||
self._finegrained_mixed_dtype_gemm_runner = FinegrainedMixedDtypeGemm._runner_dict[
|
||||
instance_key]
|
||||
|
||||
def get_valid_tactics(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile,
|
||||
) -> List[int]:
|
||||
def get_valid_tactics(self, inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile, **kwargs) -> List[int]:
|
||||
return list(
|
||||
range(self._finegrained_mixed_dtype_gemm_runner.get_num_configs()))
|
||||
|
||||
|
||||
@ -122,11 +122,8 @@ class FP4BlockScaleMoERunner(TunableRunner):
|
||||
self.local_num_experts, self.routed_scaling_factor,
|
||||
self.routing_method_type, self.do_finalize, tactic)
|
||||
|
||||
def get_valid_tactics(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile,
|
||||
) -> List[int]:
|
||||
def get_valid_tactics(self, inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile, **kwargs) -> List[int]:
|
||||
|
||||
args = FP4BlockScaleMoEInputs(*inputs)
|
||||
|
||||
@ -409,11 +406,8 @@ class FP8BlockScaleMoERunner(TunableRunner):
|
||||
self.local_expert_offset, self.local_num_experts,
|
||||
self.routed_scaling_factor, self.routing_method_type, tactic)
|
||||
|
||||
def get_valid_tactics(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile,
|
||||
) -> List[int]:
|
||||
def get_valid_tactics(self, inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile, **kwargs) -> List[int]:
|
||||
|
||||
args = FP8BlockScaleMoEInputs(*inputs)
|
||||
|
||||
@ -670,11 +664,8 @@ class MxE4m3MxE2m1BlockScaleMoERunner(TunableRunner):
|
||||
self.local_expert_offset, self.local_num_experts,
|
||||
self.routed_scaling_factor, self.routing_method_type, tactic)
|
||||
|
||||
def get_valid_tactics(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile,
|
||||
) -> List[int]:
|
||||
def get_valid_tactics(self, inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile, **kwargs) -> List[int]:
|
||||
|
||||
args = MxE4m3MxE2m1BlockScaleMoEInputs(*inputs)
|
||||
|
||||
@ -907,11 +898,8 @@ class E4m3MxE2m1BlockScaleMoERunner(TunableRunner):
|
||||
self.local_expert_offset, self.local_num_experts,
|
||||
self.routed_scaling_factor, self.routing_method_type, tactic)
|
||||
|
||||
def get_valid_tactics(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile,
|
||||
) -> List[int]:
|
||||
def get_valid_tactics(self, inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile, **kwargs) -> List[int]:
|
||||
|
||||
args = E4m3MxE2m1BlockScaleMoEInputs(*inputs)
|
||||
|
||||
@ -1123,11 +1111,8 @@ class Bf16MxE2m1BlockScaleMoERunner(TunableRunner):
|
||||
self.local_num_experts, self.routed_scaling_factor,
|
||||
self.routing_method_type, tactic)
|
||||
|
||||
def get_valid_tactics(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile,
|
||||
) -> List[int]:
|
||||
def get_valid_tactics(self, inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile, **kwargs) -> List[int]:
|
||||
|
||||
args = Bf16MxE2m1BlockScaleMoEInputs(*inputs)
|
||||
|
||||
|
||||
@ -65,7 +65,7 @@ from ..modules.linear import Linear, TensorParallelMode, WeightsLoadingConfig
|
||||
from ..modules.multi_stream_utils import maybe_execute_in_parallel
|
||||
from ..modules.rms_norm import RMSNorm
|
||||
from ..peft.lora.layer import LoraLayer
|
||||
from ..speculative import MTPSpecMetadata, SpecMetadata
|
||||
from ..speculative import SpecMetadata
|
||||
from ..utils import AuxStreamType, EventType, Fp4QuantizedTensor
|
||||
from .modeling_speculative import SpecDecOneEngineForCausalLM
|
||||
from .modeling_utils import (DecoderModel, EagerFusionConfig, filter_weights,
|
||||
@ -230,7 +230,7 @@ class DeepseekV3Attention(MLA):
|
||||
aux_stream: Optional[torch.cuda.Stream] = None,
|
||||
):
|
||||
config = model_config.pretrained_config
|
||||
predicted_tokens_per_seq = model_config.spec_config.num_nextn_predict_layers + 1 if model_config.spec_config is not None else 1
|
||||
predicted_tokens_per_seq = model_config.spec_config.max_draft_len + 1 if model_config.spec_config is not None else 1
|
||||
super().__init__(hidden_size=config.hidden_size,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
@ -750,6 +750,7 @@ class DeepseekV3DecoderLayer(DecoderLayer):
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: torch.Tensor,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if residual is None:
|
||||
@ -765,16 +766,24 @@ class DeepseekV3DecoderLayer(DecoderLayer):
|
||||
**kwargs,
|
||||
)
|
||||
if isinstance(self.mlp, Deepseekv3MoE):
|
||||
if spec_metadata is not None and spec_metadata.is_layer_capture(
|
||||
self.layer_idx):
|
||||
self.fusion_config.POST_MOE_FUSION = False
|
||||
return self.forward_MoE(
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=residual,
|
||||
spec_metadata=spec_metadata,
|
||||
)
|
||||
else:
|
||||
if spec_metadata is not None and spec_metadata.is_layer_capture(
|
||||
self.layer_idx):
|
||||
self.fusion_config.POST_MLP_FUSION = False
|
||||
assert isinstance(self.mlp, GatedMLP)
|
||||
return self.forward_mlp(
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
spec_metadata=spec_metadata,
|
||||
)
|
||||
|
||||
def forward_MoE(
|
||||
@ -782,6 +791,7 @@ class DeepseekV3DecoderLayer(DecoderLayer):
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: torch.Tensor,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
def _run_MoE(hidden_states, hidden_states_fp4, do_finalize):
|
||||
@ -856,6 +866,10 @@ class DeepseekV3DecoderLayer(DecoderLayer):
|
||||
hidden_states, residual = self.moe_allreduce(
|
||||
fc2_output, all_reduce_params=moe_all_reduce_params)
|
||||
else:
|
||||
if spec_metadata is not None and spec_metadata.is_layer_capture(
|
||||
self.layer_idx):
|
||||
spec_metadata.maybe_capture_hidden_states(
|
||||
self.layer_idx, hidden_states, residual)
|
||||
if self.next_layer_layernorm is not None:
|
||||
hidden_states, residual = self.next_layer_layernorm(
|
||||
hidden_states, residual)
|
||||
@ -866,6 +880,7 @@ class DeepseekV3DecoderLayer(DecoderLayer):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
if self.fusion_config.PRE_MLP_FUSION:
|
||||
@ -903,6 +918,10 @@ class DeepseekV3DecoderLayer(DecoderLayer):
|
||||
),
|
||||
)
|
||||
else:
|
||||
if spec_metadata is not None and spec_metadata.is_layer_capture(
|
||||
self.layer_idx):
|
||||
spec_metadata.maybe_capture_hidden_states(
|
||||
self.layer_idx, hidden_states, residual)
|
||||
if self.next_layer_layernorm is not None:
|
||||
hidden_states, residual = self.next_layer_layernorm(
|
||||
hidden_states, residual)
|
||||
@ -1105,6 +1124,7 @@ class DeepseekV3Model(DecoderModel):
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=residual,
|
||||
spec_metadata=spec_metadata,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
@ -1132,7 +1152,8 @@ class DeepseekV3ForCausalLM(SpecDecOneEngineForCausalLM[DeepseekV3Model,
|
||||
model_config=model_config)
|
||||
|
||||
self.model_nextn = 0
|
||||
if model_config.spec_config is not None:
|
||||
if model_config.spec_config is not None and model_config.spec_config.spec_dec_mode.is_mtp(
|
||||
):
|
||||
model_nextn = model_config.spec_config.num_nextn_predict_layers
|
||||
ckpt_nextn = self.config.num_nextn_predict_layers
|
||||
self.num_hidden_layers = self.config.num_hidden_layers
|
||||
@ -1167,11 +1188,10 @@ class DeepseekV3ForCausalLM(SpecDecOneEngineForCausalLM[DeepseekV3Model,
|
||||
input_ids: torch.IntTensor = None,
|
||||
position_ids: Optional[torch.IntTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
spec_metadata: Optional[MTPSpecMetadata] = None,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
return_context_logits: bool = False,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
attn_metadata.num_generations_per_batch = self.model_nextn + 1
|
||||
return super().forward(attn_metadata=attn_metadata,
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
@ -1313,7 +1333,9 @@ class DeepseekV3ForCausalLM(SpecDecOneEngineForCausalLM[DeepseekV3Model,
|
||||
|
||||
for name, module in tqdm(all_named_modules.items(),
|
||||
desc="Loading weights"):
|
||||
if len(module._parameters) > 0:
|
||||
if len(module._parameters) <= 0 or name.startswith("draft_model"):
|
||||
continue
|
||||
else:
|
||||
names = name.split('.')
|
||||
parent_module_name = '.'.join(names[:-1])
|
||||
if "model.layers" in name and int(
|
||||
|
||||
@ -194,11 +194,16 @@ class Gemma3VLM(PreTrainedModel):
|
||||
"text_config", "vision_config"
|
||||
], f"Expected subconfig name to be either 'text_config' or 'vision_config'. Got {name} instead."
|
||||
pretrained_config = getattr(model_config.pretrained_config, name)
|
||||
# ModelOpt currently doesn't quantize the vision part. Without setting quant config to None,
|
||||
# weight loading fails for vision.
|
||||
quant_config = model_config.quant_config if name == "text_config" else None
|
||||
# FlashInfer backend supports custom mask which is needed for bidirectional mask in decoder.
|
||||
preferred_backend = "FLASHINFER" if name == "text_config" else "TRTLLM"
|
||||
sub_model_config: ModelConfig[Gemma3Config] = dataclasses.replace(
|
||||
model_config,
|
||||
pretrained_config=pretrained_config,
|
||||
attn_backend=preferred_backend)
|
||||
attn_backend=preferred_backend,
|
||||
quant_config=quant_config)
|
||||
# Make sure some fields that are not explicitly included in the sub config, but present
|
||||
# in the top-level config, are replicated.
|
||||
if (hasattr(sub_model_config.pretrained_config, "torch_dtype")
|
||||
|
||||
@ -221,7 +221,9 @@ class NemotronHModel(DecoderModel):
|
||||
)
|
||||
|
||||
if self.mamba_metadata is None or self.mamba_metadata.max_batch_size != attn_metadata.max_num_requests:
|
||||
self.mamba_metadata = Mamba2Metadata(attn_metadata.max_num_requests)
|
||||
self.mamba_metadata = Mamba2Metadata(
|
||||
attn_metadata.max_num_requests,
|
||||
chunk_size=self.model_config.pretrained_config.chunk_size)
|
||||
self.mamba_metadata.prepare(attn_metadata)
|
||||
|
||||
if inputs_embeds is None:
|
||||
|
||||
@ -611,23 +611,21 @@ class Phi4MMForCausalLM(transformers.PreTrainedModel):
|
||||
@staticmethod
|
||||
def lora_config(model_dir: str):
|
||||
_lora_config = LoraConfig(
|
||||
lora_dir=[
|
||||
f"{model_dir}/vision-lora",
|
||||
f"{model_dir}/speech-lora",
|
||||
],
|
||||
lora_target_modules=[
|
||||
"attn_qkv",
|
||||
"attn_dense",
|
||||
"mlp_h_to_4h",
|
||||
"mlp_gate_up",
|
||||
"mlp_4h_to_h",
|
||||
],
|
||||
trtllm_modules_to_hf_modules={
|
||||
"attn_qkv": "qkv_proj",
|
||||
"attn_dense": "o_proj",
|
||||
"mlp_h_to_4h": "gate_up_proj",
|
||||
"mlp_gate_up": "gate_up_proj",
|
||||
"mlp_4h_to_h": "down_proj",
|
||||
},
|
||||
max_lora_rank=320, # Max rank for Phi4MM.
|
||||
swap_gate_up_proj_lora_b_weight=
|
||||
False, # Disable swap gate_up_proj.lora_B.weight for Phi4MM.
|
||||
)
|
||||
return _lora_config
|
||||
|
||||
|
||||
@ -155,10 +155,12 @@ class Eagle3DraftModel(DecoderModel):
|
||||
else:
|
||||
self.hidden_size_in = config.hidden_size
|
||||
|
||||
self.fc = Linear(self.hidden_size_in * 3,
|
||||
config.hidden_size,
|
||||
bias=getattr(config, "bias", False),
|
||||
dtype=config.torch_dtype)
|
||||
if self.spec_config.num_capture_layers > 1:
|
||||
self.fc = Linear(self.hidden_size_in *
|
||||
self.spec_config.num_capture_layers,
|
||||
config.hidden_size,
|
||||
bias=getattr(config, "bias", False),
|
||||
dtype=config.torch_dtype)
|
||||
|
||||
self.midlayer = Eagle3DecoderLayer(model_config, start_layer_idx)
|
||||
|
||||
|
||||
@ -183,18 +183,28 @@ class BaseMoeRoutingMethod(nn.Module):
|
||||
|
||||
class DefaultMoeRoutingMethod(BaseMoeRoutingMethod):
|
||||
|
||||
def __init__(self, top_k: int):
|
||||
def __init__(self, top_k: int, force_enable_pytorch_op: bool = False):
|
||||
super().__init__()
|
||||
self.top_k = top_k
|
||||
self.force_enable_pytorch_op = force_enable_pytorch_op
|
||||
|
||||
def apply(self,
|
||||
router_logits: torch.Tensor) -> (torch.Tensor, torch.Tensor):
|
||||
def apply_pytorch(
|
||||
self, router_logits: torch.Tensor) -> (torch.Tensor, torch.Tensor):
|
||||
topk_values, topk_indices = torch.topk(torch.nn.functional.softmax(
|
||||
router_logits.float(), dim=-1),
|
||||
k=self.top_k,
|
||||
dim=-1)
|
||||
return topk_indices.to(torch.int32), topk_values
|
||||
|
||||
def apply(self,
|
||||
router_logits: torch.Tensor) -> (torch.Tensor, torch.Tensor):
|
||||
num_experts = router_logits.shape[-1]
|
||||
if self.force_enable_pytorch_op or num_experts > 128 or self.top_k > 8:
|
||||
return self.apply_pytorch(router_logits)
|
||||
else:
|
||||
return torch.ops.trtllm.default_moe_routing_op(
|
||||
router_logits, self.top_k)
|
||||
|
||||
@property
|
||||
def routing_method_type(self):
|
||||
return RoutingMethodType.Default
|
||||
|
||||
@ -13,15 +13,83 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.attention_backend.interface import AttentionMetadata
|
||||
|
||||
|
||||
def cu_seqlens_to_chunk_indices_offsets(
|
||||
cu_seqlens: torch.Tensor,
|
||||
chunk_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
cu_seqlens (torch.Tensor): 1D tensor of cumulative sequence lengths, shape (num_seqs + 1,). The first element should be 0. Each entry represents the starting index of a sequence in the flattened token array.
|
||||
chunk_size (int): The size of each physical mamba chunk (number of tokens per chunk).
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
||||
- chunk_indices (torch.Tensor): 1D tensor of indices indicating the physical chunk for each logical chunk.
|
||||
- chunk_offsets (torch.Tensor): 1D tensor of offsets indicating the starting index of each logical chunk within its physical chunk.
|
||||
|
||||
This function computes the chunk indices and offsets for the given cu_seqlens and chunk_size.
|
||||
Both are tensors of integers with length N, where N is the number of logical (pseudo) chunks.
|
||||
A logical chunk is a sequence of tokens that are all part of the same sequence and are all in the same physical mamba chunk.
|
||||
In other words, a logical chunk changes every time we cross a sequence boundary or a physical mamba chunk boundary.
|
||||
Logical chunks are needed to handle batched requests with initial states (see _state_passing_fwd and _chunk_scan_fwd).
|
||||
The chunk_indices tensor contains the index of the physical chunk for each logical chunk.
|
||||
The chunk_offsets tensor contains the offset (AKA starting index) of the logical chunk in the physical chunk.
|
||||
|
||||
Example:
|
||||
cu_seqlens = [0, 5, 10]
|
||||
chunk_size = 8
|
||||
-> chunk_indices = [0, 1, 0]
|
||||
-> chunk_offsets = [0, 5, 0]
|
||||
|
||||
In this example, we have 2 sequences, each with 5 tokens. The physical chunk size is 8 tokens.
|
||||
We have three logical chunks:
|
||||
- the first logical chunk starts at token 0 in the first physical chunk and contains all 5 tokens from the first sequence
|
||||
- the second logical chunk starts at token 5 in the first physical chunk and contains first 3 tokens from the second sequence
|
||||
- the third logical chunk starts at token 0 in the second physical chunk and contains the remaining 2 tokens from the second sequence
|
||||
"""
|
||||
|
||||
total_seqlens = cu_seqlens[-1]
|
||||
cu_seqlens = cu_seqlens[1:] # remove prepended 0
|
||||
|
||||
# outputs will have length expansion of chunks that do not divide
|
||||
# chunk_size
|
||||
N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size
|
||||
> 0).sum()
|
||||
chunk_indices = torch.arange(N, dtype=torch.int, device=cu_seqlens.device)
|
||||
chunk_offsets = torch.zeros((N, ),
|
||||
dtype=torch.int,
|
||||
device=cu_seqlens.device)
|
||||
|
||||
p = 0 # num of insertions
|
||||
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
|
||||
|
||||
# if does not divide chunk_size, then there is one chunk insertion
|
||||
p += (s % chunk_size > 0)
|
||||
|
||||
# get the dimensions
|
||||
# - the + 1 for _e is to shift the boundary by one chunk
|
||||
# - this shifting is not needed if chunk_size divides e
|
||||
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size > 0)
|
||||
|
||||
# adjust inidces and offsets
|
||||
chunk_indices[_s:_e] -= p
|
||||
chunk_offsets[_s] = s % chunk_size
|
||||
|
||||
return chunk_indices, chunk_offsets
|
||||
|
||||
|
||||
class Mamba2Metadata:
|
||||
|
||||
def __init__(self, max_batch_size: int):
|
||||
def __init__(self, max_batch_size: int, chunk_size: int):
|
||||
self.max_batch_size = max_batch_size
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
# cumulative sequence lengths for prefill requests [batch_size+1]
|
||||
self.cu_seqlens = torch.zeros(max_batch_size + 1,
|
||||
@ -31,9 +99,18 @@ class Mamba2Metadata:
|
||||
# sequence index for prefill requests [num_prefill_tokens] - specifies which request each token belongs to
|
||||
self.seq_idx: torch.Tensor = None
|
||||
|
||||
# helper tensors for chunked prefill
|
||||
self.has_initial_states = torch.zeros(max_batch_size,
|
||||
dtype=torch.bool,
|
||||
device="cuda")
|
||||
self.use_initial_states = False
|
||||
self.chunk_indices: torch.Tensor = None
|
||||
self.chunk_offsets: torch.Tensor = None
|
||||
|
||||
def prepare(self, attn_metadata: AttentionMetadata):
|
||||
num_contexts = attn_metadata.num_contexts
|
||||
context_lens = attn_metadata.seq_lens_cuda[:num_contexts]
|
||||
num_ctx_tokens = attn_metadata.num_ctx_tokens
|
||||
if num_contexts > 0:
|
||||
torch.cumsum(context_lens,
|
||||
dim=0,
|
||||
@ -44,4 +121,17 @@ class Mamba2Metadata:
|
||||
dtype=torch.int,
|
||||
device=self.cu_seqlens.device),
|
||||
repeats=context_lens,
|
||||
output_size=self.cu_seqlens[num_contexts]).unsqueeze(0)
|
||||
output_size=num_ctx_tokens).unsqueeze(0)
|
||||
|
||||
num_cached_tokens_per_seq = attn_metadata.kv_cache_params.num_cached_tokens_per_seq
|
||||
self.has_initial_states[:num_contexts] = torch.tensor(
|
||||
num_cached_tokens_per_seq[:num_contexts]) > 0
|
||||
# precomputed bool to avoid host<->device syncs during forward pass
|
||||
self.use_initial_states = torch.any(
|
||||
self.has_initial_states[:num_contexts]).item()
|
||||
if self.use_initial_states:
|
||||
self.chunk_indices, self.chunk_offsets = cu_seqlens_to_chunk_indices_offsets(
|
||||
self.cu_seqlens[:num_contexts + 1], self.chunk_size)
|
||||
else:
|
||||
self.chunk_indices = None
|
||||
self.chunk_offsets = None
|
||||
|
||||
@ -191,12 +191,15 @@ class Mamba2Mixer(nn.Module):
|
||||
|
||||
cu_seqlens = mamba_metadata.cu_seqlens[:num_prefills + 1]
|
||||
seq_idx = mamba_metadata.seq_idx
|
||||
has_initial_states = mamba_metadata.has_initial_states[:
|
||||
num_prefills]
|
||||
|
||||
xbc_p = causal_conv1d_fn(xbc_p.transpose(0, 1),
|
||||
self.conv1d.weight,
|
||||
self.conv1d.bias,
|
||||
activation="silu",
|
||||
conv_states=conv_states,
|
||||
has_initial_state=has_initial_states,
|
||||
query_start_loc=cu_seqlens,
|
||||
cache_indices=state_indices_p).transpose(
|
||||
0, 1)
|
||||
@ -216,6 +219,12 @@ class Mamba2Mixer(nn.Module):
|
||||
"b l (h p) -> b l h p",
|
||||
h=self.tp_nheads)
|
||||
|
||||
initial_states = None
|
||||
if mamba_metadata.use_initial_states:
|
||||
initial_states = torch.where(
|
||||
has_initial_states[:, None, None, None],
|
||||
ssm_states[state_indices_p], 0)
|
||||
|
||||
y, current_ssm_states = mamba_chunk_scan_combined(
|
||||
x_p,
|
||||
dt_p,
|
||||
@ -226,7 +235,9 @@ class Mamba2Mixer(nn.Module):
|
||||
D=self.D,
|
||||
z=z_p,
|
||||
dt_bias=self.dt_bias,
|
||||
initial_states=None,
|
||||
initial_states=initial_states,
|
||||
chunk_indices=mamba_metadata.chunk_indices,
|
||||
chunk_offsets=mamba_metadata.chunk_offsets,
|
||||
dt_softplus=self.delta_softplus,
|
||||
cu_seqlens=cu_seqlens,
|
||||
seq_idx=seq_idx,
|
||||
|
||||
@ -314,11 +314,12 @@ def _chunk_scan_fwd_kernel(
|
||||
|
||||
# get the cs at the offset boundary
|
||||
# - c_off == 0 is a passthrough
|
||||
# - We need dA_cs at the boundary, defined by c_off - no need
|
||||
# to increase pointer by pid_m (it is a constant offset,
|
||||
# i.e. the same for all blocks)
|
||||
dA_cs_m_boundary = tl.load(
|
||||
dA_cumsum_ptr +
|
||||
(pid_m * BLOCK_SIZE_M + c_off - 1) * stride_dA_cs_csize,
|
||||
mask=(((pid_m * BLOCK_SIZE_M + c_off - 1) > -1)
|
||||
and ((pid_m * BLOCK_SIZE_M + c_off) < chunk_size)),
|
||||
dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize,
|
||||
mask=(((c_off - 1) > -1) and (c_off < chunk_size)),
|
||||
other=0.0).to(tl.float32)
|
||||
|
||||
if HAS_SEQ_IDX:
|
||||
|
||||
@ -110,21 +110,24 @@ def _mamba_chunk_scan_combined_fwd(
|
||||
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
|
||||
# (middle term of factorization of off-diag blocks; A terms)
|
||||
# - for handling chunked prefill, this requires i) initial_states
|
||||
# ii) seq_idx and iii) is_cont_batched to be all specified.
|
||||
# ii) seq_idx iii) is_cont_batched and (iv) chunk_offsets to be all specified.
|
||||
# - When a new seq_idx is detected, we will stop passing the prev_state
|
||||
# and switch accordingly to the init_state corresponding to the new seq_idx.
|
||||
# - We will also make sure that the dA_cumsum is taken only from the start of the
|
||||
# sequence (hence we need the full dA_cumsum tensor and not just the values at chunk boundaries)
|
||||
# - this will ensure that states will be updated with the rightmost flushed seq_idx
|
||||
# of the previous chunk. This implies that the first chunk of states is either 0
|
||||
# or equal to init_states of the first example.
|
||||
states, final_states = _state_passing_fwd(
|
||||
rearrange(states, "... p n -> ... (p n)"),
|
||||
dA_cumsum[:, :, :, -1],
|
||||
dA_cumsum,
|
||||
initial_states=(rearrange(initial_states, "... p n -> ... (p n)")
|
||||
if initial_states is not None else None),
|
||||
seq_idx=seq_idx,
|
||||
chunk_size=chunk_size,
|
||||
out_dtype=mamba_ssm_cache_dtype or C.dtype,
|
||||
is_cont_batched=cu_seqlens is not None)
|
||||
is_cont_batched=cu_seqlens is not None,
|
||||
chunk_offsets=chunk_offsets)
|
||||
states, final_states = [
|
||||
rearrange(t, "... (p n) -> ... p n", n=dstate)
|
||||
for t in [states, final_states]
|
||||
|
||||
@ -41,6 +41,8 @@ def _state_passing_fwd_kernel(
|
||||
dA_cs_ptr,
|
||||
initstates_ptr,
|
||||
seq_idx_ptr,
|
||||
chunk_offsets_ptr,
|
||||
chunk_meta_num,
|
||||
# Matrix dimensions
|
||||
dim,
|
||||
nchunks,
|
||||
@ -61,6 +63,7 @@ def _state_passing_fwd_kernel(
|
||||
stride_dA_cs_batch,
|
||||
stride_dA_cs_chunk,
|
||||
stride_dA_cs_head,
|
||||
stride_dA_cs_csize,
|
||||
stride_initstates_batch,
|
||||
stride_initstates_head,
|
||||
stride_initstates_dim,
|
||||
@ -76,7 +79,8 @@ def _state_passing_fwd_kernel(
|
||||
pid_h = tl.program_id(axis=2)
|
||||
pid_m = tl.program_id(axis=0)
|
||||
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
|
||||
dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head
|
||||
dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + (
|
||||
chunk_size - 1) * stride_dA_cs_csize
|
||||
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
|
||||
final_states_ptr += (pid_b * stride_final_states_batch +
|
||||
pid_h * stride_final_states_head)
|
||||
@ -105,35 +109,63 @@ def _state_passing_fwd_kernel(
|
||||
other=0.0).to(tl.float32)
|
||||
tl.store(out_ptrs, states, mask=offs_m < dim)
|
||||
out_ptrs += stride_out_chunk
|
||||
seq_idx = 0
|
||||
prev_seq_idx_chunk_end = 0
|
||||
logical_chunk_idx = 0
|
||||
for c in range(nchunks):
|
||||
new_states = tl.load(states_ptrs, mask=offs_m < dim,
|
||||
other=0.0).to(tl.float32)
|
||||
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
|
||||
scale = tl.exp(dA_cs)
|
||||
scale_mask = True
|
||||
if HAS_SEQ_IDX:
|
||||
# - the seq to pass forward is the one that is flushed to the right
|
||||
# boundary.
|
||||
# - that is given by seq_idx_new below.
|
||||
seq_idx_new = tl.load(seq_idx_ptr +
|
||||
(min((c + 1) * chunk_size, seqlen) - 1) *
|
||||
stride_seq_idx_seqlen)
|
||||
# - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk.
|
||||
seq_idx_chunk_end = tl.load(seq_idx_ptr + (min(
|
||||
(c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen)
|
||||
if HAS_INITSTATES:
|
||||
if IS_CONT_BATCHED and seq_idx != seq_idx_new:
|
||||
if IS_CONT_BATCHED and prev_seq_idx_chunk_end != seq_idx_chunk_end:
|
||||
# this means in the current chunk the rightmost flushed seq
|
||||
# has changed.
|
||||
# - so we do not propagate the state from previous chunk
|
||||
# - but rather we load that sequence's init state
|
||||
initstates_ptrs = initstates_ptr + seq_idx_new * stride_initstates_batch
|
||||
initstates_ptrs = initstates_ptr + seq_idx_chunk_end * stride_initstates_batch
|
||||
|
||||
# - update state with seq_idx_new's init state
|
||||
states = tl.load(initstates_ptrs,
|
||||
mask=offs_m < dim,
|
||||
other=0.0).to(tl.float32)
|
||||
else:
|
||||
scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)
|
||||
|
||||
seq_idx = seq_idx_new
|
||||
# - we need to consider the cumsum only of the last sequence in the chunk
|
||||
# - find its starting position (given by c_off of the logical chunk index)
|
||||
# - and subtract the cumsum just before that position from the total cumsum
|
||||
# - first, update the logical chunk index (add the number of sequences in the current physical chunk):
|
||||
# sequence index at the start of the current chunk
|
||||
seq_idx_chunk_start = tl.load(seq_idx_ptr +
|
||||
min(c * chunk_size, seqlen) *
|
||||
stride_seq_idx_seqlen)
|
||||
logical_chunk_idx += (seq_idx_chunk_end -
|
||||
seq_idx_chunk_start)
|
||||
# - load the chunk offset:
|
||||
c_off = tl.load(chunk_offsets_ptr + logical_chunk_idx,
|
||||
mask=logical_chunk_idx < chunk_meta_num,
|
||||
other=0)
|
||||
# - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything
|
||||
if c_off > 0:
|
||||
# - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset
|
||||
dA_cs_boundary = tl.load(
|
||||
dA_cs_ptr - (chunk_size - 1) * stride_dA_cs_csize +
|
||||
(c_off - 1) * stride_dA_cs_csize,
|
||||
mask=(c_off - 1) > -1 and c_off < chunk_size,
|
||||
other=0.0)
|
||||
dA_cs -= dA_cs_boundary
|
||||
|
||||
# - increment logical chunk index for every physical chunk
|
||||
logical_chunk_idx += 1
|
||||
else:
|
||||
scale_mask = seq_idx_chunk_end == prev_seq_idx_chunk_end
|
||||
prev_seq_idx_chunk_end = seq_idx_chunk_end
|
||||
|
||||
scale = tl.where(scale_mask, tl.exp(dA_cs), 0.0)
|
||||
states = scale * states + new_states
|
||||
if c < nchunks - 1:
|
||||
tl.store(out_ptrs, states, mask=offs_m < dim)
|
||||
@ -146,28 +178,36 @@ def _state_passing_fwd_kernel(
|
||||
|
||||
def _state_passing_fwd(
|
||||
states,
|
||||
dA_chunk_cumsum,
|
||||
dA_cumsum,
|
||||
initial_states=None,
|
||||
seq_idx=None,
|
||||
chunk_size=None,
|
||||
out_dtype=None,
|
||||
is_cont_batched=False,
|
||||
chunk_offsets=None,
|
||||
):
|
||||
batch, nchunks, nheads, dim = states.shape
|
||||
assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
|
||||
if chunk_size is None:
|
||||
chunk_size = dA_cumsum.shape[-1]
|
||||
else:
|
||||
assert chunk_size == dA_cumsum.shape[-1]
|
||||
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
|
||||
if initial_states is not None:
|
||||
if is_cont_batched:
|
||||
# - if cu_seqlens is provided, then the initial states
|
||||
# are used for continuous batching. In which case we
|
||||
# require seq_idx to be provided
|
||||
assert seq_idx is not None, ""
|
||||
assert seq_idx is not None, "seq_idx must be provided for continuous batching"
|
||||
# - we also need chunk_offsets to be provided, to account
|
||||
# for computation of dA_cumsum from the start of the
|
||||
# sequence
|
||||
assert chunk_offsets is not None, "chunk_offsets must be provided for continuous batching"
|
||||
else:
|
||||
# - this is the regular batching case, where initial
|
||||
# states are used are for each example of the batch.
|
||||
assert initial_states.shape == (batch, nheads, dim)
|
||||
|
||||
if seq_idx is not None:
|
||||
assert chunk_size is not None
|
||||
seqlen = seq_idx.shape[-1]
|
||||
assert seq_idx.shape == (batch, seqlen)
|
||||
out_dtype = states.dtype if out_dtype is None else out_dtype
|
||||
@ -183,13 +223,15 @@ def _state_passing_fwd(
|
||||
states,
|
||||
out,
|
||||
final_states,
|
||||
dA_chunk_cumsum,
|
||||
dA_cumsum,
|
||||
initial_states,
|
||||
seq_idx,
|
||||
chunk_offsets,
|
||||
len(chunk_offsets) if chunk_offsets is not None else 0,
|
||||
dim,
|
||||
nchunks,
|
||||
seqlen if seq_idx is not None else 0,
|
||||
chunk_size if seq_idx is not None else 0,
|
||||
chunk_size,
|
||||
states.stride(0),
|
||||
states.stride(1),
|
||||
states.stride(2),
|
||||
@ -201,9 +243,10 @@ def _state_passing_fwd(
|
||||
final_states.stride(0),
|
||||
final_states.stride(1),
|
||||
final_states.stride(2),
|
||||
dA_chunk_cumsum.stride(0),
|
||||
dA_chunk_cumsum.stride(2),
|
||||
dA_chunk_cumsum.stride(1),
|
||||
dA_cumsum.stride(0),
|
||||
dA_cumsum.stride(2),
|
||||
dA_cumsum.stride(1),
|
||||
dA_cumsum.stride(3),
|
||||
*((
|
||||
initial_states.stride(0),
|
||||
initial_states.stride(1),
|
||||
|
||||
@ -514,7 +514,8 @@ def create_py_executor_instance(
|
||||
resources[ResourceManagerType.PEFT_CACHE_MANAGER] = peft_cache_manager
|
||||
model_engine.set_lora_model_config(
|
||||
lora_config.lora_target_modules,
|
||||
lora_config.trtllm_modules_to_hf_modules)
|
||||
lora_config.trtllm_modules_to_hf_modules,
|
||||
lora_config.swap_gate_up_proj_lora_b_weight)
|
||||
|
||||
max_num_sequences = executor_config.max_batch_size * mapping.pp_size
|
||||
|
||||
|
||||
@ -16,7 +16,6 @@ from tensorrt_llm.mapping import CpType
|
||||
from ..distributed import Distributed
|
||||
from .llm_request import (ExecutorRequest, LlmRequest,
|
||||
executor_request_to_llm_request)
|
||||
from .sampler import Sampler, TorchSampler
|
||||
|
||||
SHUTDOWN_REQUEST_ID = -1
|
||||
|
||||
@ -707,21 +706,19 @@ class ExecutorRequestQueue:
|
||||
|
||||
def set_exclude_last_generation_logits(self,
|
||||
disable_overlap_scheduler: bool,
|
||||
sampler: Sampler) -> None:
|
||||
pp_size: int) -> None:
|
||||
# When overlap scheduler is enabled then when starting to handle a new prompt,
|
||||
# sample_async is called twice before the first call to update_requests:
|
||||
# - 1st time as a context request that handles on the 1st generated token
|
||||
# - 2nd time as a generation request that handles on the 2nd generated token.
|
||||
# and only after these two calls the sampler's update_request method is called.
|
||||
# So in a sampler that works by the expected flow of handling the logits in
|
||||
# sample_async (TorchSampler is an anomaly that instead does that on
|
||||
# update_requests), every update_request doesn't handle the newest token, but one
|
||||
# sample_async, every update_request doesn't handle the newest token, but one
|
||||
# before it. Since all these calls work on the same request object, then its
|
||||
# logits storage contains the logits of both the token update_requests should work
|
||||
# on, and also its next token. Thus, excluding the last generation logits from any
|
||||
# getter is required, when not using TorchSampler.
|
||||
self.should_exclude_last_generation_logits = not disable_overlap_scheduler and not isinstance(
|
||||
sampler, TorchSampler)
|
||||
# getter is required.
|
||||
self.should_exclude_last_generation_logits = not disable_overlap_scheduler and pp_size == 1
|
||||
|
||||
def _should_exclude_last_generation_logits(self) -> bool:
|
||||
return self.should_exclude_last_generation_logits
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user