TensorRT-LLMs/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/gemmCommon.h
chenfeiz0326 7f5716ef83
Cherry-pick trtllm-gen from feat/llama4 to main (#4086)
* feat: TRT-LLM Gen FP8 MoE Llama4

Signed-off-by: Nikita Korobov <nkorobov@nvidia.com>

* feat: TRT-LLM Gen llama4 MoE Top1 routing

Signed-off-by: Jiqun Tu <jtu@nvidia.com>

* feat: add per tensor FP8 TRT-LLM Gen GEMMs

Signed-off-by: Nikita Korobov <nkorobov@nvidia.com>

* Update

Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com>

* Update

Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com>

* Add license for cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/gemmCubins

Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com>

* Add guard for routingIndicesClusterKernel

Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com>

* Guard sm90+ for routingkernels

Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com>

* Guard sm90+ for routingkernels

Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com>

---------

Signed-off-by: Nikita Korobov <nkorobov@nvidia.com>
Signed-off-by: Jiqun Tu <jtu@nvidia.com>
Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com>
Co-authored-by: Nikita Korobov <nkorobov@nvidia.com>
Co-authored-by: Jiqun Tu <jtu@nvidia.com>
2025-05-08 14:13:01 -07:00

809 lines
34 KiB
C++

/*
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Force the order of headers to avoid definition not found error
// clang-format off
#include <numeric>
#include <set>
#include "trtllmGenSrc/KernelParams.h"
#include "trtllmGenSrc/Enums.h"
#include "trtllmGenSrc/KernelTraits.h"
#include "trtllmGenSrc/SfLayout.h"
#include "gemmList.h"
#include "tensorrt_llm/common/cudaDriverWrapper.h"
#include "tensorrt_llm/common/envUtils.h"
// clang-format on
#define TLLM_CHECK_ERROR_FMT(cond, ...) \
do \
{ \
if (!(cond)) \
{ \
TLLM_CHECK_WITH_INFO(false, "TRTLLM-GEN kernel launch failed"); \
} \
} while (0)
#define TLLM_CHECK_WARNING TLLM_CHECK_ERROR_FMT
namespace tensorrt_llm
{
namespace kernels
{
namespace trtllmGenFp8BlockScaleMoe
{
namespace gemmCommon
{
struct MyOptions
{
int mTopK;
bool mBatchM;
bool mTransposeMmaOutput;
bool mUseDeepSeekFp8;
bool mUseRouteAct;
bool mRouteAct;
bool mUseShuffledMatrixA;
int mTileM;
int mTileN;
int mTileK;
int mEpilogueTileM;
int mEpilogueTileN;
int mMmaM;
int mMmaN;
int mMmaK;
int mNumExperts;
int mNumTokens;
int mM;
int mN;
int mK;
tg::Dtype mDtypeElt;
tg::Dtype mDtypeC;
tg::Dtype mDtypeAcc;
bool mUseTmaStore;
bool mUseTwoTmaLoadWarps;
float mAtol;
float mRtol;
int mNumSlicesForSplitK;
int mClusterDimX;
int mClusterDimY;
int mClusterDimZ;
::gemm::AllReduceAlgo mAllReduceAlgo;
::gemm::SplitK mSplitK;
::gemm::KernelTraits mKernelTraits;
int mNumStages;
int mNumStagesMma;
bool mUseFusedAct;
bool mGateUseClusterSplitK;
int mProjGateNumSplitKSlices;
int32_t* mPtrNumNonExitingCtas{nullptr};
int32_t* mPtrTotalNumPaddedTokens{nullptr};
int32_t* mPtrCtaIdxXyToBatchIdx{nullptr};
int32_t* mPtrCtaIdxXyToMnLimit{nullptr};
tg::SfLayout mSfLayoutB;
tg::SfLayout mSfLayoutC;
bool mUseCustomLowLatencyImpl;
bool mSliceK;
int mNumSlicesForSliceK;
std::vector<int> mBatchedM;
std::vector<int> mBatchedN;
int mNumBatches;
bool mAllToAllRouteAct;
bool mIsStaticBatch;
bool mUsePerTokenSfA;
bool mUsePerTokenSfB;
};
void copyKernelInfoToOptions(GemmInfo const& kernelInfo, MyOptions& options)
{
options.mUseDeepSeekFp8 = kernelInfo.blockScale && kernelInfo.dtypeElt == tg::Dtype::E4m3;
options.mUseRouteAct = kernelInfo.permuteFusion;
options.mRouteAct = kernelInfo.permuteFusion;
options.mUseShuffledMatrixA = kernelInfo.shuffledMatrixA;
options.mTileM = kernelInfo.tileM;
options.mTileN = kernelInfo.tileN;
options.mTileK = kernelInfo.tileK;
options.mEpilogueTileM = kernelInfo.epilogueTileM;
options.mEpilogueTileN = kernelInfo.epilogueTileN;
options.mMmaM = kernelInfo.mmaM;
options.mMmaN = kernelInfo.mmaN;
options.mMmaK = kernelInfo.mmaK;
options.mNumSlicesForSplitK = kernelInfo.numSlicesForSplitK;
options.mDtypeElt = kernelInfo.dtypeElt;
options.mDtypeC = kernelInfo.dtypeC;
options.mDtypeAcc = kernelInfo.dtypeAcc;
options.mUseTmaStore = kernelInfo.useTmaStore;
options.mNumStagesMma = kernelInfo.numStagesMma;
options.mNumStages = kernelInfo.numStages;
options.mNumStagesMma = kernelInfo.numStagesMma;
options.mUseFusedAct = kernelInfo.useFusedAct;
options.mGateUseClusterSplitK = kernelInfo.gateUseClusterSplitK;
options.mProjGateNumSplitKSlices = kernelInfo.projGateNumSplitKSlices;
options.mSliceK = kernelInfo.sliceK;
options.mNumSlicesForSliceK = kernelInfo.mNumSlicesForSliceK;
options.mSfLayoutB = kernelInfo.mSfLayoutB;
options.mSfLayoutC = kernelInfo.mSfLayoutC;
options.mUsePerTokenSfA = kernelInfo.usePerTokenSfA;
options.mUsePerTokenSfB = kernelInfo.usePerTokenSfB;
}
namespace gemm
{
template <typename T>
inline std::string toString(T e)
{
return std::to_string(e);
}
template <>
inline std::string toString(trtllm::gen::Dtype e)
{
return trtllm::gen::dtypeToString(e);
}
inline int32_t divUp(int32_t a, int32_t b)
{
return (a + b - 1) / b;
}
inline int32_t getShuffleBlockSize(int epilogueTileM)
{
int shuffleBlockSize = 16;
if (epilogueTileM % 128 == 0)
{
shuffleBlockSize = 32;
}
return shuffleBlockSize;
}
using namespace ::gemm;
inline void checkAndUpdateGemmOptions(
MyOptions& options, bool isBlackwell, bool usesAtolFromArg, bool usesRtolFromArg, int /* tpGrpSize */)
{
if (options.mUseCustomLowLatencyImpl)
{
TLLM_CHECK_ERROR(isBlackwell, "LowLatencyGemm is only supported on Blackwell");
TLLM_CHECK_ERROR(options.mDtypeElt == tg::Dtype::E2m1, "LowLatencyGemm only supports nvfp4 (for now).");
TLLM_CHECK_ERROR(options.mDtypeAcc == tg::Dtype::Fp32, "LowLatencyGemm accumulates in fp32.");
TLLM_CHECK_ERROR(options.mTileN == 32 || options.mTileN == 64 || options.mTileN == 128,
"LowLatencyGemm supports tileN of 32, 64, 128..");
if (options.mTileN == 32)
{
TLLM_CHECK_ERROR(options.mNumSlicesForSplitK == 1 || options.mNumSlicesForSplitK == 2
|| options.mNumSlicesForSplitK == 4,
"LowLatencyGemm accumulates in fp32.");
}
else
{
TLLM_CHECK_ERROR(
options.mNumSlicesForSplitK == 1, "LowLatencyGemm does not use SplitK > 1 with tileN > 32.");
}
return;
}
if (options.mDtypeElt == tg::Dtype::E4m3 && options.mMmaK != 32)
{
TLLM_LOG_WARNING("Unsupported MmaK (", options.mMmaK, ") for ", gemm::toString(options.mDtypeElt).c_str(),
". Setting MmaK to 32");
options.mMmaK = 32;
options.mTileK = std::max(options.mMmaK, options.mTileK);
}
// NvFp4 constraints
if (options.mDtypeElt == tg::Dtype::E2m1)
{
TLLM_CHECK_ERROR(isBlackwell, "FP4 is only supported on Blackwell");
TLLM_CHECK_ERROR(options.mSfLayoutB == tg::SfLayout::R128c4 || options.mSfLayoutB == tg::SfLayout::R8c4,
"Only the 128x4 and 8x4 SF layouts are supported for B, got ", tg::sfLayoutToString(options.mSfLayoutB));
if (options.mMmaK != 64)
{
int newTileK = 64 * divUp(options.mTileK, 64);
TLLM_LOG_WARNING("Unsupported MmaK (", options.mMmaK, ") for ", gemm::toString(options.mDtypeElt).c_str(),
". Setting MmaK to 64 and TileK to ", newTileK);
options.mMmaK = 64;
options.mTileK = newTileK;
}
if (options.mMmaM != 128)
{
int newTileM = 128 * divUp(options.mTileM, 128);
TLLM_LOG_WARNING("Unsupported MmaM (", options.mMmaM, ") for ", gemm::toString(options.mDtypeElt).c_str(),
". Setting MmaM to 128 and TileM to ", newTileM);
options.mMmaM = 128;
options.mTileM = newTileM;
}
// TileN must be a multiple of the number of rows per SF tile.
// TODO: relax this restriction.
int const numSfTileRowsB = options.mSfLayoutB == tg::SfLayout::R128c4 ? 128 : 8;
TLLM_CHECK_ERROR(options.mTileN % numSfTileRowsB == 0, "TileN (", options.mTileN, ") must be a multiple of ",
numSfTileRowsB, " for B SF layout ", tg::sfLayoutToString(options.mSfLayoutB));
// The MMA N may only be smaller than 64 if it is equal to the tile N.
TLLM_CHECK_ERROR(options.mMmaN >= 64 || options.mMmaN == options.mTileN, "MmaN (", options.mMmaN,
") must be >= 64 or equal to TileN (", options.mTileN, ") for ", gemm::toString(options.mDtypeElt));
}
if (options.mDtypeC == tg::Dtype::E2m1)
{
TLLM_CHECK_ERROR(isBlackwell, "FP4 is only supported on Blackwell");
TLLM_CHECK_ERROR(options.mSfLayoutC == tg::SfLayout::R128c4 || options.mSfLayoutC == tg::SfLayout::R8c4,
"Only the 128x4 and 8x4 SF layouts are supported for C.");
int const numSfTileRowsC = options.mSfLayoutC == tg::SfLayout::R128c4 ? 128 : 8;
TLLM_CHECK_ERROR(options.mTileN % numSfTileRowsC == 0, "TileN (", options.mTileN, ") must be a multiple of ",
numSfTileRowsC, " for C SF layout ", tg::sfLayoutToString(options.mSfLayoutC));
int const hiddenDim = options.mTransposeMmaOutput ? options.mM : options.mN;
TLLM_CHECK_ERROR(hiddenDim % 64 == 0, "Hidden dim (", hiddenDim, ") must be a multiple of 64 for FP4 outputs.");
TLLM_CHECK_ERROR(!options.mTransposeMmaOutput || options.mUseShuffledMatrixA,
"Transposing FP4 outputs requires shuffled A.");
}
// If dtypeC is unspecified (Dtype::Void), assign to the input dtype.
if (options.mDtypeC == tg::Dtype::Void)
{
options.mDtypeC = options.mDtypeElt;
}
// Set epilogue tile sizes to the output tile sizes, when epilogue tile sizes are incorrect.
if (options.mTileM % options.mEpilogueTileM != 0)
{
TLLM_LOG_WARNING("TileM (", options.mTileM, ") must be divisible by EpilogueTileM (", options.mEpilogueTileM,
"). Setting EpilogueTileM to TileM");
options.mEpilogueTileM = options.mTileM;
}
if (options.mTileN % options.mEpilogueTileN != 0)
{
TLLM_LOG_WARNING("TileN (", options.mTileN, ") must be divisible by EpilogueTileN (", options.mEpilogueTileN,
"). Setting EpilogueTileN to TileN");
options.mEpilogueTileN = options.mTileN;
}
// On Hopper, epilogue tile sizes are the same as output tiles.
if (!isBlackwell)
{
options.mEpilogueTileM = options.mTileM;
options.mEpilogueTileN = options.mTileN;
TLLM_LOG_WARNING("Overwriting epilogueTileM and epilogueTileN to match tileM and tileN respectively");
}
// Unsupported epilogue tile size.
if (options.mMmaM == 128 && options.mEpilogueTileM != options.mTileM)
{
TLLM_LOG_WARNING("When MmaM = 128, EpilogueTileM must be equal to TileM. Setting EpilogueTileM to TileM");
options.mEpilogueTileM = options.mTileM;
}
TLLM_CHECK_ERROR(options.mM > 0 && options.mN > 0 && options.mK > 0, "M, N and K must be larger than 0");
TLLM_CHECK_ERROR(options.mNumSlicesForSplitK > 0, "Split K must be larger than 0.");
TLLM_CHECK_ERROR(options.mK % options.mNumSlicesForSplitK == 0, "K must be divisible by NumSlicesForSplitK.");
TLLM_CHECK_ERROR((options.mK / options.mNumSlicesForSplitK) % options.mTileK == 0,
"K / NumSlicesForSplitK must be divisible by TileK. Found TileK=", options.mTileK, " and K=", options.mK,
" and NumSlicesForSplitK=", options.mNumSlicesForSplitK);
if (options.mUseShuffledMatrixA)
{
auto const shuffleBlockSize = getShuffleBlockSize(options.mEpilogueTileM);
TLLM_CHECK_ERROR(options.mM % shuffleBlockSize == 0, "M must be a multiple of shuffle block size (",
shuffleBlockSize, ") when useShuffledMatrixA");
}
TLLM_CHECK_ERROR(options.mMmaM <= options.mEpilogueTileM && options.mMmaN <= options.mEpilogueTileN,
"EpilogueTileM and EpilogueTileN must be larger or equal than the respective atom sizes.");
TLLM_CHECK_ERROR(options.mTileM % options.mEpilogueTileM == 0 && options.mTileN % options.mEpilogueTileN == 0,
"TileM and TileN must be divisible by EpilogueTileM and EpilogueTileN respectively.");
TLLM_CHECK_ERROR(
options.mClusterDimX == 1 && options.mClusterDimY == 1, "GEMM does not support cluster in X and Y dimensions.");
TLLM_CHECK_ERROR(
options.mClusterDimZ == 1 || options.mNumSlicesForSplitK > 1, "Cluster DimZ is only allowed for split-k.");
TLLM_CHECK_ERROR(options.mTileM <= 128, "GEMM does not support TileM > 128.");
// Force the tolerance to a slightly higher value for FP16/BF16.
if (options.mDtypeElt == tg::Dtype::Fp16 || options.mDtypeElt == tg::Dtype::Bfloat16)
{
if (!usesAtolFromArg)
{
options.mAtol = (options.mAllReduceAlgo != gemm::AllReduceAlgo::None) ? 4e-3f : 2e-4f;
}
}
// Force the tolerance to a slightly higher value DeepSeek with for FP16/BF16 output.
if (options.mUseDeepSeekFp8 && (options.mDtypeC == tg::Dtype::Fp16 || options.mDtypeC == tg::Dtype::Bfloat16))
{
if (!usesAtolFromArg)
{
options.mAtol = 1e-2f;
}
if (!usesRtolFromArg)
{
options.mRtol = 1e-2f;
}
}
// When the A-matrix is shuffled, the output must be transposed.
if (options.mUseShuffledMatrixA)
{
// TODO add matrix shuffle for N-major epilogue.
TLLM_CHECK_ERROR(options.mTransposeMmaOutput,
"Shuffled matrix A is only supported with M-major epilogue. Set -transposeMmaOutput");
}
// Check all-reduce options.
if (options.mAllReduceAlgo == AllReduceAlgo::OneShot)
{
// One shot is implemented with PTX cp.reduce.async.bulk.tensor which supports only the
// following types for reduce add: u32, s32, u64, f32, f16, bf16.
//
// See: https://docs.nvidia.com/cuda/parallel-thread-execution/
// #data-movement-and-conversion-instructions-cp-reduce-async-bulk-tensor
std::set<tg::Dtype> dtypeSupported{tg::Dtype::UInt32, tg::Dtype::Int32, tg::Dtype::UInt64, tg::Dtype::Fp32,
tg::Dtype::Fp16, tg::Dtype::Bfloat16};
TLLM_CHECK_ERROR(dtypeSupported.find(options.mDtypeC) != dtypeSupported.end(), "Unsupported output dtype ",
tg::dtypeToString(options.mDtypeC));
}
else if (options.mAllReduceAlgo == AllReduceAlgo::TwoShot)
{
// TODO(anchengc):
// Input dtype == output dtype -> can perform all-reduce in-place.
// Input dtype != output dtype -> must perform all-reduce out of place.
TLLM_CHECK_ERROR_FMT(options.mDtypeC == options.mDtypeAcc,
"Not implemented - mixed dtype (dtypeC (%s) != dtypeAcc (%s)) requires out of place update",
tg::dtypeToString(options.mDtypeC).c_str(), tg::dtypeToString(options.mDtypeAcc).c_str());
}
if (options.mAllReduceAlgo != AllReduceAlgo::None)
{
TLLM_CHECK_ERROR(options.mUseTmaStore, "Non-TMA store with all-reduce is not implemented");
}
if (options.mNumSlicesForSplitK == 1)
{
// No split-k.
options.mSplitK = SplitK::None;
}
else if (options.mNumSlicesForSplitK > 1 && options.mClusterDimZ == 1)
{
// Split-k with exchange through gmem.
options.mSplitK = SplitK::Gmem;
}
else
{
// Split-k with exchange through Dsmem.
options.mSplitK = SplitK::Dsmem;
}
// For GMEM-based split-K, we write 4 elements at once.
if (options.mSplitK == SplitK::Gmem)
{
TLLM_CHECK_ERROR((options.mM * options.mN) % 4 == 0, "M * N must be a multiple of 4 for Split-K");
}
if (options.mNumSlicesForSplitK > 1)
{
if ((options.mEpilogueTileM != options.mTileM || options.mEpilogueTileN != options.mTileN)
&& !options.mUseDeepSeekFp8)
{
options.mEpilogueTileM = options.mTileM;
options.mEpilogueTileN = options.mTileN;
TLLM_LOG_WARNING("Overwriting epilogueTileM and epilogueTileN to match tileM and tileN respectively");
}
}
if (options.mSplitK == SplitK::Dsmem)
{
TLLM_CHECK_ERROR(options.mClusterDimZ == options.mNumSlicesForSplitK,
"CGA size must be equal to the number of slices in split-k");
}
// DeepSeek Fp8
if (!options.mUseDeepSeekFp8)
{
TLLM_CHECK_ERROR(options.mNumStagesMma == 1, "Non-DeepSeekFp8 requires numStagesMma == 1");
}
if (options.mUseDeepSeekFp8)
{
TLLM_CHECK_ERROR(options.mDtypeElt == tg::Dtype::E4m3, "A and B dtype must be E4m3 for DeepSeek Fp8. Found ",
tg::dtypeToString(options.mDtypeElt));
TLLM_CHECK_ERROR(isBlackwell, "DeepSeek Fp8 is not supported for Hopper");
TLLM_CHECK_ERROR(options.mAllReduceAlgo == AllReduceAlgo::None, "DeepSeek Fp8 does not support AllReduce");
// Check that TileK = 128 for correct scaling of every 128 channels.
TLLM_CHECK_ERROR(options.mTileK == 128, "Tile-K must be equal to 128 for DeepSeek Fp8");
// Tile sizes of the output hidden dimension.
auto hiddenDim = options.mTransposeMmaOutput ? options.mM : options.mN;
auto hiddenDimPerOutputTile = options.mTransposeMmaOutput ? options.mTileM : options.mTileN;
auto hiddenDimPerEpilogueTile = options.mTransposeMmaOutput ? options.mEpilogueTileM : options.mEpilogueTileN;
auto hiddenDimPerMma = options.mTransposeMmaOutput ? options.mMmaM : options.mMmaN;
// auto hiddenDimName = options.mTransposeMmaOutput ? "M" : "N";
// TLLM_CHECK_WARNING(options.mNumStagesMma > 1, "DeepSeekFp8 recommends >1 MMA accumulator stages.");
// Update the number of stages of the MMA accumulator pipeline. TODO: enable by default for
// deepseek.
// options.mNumStagesMma = 2;
// Use two MMA warps to reduce mbar trywait latency. TODO: enable by default for deepseek.
// options.mUseTwoMmaWarps = true;
// Make sure the GEMM-M/N dimension is a multiple of 128 when using DeepSeek FP8.
TLLM_CHECK_ERROR(hiddenDim % 128 == 0, "GEMM-", hiddenDimName,
" must be a multiple of 128 when using DeepSeek Fp8. Found ", hiddenDim);
// Make sure the GEMM-K dimension is a multiple of 128 when using DeepSeek FP8.
TLLM_CHECK_ERROR(
options.mK % 128 == 0, "GEMM-K must be a multiple of 128 when using DeepSeek Fp8. Found ", options.mK);
// Check that the output tile N can be processed with the epilogue tile granularity.
TLLM_CHECK_ERROR((hiddenDimPerOutputTile / 2) % hiddenDimPerEpilogueTile == 0, "DeepSeek Fp8 requires Tile",
hiddenDimName, " / 2 (", hiddenDimPerOutputTile / 2, ") being a multiple of EpilogueTile", hiddenDimName,
" (", hiddenDimPerEpilogueTile, ")");
// Check that the output tile N can be processed with the epilogue tile granularity.
TLLM_CHECK_ERROR((hiddenDimPerOutputTile / 2) % hiddenDimPerMma == 0, "DeepSeek Fp8 requires Tile",
hiddenDimName, " / 2 (", hiddenDimPerOutputTile / 2, ") being a multiple of mma", hiddenDimName, " (",
hiddenDimPerMma, ")");
}
TLLM_CHECK_ERROR((options.mK / options.mNumSlicesForSplitK) % (options.mTileK * 2) == 0,
"Size K / splitK must be a multiple of TileK * 2. Found TileK=", options.mTileK, " and K=", options.mK,
" and numSlicesForSplitK=", options.mNumSlicesForSplitK);
if (options.mSliceK)
{
TLLM_CHECK_ERROR(isBlackwell, "Slice-K is not supported on Hopper");
TLLM_CHECK_ERROR(!options.mUseDeepSeekFp8, "DeepSeek Fp8 GEMM is not supported for slice-K");
TLLM_CHECK_ERROR(options.mUseTwoTmaLoadWarps, "Slice-K requires two warp load for A and B");
TLLM_CHECK_ERROR(options.mTransposeMmaOutput, "Slice-K requires transpose mma output");
TLLM_CHECK_ERROR(options.mUseShuffledMatrixA, "Slice-K requires shuffled matrix A");
TLLM_CHECK_ERROR(options.mTileK % 128 == 0, "Slice-K requires TileK be a multiple of 128");
TLLM_CHECK_ERROR(options.mMmaM == 128, "Slice-K requires MmaM == 128");
TLLM_CHECK_ERROR(options.mTileN == options.mEpilogueTileN, "TileN must be equal to EpilogueTileN for slice-K");
TLLM_LOG_WARNING("Overwriting TileM and EpilogueTileM to 32 for slice-K");
// FIXME: it is possible to remove this restriction.
options.mTileM = 32;
options.mEpilogueTileM = 32;
TLLM_CHECK_ERROR(options.mDtypeElt == tg::Dtype::E4m3, "Slice-K requires e4m3 input dtype");
options.mNumSlicesForSliceK = 4;
TLLM_CHECK_ERROR((options.mTileK / options.mMmaK) % options.mNumSlicesForSliceK == 0, "TileK (", options.mTileK,
") / MmaK (", options.mMmaK, ") must be a multiple of mNumSlicesForSliceK (", options.mNumSlicesForSliceK,
")");
}
// Init kernel traits.
options.mKernelTraits = KernelTraits(options.mDtypeElt, options.mDtypeC, options.mDtypeAcc, options.mTileM,
options.mTileN, options.mTileK, options.mEpilogueTileM, options.mEpilogueTileN, options.mNumStages,
options.mNumStagesMma, options.mNumSlicesForSplitK, options.mNumSlicesForSliceK, options.mSplitK,
options.mUseTmaStore, options.mTransposeMmaOutput, options.mAllReduceAlgo, options.mUseDeepSeekFp8);
}
} // namespace gemm
namespace gemmGatedAct
{
// Check if the options are valid or not.
inline void checkAndUpdateGemmGatedActOptions(
MyOptions& options, bool isBlackwell, bool usesAtolFromArg, bool usesRtolFromArg)
{
// tmpOut is already transposed at this stage
auto const hiddenSizeStr = options.mTransposeMmaOutput ? "M" : "N";
auto const hiddenSize = options.mTransposeMmaOutput ? options.mM : options.mN;
auto const hiddenEpilogueTileSize = options.mTransposeMmaOutput ? options.mEpilogueTileM : options.mEpilogueTileN;
TLLM_CHECK_WITH_INFO(hiddenSize % 2 == 0, "%s must be a multiple of 2.", hiddenSizeStr);
TLLM_CHECK_WITH_INFO((options.mTransposeMmaOutput ^ options.mUseShuffledMatrixA) == 0,
"Transpose mma output can only be used with shuffled A matrix. And vice versa.");
if (options.mUseTmaStore)
{
TLLM_CHECK_WITH_INFO(
hiddenEpilogueTileSize * trtllm::gen::dtypeGetNumBits(options.mDtypeElt) / /* bits */ 8 % 32 == 0,
"Unsupported output hidden tile size");
}
if (options.mUseDeepSeekFp8)
{
TLLM_CHECK_WITH_INFO(hiddenSize % 256 == 0, "Output hidden size must be a multiple of 256");
}
gemm::checkAndUpdateGemmOptions(options, isBlackwell, usesAtolFromArg, usesRtolFromArg,
/* tpGrpSize */ 1);
}
} // namespace gemmGatedAct
namespace batchedGemm
{
void checkAndUpdateGemmOptions(MyOptions& options, bool isBlackwell, bool usesAtolFromArg, bool usesRtolFromArg)
{
if (options.mUseFusedAct)
{
gemmGatedAct::checkAndUpdateGemmGatedActOptions(options, isBlackwell, usesAtolFromArg, usesRtolFromArg);
}
else
{
gemm::checkAndUpdateGemmOptions(options, isBlackwell, usesAtolFromArg, usesRtolFromArg, 1 /* tpGrpSize */);
}
bool batchM = options.mBatchM;
if (batchM)
{
if (options.mBatchedM.empty())
{
options.mBatchedM.push_back(128);
options.mBatchedM.push_back(256);
}
options.mNumBatches = options.mBatchedM.size();
}
else
{
if (options.mBatchedN.empty())
{
options.mBatchedN.push_back(128);
options.mBatchedN.push_back(256);
}
options.mNumBatches = options.mBatchedN.size();
}
for (int b = 0; b < options.mNumBatches; b++)
{
if (batchM)
{
TLLM_CHECK_ERROR(options.mN > 0 && options.mK > 0, "N and K must be larger than 0");
TLLM_CHECK_ERROR(options.mN >= options.mTileN && options.mK >= options.mTileK,
"N and K must be equal or larger than TileN and TileK respectively.");
TLLM_CHECK_ERROR(options.mN % options.mTileN == 0 && options.mK % options.mTileK == 0,
"N and K must be divisible by TileN and TileK respectively.");
TLLM_CHECK_ERROR(!options.mTransposeMmaOutput, "When batchM the MMA output has to be in row-major.");
}
else
{
TLLM_CHECK_ERROR(options.mM > 0 && options.mK > 0, "M and K must be larger than 0");
TLLM_CHECK_ERROR(options.mM >= options.mTileM && options.mK >= options.mTileK,
"N and K must be equal or larger than tileN and tileK respectively.");
TLLM_CHECK_ERROR(options.mM % options.mTileM == 0 && options.mK % options.mTileK == 0,
"M and K must be divisible by TileM and TileK respectively.");
TLLM_CHECK_ERROR(options.mTransposeMmaOutput, "When batchN the MMA output has to be in column-major.");
}
}
if (options.mUseDeepSeekFp8)
{
if (batchM)
{
// Make sure the GEMM-K dimension is a multiple of 128 when using DeepSeek FP8.
TLLM_CHECK_ERROR(
options.mN % 128 == 0, "GEMM-N must be a multiple of 128 when using DeepSeek Fp8. Found ", options.mN);
}
else
{
// Make sure the GEMM-K dimension is a multiple of 128 when using DeepSeek FP8.
TLLM_CHECK_ERROR(
options.mM % 128 == 0, "GEMM-N must be a multiple of 128 when using DeepSeek Fp8. Found ", options.mN);
}
// Make sure the GEMM-K dimension is a multiple of 128 when using DeepSeek FP8.
TLLM_CHECK_ERROR(
options.mK % 128 == 0, "GEMM-K must be a multiple of 128 when using DeepSeek Fp8. Found ", options.mK);
TLLM_CHECK_ERROR(options.mDtypeC != tg::Dtype::E2m1 && options.mDtypeElt != tg::Dtype::E2m1,
"E2m1 is not supported with DeepSeek FP8");
}
if (options.mAllToAllRouteAct)
{
if (batchM)
{
TLLM_CHECK_ERROR(options.mNumTokens <= options.mTileM, "Max number of tokens per expert (",
options.mNumTokens, ") must be smaller than TileM (", options.mTileN, ")");
}
else
{
TLLM_CHECK_ERROR(options.mNumTokens <= options.mTileN, "Max number of tokens per expert (",
options.mNumTokens, ") must be smaller than TileN (", options.mTileN, ")");
}
}
TLLM_CHECK_ERROR(options.mUseTmaStore, "Only TMA store is supported.");
if (options.mAllToAllRouteAct && options.mRouteAct)
{
// Turning off mRouteAct when we do mAllToAllRouteAct
TLLM_LOG_INFO("Turning off mRouteAct when we do mAllToAllRouteAct");
options.mRouteAct = false;
}
}
} // namespace batchedGemm
class BatchedGemmData
{
public:
int mNumCtaX;
int mNumCtaY;
int mNumCtaZ;
int mClusterDimX;
int mClusterDimY;
int mClusterDimZ;
void* mA;
void* mB;
void* mC;
float* mScaleC;
float* mScaleGate;
void* mSfA;
void* mSfB;
void* mSfC;
int32_t const* mRouteMap;
float* mDqSfsTokens;
void* mPerTokenSfB;
// Pointer for partial row max for DeepSeek computation.
float* mPtrPartialRowMax;
// Flags in global memory that sync on "exit" for row max computation.
uint32_t* mPtrRowMaxCompletionBars;
// Number of CTAs that do not exit early. For when the workload is smaller than the CTA grid.
int32_t* mPtrNumNonExitingCtas;
// Pointer to total number of padded tokens
int32_t* mPtrTotalNumPaddedTokens;
// Pointer to CTA index X/Y to batch index
int32_t* mPtrCtaIdxXyToBatchIdx;
// Pointer to CTA index X/Y to tile index **expanded** M/N for batched dimension
int32_t* mPtrCtaIdxXyToMnLimit;
};
void setSingleBatchedGemmData(void* A, void* B, void* C, float* scaleC, float* scaleGate, float* dqSfsA, float* dqSfsB,
float* dqSfsC, void* sfA, void* sfB, void* sfC, int32_t* permutedIdxToTokenIdx, float* ptrPartialRowMax,
uint32_t* ptrRowMaxCompletionBars, void* expertWeights,
// const bool projUp,
int const numSlicesForSplitK, MyOptions& args, BatchedGemmData& data, int32_t maxNumPaddedTokens)
{
data.mA = A;
data.mB = B;
data.mC = C;
if (args.mUseDeepSeekFp8)
{
data.mSfA = dqSfsA;
data.mSfB = dqSfsB;
data.mSfC = dqSfsC;
// Avoid illegal read when compiling with debug info
data.mScaleC = dqSfsC;
data.mScaleGate = dqSfsC;
}
else if (args.mDtypeElt == tg::Dtype::E4m3)
{
data.mScaleC = scaleC;
data.mScaleGate = scaleGate;
}
else // dtypeElt == e2m1
{
data.mSfA = sfA;
data.mSfB = sfB;
data.mSfC = sfC;
data.mScaleC = scaleC;
data.mScaleGate = scaleGate;
}
if (args.mUsePerTokenSfA || args.mUsePerTokenSfB)
{
data.mPerTokenSfB = expertWeights;
}
if (args.mUseRouteAct)
{
data.mRouteMap = permutedIdxToTokenIdx;
}
// Used in fused activation, will just be nullptrs if not being used
data.mPtrPartialRowMax = ptrPartialRowMax;
data.mPtrRowMaxCompletionBars = ptrRowMaxCompletionBars;
data.mPtrNumNonExitingCtas = args.mPtrNumNonExitingCtas;
/*
const int tileN = (projUp) ? args.mProjUpTileN : args.mProjDownTileN;
const int tileM = (projUp) ? args.mProjUpTileM : args.mProjDownTileM;
const int m = (projUp) ? 2 * args.mIntermediateDim : args.mHiddenDim;
*/
int const tileN = args.mTileN;
int const tileM = args.mTileM;
int const m = args.mM;
// The number of tokens per expert and number of active experts are known only at runtime. We
// launch the max possible number of CTAs and use ptrNumNonExitingCtas to decide if a CTA must
// run (if its ID < this number) or early-exit.
// Get maximum number of CTAs in batch dim per expert.
auto maxCtasInBatchDimPerExpert = gemm::divUp(args.mNumTokens, tileN);
// Get maximum enabled experts.
auto maxEnabledExperts = std::min(args.mNumTokens * args.mTopK, args.mNumExperts);
// Get maximum number of CTAs in batch dim.
auto maxNumCtasInBatchDim = maxEnabledExperts * maxCtasInBatchDimPerExpert;
data.mNumCtaY = std::min(maxNumCtasInBatchDim, gemm::divUp(maxNumPaddedTokens, tileN));
data.mNumCtaX = gemm::divUp(m, tileM);
data.mNumCtaZ = numSlicesForSplitK;
data.mClusterDimX = 1;
data.mClusterDimY = 1;
data.mClusterDimZ = numSlicesForSplitK;
// Pointer to total number of padded tokens
data.mPtrTotalNumPaddedTokens = args.mPtrTotalNumPaddedTokens;
data.mPtrCtaIdxXyToBatchIdx = args.mPtrCtaIdxXyToBatchIdx;
data.mPtrCtaIdxXyToMnLimit = args.mPtrCtaIdxXyToMnLimit;
}
void launchGemmFromData(GemmInfo const& kernelInfo, MyOptions const& options, BatchedGemmData const& batchedGemmData,
CUstream stream, bool usePDL = true)
{
std::shared_ptr<tensorrt_llm::common::CUDADriverWrapper> cuDriver(
tensorrt_llm::common::CUDADriverWrapper::getInstance());
CUmodule cuModule;
CUfunction cuFunction;
TLLM_CU_CHECK(cuDriver->cuModuleLoadData(&cuModule, kernelInfo.data));
TLLM_CU_CHECK(cuDriver->cuModuleGetFunction(&cuFunction, cuModule, kernelInfo.functionName));
if (kernelInfo.sharedMemSize >= 48 * 1024)
{
TLLM_CU_CHECK(cuDriver->cuFuncSetAttribute(
cuFunction, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, kernelInfo.sharedMemSize));
}
auto params = KernelParams::setKernelParams(options, batchedGemmData.mA, batchedGemmData.mB, batchedGemmData.mC,
batchedGemmData.mSfA, batchedGemmData.mSfB,
/* mPerTokenSfA */ nullptr, batchedGemmData.mPerTokenSfB, batchedGemmData.mSfC, batchedGemmData.mScaleC,
batchedGemmData.mScaleGate, batchedGemmData.mRouteMap, batchedGemmData.mPtrPartialRowMax,
batchedGemmData.mPtrRowMaxCompletionBars, batchedGemmData.mPtrNumNonExitingCtas,
batchedGemmData.mPtrTotalNumPaddedTokens, batchedGemmData.mPtrCtaIdxXyToBatchIdx,
batchedGemmData.mPtrCtaIdxXyToMnLimit);
CUlaunchConfig launch_config;
launch_config.blockDimX = kernelInfo.threadsPerCTA;
launch_config.blockDimY = 1;
launch_config.blockDimZ = 1;
launch_config.gridDimX = batchedGemmData.mNumCtaX;
launch_config.gridDimY = batchedGemmData.mNumCtaY;
launch_config.gridDimZ = batchedGemmData.mNumCtaZ;
launch_config.hStream = stream;
launch_config.sharedMemBytes = kernelInfo.sharedMemSize;
CUlaunchAttribute launch_attribute[3];
launch_attribute[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
launch_attribute[0].value.clusterDim.x = batchedGemmData.mClusterDimX;
launch_attribute[0].value.clusterDim.y = batchedGemmData.mClusterDimY;
launch_attribute[0].value.clusterDim.z = batchedGemmData.mClusterDimZ;
launch_attribute[1].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE;
launch_attribute[1].value.clusterSchedulingPolicyPreference = CU_CLUSTER_SCHEDULING_POLICY_DEFAULT;
launch_attribute[2].id = CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION;
launch_attribute[2].value.programmaticStreamSerializationAllowed = static_cast<int>(usePDL);
if (!usePDL)
{
launch_attribute[2].value.programmaticStreamSerializationAllowed = 0;
}
else
{
// Even if true, respect the PDL environment variable
launch_attribute[2].value.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
}
launch_config.attrs = launch_attribute;
launch_config.numAttrs = 3;
// std::cout << "kernelInfo.data: " << kernelInfo.data << std::endl;
// std::cout << "kernelInfo.functionName: " << kernelInfo.functionName << std::endl;
TLLM_CHECK_WITH_INFO(kernelInfo.paramsStructSize == sizeof(params), "Alignment issue detected");
void* kernelParamsList[] = {&params};
TLLM_CU_CHECK(cuDriver->cuLaunchKernelEx(&launch_config, cuFunction, kernelParamsList, nullptr));
}
} // namespace gemmCommon
} // namespace trtllmGenFp8BlockScaleMoe
} // namespace kernels
} // namespace tensorrt_llm