revert tlg kernels for ease of merge

Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com>
This commit is contained in:
Xiwen Yu 2025-08-19 11:44:36 +08:00
parent c1014e85cc
commit 4a95d88ce2
1948 changed files with 23857 additions and 8585 deletions

View File

@ -65,6 +65,8 @@ std::vector<int64_t> prioritizePredefinedConfigs(int m, int n, int k, std::vecto
//
// Dummy
//
// Qwen3_235B_TP8_EP1_MoE_FC2 m=4096 k=192
if (n /* out_dim */ == 0 && k /* in_dim */ == 0)
{
auto pred = [](BatchedGemmConfig const& config)
@ -100,13 +102,27 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(TrtllmGenBatchedGemmRunne
auto const options = configs[i].mOptions;
auto const tileSize = mOptions.transposeMmaOutput ? options.mTileN : options.mTileM;
// When we include low-latency kernels we can set transposeMmaOutput via constructor
if (options.mDtypeA == mOptions.eltType && options.mDtypeC == mOptions.outputType
&& options.mUseDeepSeekFp8 == mOptions.deepSeekFp8
if (options.mDtypeA == mOptions.dtypeA && options.mDtypeB == mOptions.dtypeB
&& options.mDtypeC == mOptions.dtypeC && options.mUseDeepSeekFp8 == mOptions.deepSeekFp8
&& options.mTransposeMmaOutput == mOptions.transposeMmaOutput
&& (!doesRouteImplUseNoRoute(options.mRouteImpl)) == mOptions.routeAct
&& options.mFusedAct == mOptions.fusedAct && options.mIsStaticBatch == mOptions.staticBatch
&& tileSize == mOptions.tileSize)
{
// FIXME: Disable split-k for now.
if (options.mClusterDimZ != 1)
{
continue;
}
if (options.mFusedAct)
{
if (options.mActType != static_cast<batchedGemm::gemmGatedAct::ActType>(mOptions.actType))
{
continue;
}
}
if (mOptions.transposeMmaOutput && options.mEpilogueTileM == mOptions.epilogueTileM)
{
mPassingConfigIndices.push_back(i);
@ -146,9 +162,10 @@ size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes(int32_t m, int32_t n,
void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens,
int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim, void const* a, void const* sfA, void const* b,
void const* sfB, void const* perTokensSfA, void const* perTokensSfB, float const* scaleC, float const* scaleGateC,
void* c, void* outSfC, int32_t const* routeMap, int32_t const* totalNumPaddedTokens,
int32_t const* ctaIdxXyToBatchIdx, int32_t const* ctaIdxXyToMnLimit, int32_t const* numNonExitingCtas,
void* workspace, CUstream stream, int device, int32_t configIndex)
float const* ptrBias, float const* ptrAlpha, float const* ptrBeta, float const* ptrClampLimit, void* c,
void* outSfC, int32_t const* routeMap, int32_t const* totalNumPaddedTokens, int32_t const* ctaIdxXyToBatchIdx,
int32_t const* ctaIdxXyToMnLimit, int32_t const* numNonExitingCtas, void* workspace, CUstream stream, int device,
int32_t configIndex)
{
auto bmm = BatchedGemmInterface();
@ -200,6 +217,10 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
gemmData.mInputBuffers.mPtrScaleGate = scaleGateC;
gemmData.mInputBuffers.mPtrPerTokenSfA = mOptions.transposeMmaOutput ? perTokensSfB : perTokensSfA;
gemmData.mInputBuffers.mPtrPerTokenSfB = mOptions.transposeMmaOutput ? perTokensSfA : perTokensSfB;
gemmData.mInputBuffers.mPtrBias = ptrBias;
gemmData.mInputBuffers.mPtrSwiGluAlpha = ptrAlpha;
gemmData.mInputBuffers.mPtrSwiGluBeta = ptrBeta;
gemmData.mInputBuffers.mPtrClampLimit = ptrClampLimit;
gemmData.mInputBuffers.mPtrRouteMap = routeMap;
@ -242,7 +263,22 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
// Dispatch with block scaling factors and with static batching.
run(m, n, k, batchedTokens, /* numTokens */ 0, batchedTokens.size(), /* maxNumCtasInBatchDim */ 0, a, sfA, b, sfB,
/* perTokensSfA */ nullptr, /* perTokensSfB */ nullptr,
/* scaleC */ nullptr, /* scaleGateC */ nullptr, c, outSfC,
/* scaleC */ nullptr, /* scaleGateC */ nullptr, /* ptrBias */ nullptr, /* ptrAlpha */ nullptr,
/* ptrBeta */ nullptr, /* ptrClampLimit */ nullptr, c, outSfC,
/* routeMap */ nullptr, /* totalNumPaddedTokens */ nullptr,
/* ctaIdxXyToBatchIdx */ nullptr, /* ctaIdxXyToMnLimit */ nullptr,
/* numNonExitingCtas */ nullptr, workspace, stream, device, configIndex);
}
void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens,
void const* a, void const* sfA, void const* b, void const* sfB, float const* ptrBias, float const* ptrAlpha,
float const* ptrBeta, float const* ptrClampLimit, void* c, void* outSfC, void* workspace, CUstream stream,
int device, int32_t configIndex)
{
// Dispatch with block scaling factors and with static batching.
run(m, n, k, batchedTokens, /* numTokens */ 0, batchedTokens.size(), /* maxNumCtasInBatchDim */ 0, a, sfA, b, sfB,
/* perTokensSfA */ nullptr, /* perTokensSfB */ nullptr,
/* scaleC */ nullptr, /* scaleGateC */ nullptr, ptrBias, ptrAlpha, ptrBeta, ptrClampLimit, c, outSfC,
/* routeMap */ nullptr, /* totalNumPaddedTokens */ nullptr,
/* ctaIdxXyToBatchIdx */ nullptr, /* ctaIdxXyToMnLimit */ nullptr,
/* numNonExitingCtas */ nullptr, workspace, stream, device, configIndex);
@ -255,7 +291,9 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
// Dispatch with block scaling factors and with static batching.
run(m, n, k, batchedTokens, /* numTokens */ 0, batchedTokens.size(), /* maxNumCtasInBatchDim */ 0, a,
/* sfA */ nullptr, b, /* sfB */ nullptr, /* perTokensSfA */ nullptr, /* perTokensSfB */ nullptr, scaleC,
scaleGateC, c, /* outSfC */ nullptr,
scaleGateC, /* ptrBias */ nullptr, /* ptrAlpha */ nullptr, /* ptrBeta */ nullptr, /* ptrClampLimit */ nullptr,
c,
/* outSfC */ nullptr,
/* routeMap */ nullptr, /* totalNumPaddedTokens */ nullptr,
/* ctaIdxXyToBatchIdx */ nullptr, /* ctaIdxXyToMnLimit */ nullptr,
/* numNonExitingCtas */ nullptr, workspace, stream, device, configIndex);
@ -281,7 +319,6 @@ std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(int32_t m
gemmData.mProblemDimensions.mRank = 0;
gemmData.mProblemDimensions.mWorldSize = 1;
gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim;
// Tier 0: K < tileK, prefer higher efficiency.
auto cmpTier0 = [&configs, &gemmData](int64_t idx0, int64_t idx1)
{
@ -343,7 +380,6 @@ std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(int32_t m
}
return false;
};
// Sort configs by options.
std::vector<int64_t> sortedIndices = mPassingConfigIndices;
std::sort(sortedIndices.begin(), sortedIndices.end(), cmpTier0);

View File

@ -27,10 +27,28 @@ namespace tensorrt_llm
namespace kernels
{
// Keep this in sync with the ActType in
// cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/GemmGatedActOptions.h
enum class ActType
{
// For ActType == SwiGlu, ideally we would like to have something like
// gatedAct = scaleC * (x0 * scaleAb + beta) * ((x1 * scaleGate) * sigmoid(alpha * x1 *
// scaleGate)).
// But for now, we use the simplified version
// gatedAct = scaleC' * (x0 + beta') * ((x1 * scaleGate) * sigmoid(alpha * x1 * scaleGate)),
// where x0 and x1 are the raw numbers from Gemm, while scaleC and scaleGate are input scales,
// beta' = beta / scaleAb, scaleC' = scaleC * scaleAb.
//
// GatedSilu is a special case of SwiGlu where the alpha is 1.0 and the beta is 0.0.
SwiGlu
};
struct TrtllmGenBatchedGemmRunnerOptions
{
batchedGemm::trtllm::gen::Dtype eltType;
batchedGemm::trtllm::gen::Dtype outputType;
batchedGemm::trtllm::gen::Dtype dtypeA;
batchedGemm::trtllm::gen::Dtype dtypeB;
batchedGemm::trtllm::gen::Dtype dtypeC;
ActType actType{ActType::SwiGlu};
bool deepSeekFp8{false};
bool fusedAct{false};
bool routeAct{false};
@ -53,7 +71,8 @@ public:
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, int32_t numTokens,
int32_t numBatches, int32_t maxNumCtasInBatchDim, void const* a, void const* sfA, void const* b,
void const* sfB, void const* perTokensSfA, void const* perTokensSfB, float const* scaleC,
float const* scaleGateC, void* c, void* outSfC, int32_t const* routeMap, int32_t const* totalNumPaddedTokens,
float const* scaleGateC, float const* bias, float const* swiGluAlpha, float const* swiGluBeta,
float const* clampLimit, void* c, void* outSfC, int32_t const* routeMap, int32_t const* totalNumPaddedTokens,
int32_t const* ctaIdxXyToBatchIdx, int32_t const* ctaIdxXyToMnLimit, int32_t const* numNonExitingCtas,
void* workspace, CUstream stream, int device, int32_t configIndex);
@ -62,6 +81,11 @@ public:
void const* b, void const* sfB, void* c, void* outSfC, void* workspace, CUstream stream, int device,
int32_t configIndex);
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, void const* a, void const* sfA,
void const* b, void const* sfB, float const* bias, float const* swiGluAlpha, float const* swiGluBeta,
float const* clampLimit, void* c, void* outSfC, void* workspace, CUstream stream, int device,
int32_t configIndex);
// FP8 per-tensor scaling GEMM
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, void const* a, void const* b,
float const* scaleC, float const* scaleGateC, void* c, void* workspace, CUstream stream, int device,

View File

@ -247,35 +247,10 @@ struct BatchedGemmData
// The clamp limit for the accumulator before applying the activation.
// Shape is [B].
// Clamp is INF if nullptr.
// When the input is FP8 or NVFP4, the clamp has to be scaled by limit' = limit / dequantAb.
// If applied on SwiGlu, it will be:
//
// x_glu = x_glu.clamp(min=None, max=limit)
// x_linear = x_linear.clamp(min=-limit, max=limit)
//
// The given clamp limit applies to the dequantized values, so the order of operations would
// look something like this:
//
// x0 = x0 * dqAb
// x0 = clamp(x0, none, limit)
// x0 = x0 * sigmoid(alpha * x0)
// x1 = dqAb * x1
// x1 = clamp(x1, -limit, limit)
// out = qC * (x1 + beta) * x0
//
// Given that the dqAb and qC are combined into scaleC, we can bring the dqAb into the clamp
// limit and apply the clamping prior to dequantization:
//
// x0 = clamp(x0, none, limit / dqAb)
// x0 = x0 * dqAb
// x0 = x0 * sigmoid(alpha * x0)
// x1 = clamp(x1, -limit / dqAb, limit / dqAb)
// scaleC = dqAb * qC
// beta' = beta / dqAb
// out = scaleC * (x1 + beta') * x0
//
// Note this assumes that scaleAb == scaleGate which is true in TRT-LLM MoE use-case
//
float const* mPtrClampLimit{nullptr};
// The alpha and beta for SwiGlu.

View File

@ -99,7 +99,7 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions
int32_t sfReshapeFactor, gemm::TileScheduler tileScheduler, gemmGatedAct::ActType actType, bool clampBeforeAct,
std::vector<int> batchedM, std::vector<int> batchedN, BatchMode batchMode, int numBatches, bool isStaticBatch,
int numTokens, RouteImpl routeImpl, bool gridWaitForPrimaryRouting, bool fusedAct,
int numRegsPerThreadNonEpilogueWarp, int numRegsPerThreadEpilogueWarp, int numRegsCastAWarps, bool useTmaOobOpt)
int numRegsPerThreadNonEpilogueWarp, int numRegsPerThreadEpilogueWarp, int numRegsCastAWarps)
: gemmGatedAct::GemmGatedActOptions(
gemm::GemmOptions(allReduceAlgo, biasType, blockK, clusterDimX, clusterDimY, clusterDimZ, dtypeAcc, dtypeA,
dtypeB, dtypeC, dtypeMmaA, dtypeMmaB, enablesEarlyExit, enablesDelayedEarlyExit, enablesGlobalPtxKnobs,
@ -116,16 +116,15 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions
, mBatchedM(batchedM)
, mBatchedN(batchedN)
, mBatchMode(BatchMode(batchMode))
, mFusedAct(fusedAct)
, mGridWaitForPrimaryRouting(gridWaitForPrimaryRouting)
, mIsStaticBatch(isStaticBatch)
, mNumBatches(numBatches)
, mIsStaticBatch(isStaticBatch)
, mNumTokens(numTokens)
, mRouteImpl(routeImpl)
, mGridWaitForPrimaryRouting(gridWaitForPrimaryRouting)
, mFusedAct(fusedAct)
, mNumRegsPerThreadNonEpilogueWarp(numRegsPerThreadNonEpilogueWarp)
, mNumRegsPerThreadEpilogueWarp(numRegsPerThreadEpilogueWarp)
, mNumRegsCastAWarps(numRegsCastAWarps)
, mNumTokens(numTokens)
, mRouteImpl(routeImpl)
, mUseTmaOobOpt(useTmaOobOpt)
{
}
@ -135,28 +134,28 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions
std::vector<int> mBatchedN;
// Whether batching M or N.
BatchMode mBatchMode{BatchMode::BatchM};
// Whether to perform a fused gated activation.
bool mFusedAct{false};
// Number of Gemm batches.
int mNumBatches;
// Whether the batch size is static (i.e. known at kernel launch time).
bool mIsStaticBatch{true};
// Total number of tokens.
int mNumTokens{32};
// Whether load the input tokens and do routing.
RouteImpl mRouteImpl{RouteImpl::NoRoute};
// Whether the loads that load from ptrRouteMap, ptrTotalNumPaddedTokens,
// ptrCtaIdxXyToBatchIdx, etc.. should wait on a grid dependency.
bool mGridWaitForPrimaryRouting{true};
// Whether the batch size is static (i.e. known at kernel launch time).
bool mIsStaticBatch{true};
// Number of Gemm batches.
int mNumBatches;
// Whether to perform a fused gated activation.
bool mFusedAct{false};
// Number of registers per thread for non-epilogue warps
int mNumRegsPerThreadNonEpilogueWarp{0};
// Number of registers per thread for epilogue warps
int mNumRegsPerThreadEpilogueWarp{0};
// Number of registers for the cast A warps.
int mNumRegsCastAWarps{0};
// Total number of tokens.
int mNumTokens{32};
// Whether load the input tokens and do routing.
RouteImpl mRouteImpl{RouteImpl::NoRoute};
// Whether to use TMA out-of-bounds optimization to reduce wasted traffic. See details in
// BatchedGemm/KernelParamsDecl.h.
bool mUseTmaOobOpt{false};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -166,20 +165,6 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackw
{
bool isValid = true;
if (options.mUseTmaOobOpt && !options.mUseTwoTmaLoadWarps)
{
if (updateOptions)
{
// Since any routing (mRouteAct != NoRoute) requires mUseTwoTmaLoadWarps == true.
// Single TMA load warp is not the target use case for OOB optimization.
options.mUseTmaOobOpt = false;
}
else
{
TLLM_CHECK_ERROR(false, "TMA OOB optimization requires two TMA load warps.");
return false;
}
}
if (options.mFusedAct)
{
// ensure that we check the fused options as well
@ -382,8 +367,7 @@ inline std::string dumpOptions(BatchedGemmOptions const& options)
ss << "mFusedAct=" << options.mFusedAct << "," << std::endl;
ss << "mNumRegsPerThreadNonEpilogueWarp=" << options.mNumRegsPerThreadNonEpilogueWarp << "," << std::endl;
ss << "mNumRegsPerThreadEpilogueWarp=" << options.mNumRegsPerThreadEpilogueWarp << "," << std::endl;
ss << "mNumRegsCastAWarps=" << options.mNumRegsCastAWarps << "," << std::endl;
ss << "mUseTmaOobOpt=" << options.mUseTmaOobOpt << std::endl;
ss << "mNumRegsCastAWarps=" << options.mNumRegsCastAWarps << std::endl;
return ss.str();
}

View File

@ -179,7 +179,7 @@ inline std::string dumpOptions(GemmGatedActOptions const& options)
ss << gemm::dumpOptions(options) << ", ";
ss << "mActType="
<< "gemmGatedAct::ActType(" << static_cast<int32_t>(options.mActType) << ")," << std::endl;
ss << "mClampBeforeAct=" << options.mClampBeforeAct << "" << std::endl;
ss << "mClampLimit=" << options.mClampBeforeAct << "," << std::endl;
return ss.str();
}

View File

@ -527,7 +527,6 @@ inline int32_t getShuffleBlockSize(int epilogueTileM)
inline bool checkAndUpdateGemmOptions(
GemmOptions& options, bool isBlackwell, int /* tpGrpSize */, bool updateOptions = true)
{
if (options.mDtypeB == tg::Dtype::Void)
{
if (updateOptions)
@ -568,8 +567,7 @@ inline bool checkAndUpdateGemmOptions(
// Currently, we only support {MxFp4, NvFp4} -> Bf16.
TLLM_CHECK_ERROR((options.mDtypeA == options.mDtypeMmaA)
|| ((options.mDtypeA == tg::Dtype::MxE2m1 || options.mDtypeA == tg::Dtype::E2m1)
&& options.mDtypeMmaA == tg::Dtype::Bfloat16)
|| (options.mDtypeA == tg::Dtype::E2m1 && options.mDtypeMmaA == tg::Dtype::E4m3),
&& options.mDtypeMmaA == tg::Dtype::Bfloat16),
"Unsupported cast for A: ", tg::dtypeToString(options.mDtypeA), " -> ", tg::dtypeToString(options.mDtypeMmaA));
// Check that the B cast is supported.

View File

@ -18,7 +18,6 @@
#include "trtllm/gen/CommonUtils.h"
#include "trtllm/gen/SfLayoutDecl.h"
#include <stdexcept>
#include "BatchedGemmEnums.h"
#include "Enums.h"
@ -52,7 +51,11 @@ namespace tg = trtllm::gen;
namespace KernelParamsSetup
{
#ifdef TLLM_ENABLE_CUDA
//////////////////////////////////////////////////////////////////////////////////////////////////
//
// Member functions.
//
//////////////////////////////////////////////////////////////////////////////////////////////////
enum class MatrixType
{
MatrixA = 0,
@ -60,38 +63,6 @@ enum class MatrixType
MatrixC
};
//////////////////////////////////////////////////////////////////////////////////////////////////
//
// Utility functions.
//
//////////////////////////////////////////////////////////////////////////////////////////////////
template <typename BatchedGemmOptions>
bool useTmaOobOptA(BatchedGemmOptions const& options)
{
return options.mBatchMode == BatchedGemmOptions::BatchMode::BatchM && doesRouteImplUseNoRoute(options.mRouteImpl)
&& options.mUseTmaOobOpt;
}
//////////////////////////////////////////////////////////////////////////////////////////////////
template <typename BatchedGemmOptions>
bool useTmaOobOptB(BatchedGemmOptions const& options)
{
return options.mBatchMode == BatchedGemmOptions::BatchMode::BatchN && doesRouteImplUseNoRoute(options.mRouteImpl)
&& options.mUseTmaOobOpt;
}
//////////////////////////////////////////////////////////////////////////////////////////////////
template <typename BatchedGemmOptions>
bool useTmaOobOptC(BatchedGemmOptions const& options)
{
return options.mUseTmaStore && options.mUseTmaOobOpt;
}
//////////////////////////////////////////////////////////////////////////////////////////////////
// Create the TMA shape/stride for A/B/C.
template <class GemmOptions>
static auto makeTmaShapeStrideAbc(
@ -102,83 +73,60 @@ static auto makeTmaShapeStrideAbc(
bool const isWeights = (matrixType == MatrixType::MatrixA && options.mTransposeMmaOutput)
|| (matrixType == MatrixType::MatrixB && !options.mTransposeMmaOutput);
// Whether to use TMA OOB trick to block out padded dummy tokens and saving BW whenever no routing
// is involved. It applies to batchM and matrixA, or batchN and matrixB, or any case for matrixC.
bool const useTmaOobOpt = matrixType == MatrixType::MatrixA ? useTmaOobOptA(options)
: matrixType == MatrixType::MatrixB ? useTmaOobOptB(options)
: matrixType == MatrixType::MatrixC ? useTmaOobOptC(options)
: false;
// The outer dimension.
auto numTokens = (matrixType == MatrixType::MatrixA || matrixType == MatrixType::MatrixC) ? mM : mN;
// The outer dimension tile size.
auto ctaTileNumTokens = (matrixType == MatrixType::MatrixA || matrixType == MatrixType::MatrixC) ? tileM : tileN;
// The outer dimension of TMA box shape.
auto tileNumTokens = (matrixType == MatrixType::MatrixC) ? options.mEpilogueTileM : ctaTileNumTokens;
auto tileNumTokens = (matrixType == MatrixType::MatrixC) ? options.mEpilogueTileM
: (matrixType == MatrixType::MatrixA) ? tileM
: tileN;
// The inner dimension.
auto hiddenSize = (matrixType == MatrixType::MatrixC) ? mN : mK;
// The inner dimension tile size.
auto ctaTileHiddenSize = (matrixType == MatrixType::MatrixC) ? tileN : tileK;
// The inner dimension of TMA box shape.
auto tileHiddenSize = (matrixType == MatrixType::MatrixC) ? options.mEpilogueTileN : ctaTileHiddenSize;
auto tileHiddenSize = (matrixType == MatrixType::MatrixC) ? options.mEpilogueTileN : tileK;
// Swap matrix C sizes if output is transposed.
// Swap matrix C sizes if output is transpose
if (matrixType == MatrixType::MatrixC && options.mTransposeMmaOutput)
{
std::swap(numTokens, hiddenSize);
std::swap(ctaTileNumTokens, ctaTileHiddenSize);
std::swap(tileNumTokens, tileHiddenSize);
numTokens = mN;
hiddenSize = mM;
tileNumTokens = options.mEpilogueTileN;
tileHiddenSize = options.mEpilogueTileM;
}
// For a fused activation kernel, the hidden size of output is halved. TODO: That's true for
// gated activations but not regular activations.
if (options.mFusedAct && matrixType == MatrixType::MatrixC)
if (options.mFusedAct)
{
hiddenSize /= 2;
tileHiddenSize /= 2;
ctaTileHiddenSize /= 2;
if (matrixType == MatrixType::MatrixC)
{
hiddenSize /= 2;
tileHiddenSize /= 2;
}
}
// The cute tensor shape for A/B: (numTokens, hiddenSize).
// Note that TMA descriptor expects the first dimension's stride to be
// 1, so swap the first two dimension so that the hiddenSize dimension comes first.
// Activations matrix is 2D (sum(divUpMul(M[bi], tileM) for bi in B), K).
std::vector<uint64_t> shape = {static_cast<uint64_t>(hiddenSize), static_cast<uint64_t>(numTokens)};
if (useTmaOobOpt /* also implies input/output activation */)
auto shape = std::vector<uint64_t>{static_cast<uint64_t>(hiddenSize), static_cast<uint64_t>(numTokens)};
// If the matrix is a weights matrix, we use 3D logical shape for it (B, M, K) or (B, N, K).
// Ativations matrix is 2D (sum(divUpMul(M[bi], tileM) for bi in B), K).
if (isWeights)
{
// If TMA OOB optimization is used, we use 3D logical shape (M, tileM, K) or (N, tileN, K).
// The outer dimension is extended to make room for the possible counterbalance positive
// offset from the middle "bound" dimension. The counterbalance should be no more than
// ctaTileNumTokens.
shape = {static_cast<uint64_t>(hiddenSize), static_cast<uint64_t>(ctaTileNumTokens),
static_cast<uint64_t>(numTokens + ctaTileNumTokens)};
}
else if (isWeights)
{
// If the matrix is a weights matrix, we use 3D logical shape (B, M, K) or (B, N, K).
shape = {static_cast<uint64_t>(hiddenSize), static_cast<uint64_t>(numTokens),
static_cast<uint64_t>(options.mNumBatches)};
shape.push_back(static_cast<uint64_t>(options.mNumBatches));
}
// Assemble the stride (strideTokens, 1).
// Swap the first two dimension as mentioned before.
std::vector<uint64_t> stride = {1, static_cast<uint64_t>(hiddenSize)};
if (useTmaOobOpt)
auto stride = std::vector<uint64_t>{1, static_cast<uint64_t>(hiddenSize)};
if (isWeights)
{
stride = {1, static_cast<uint64_t>(hiddenSize), static_cast<uint64_t>(hiddenSize)};
}
else if (isWeights)
{
stride = {
1, static_cast<uint64_t>(hiddenSize), static_cast<uint64_t>(hiddenSize) * static_cast<uint64_t>(numTokens)};
stride.push_back(static_cast<uint64_t>(hiddenSize * numTokens));
}
// Assemble the box shape
std::vector<int32_t> tileShape = {tileHiddenSize, tileNumTokens};
// Alternate layouts (MajorMn and BlockMajorK) do not apply to matrixC
// Alternate layouts do not apply to matrixC
if (matrixType != MatrixType::MatrixC)
{
gemm::MatrixLayout layout = (matrixType == MatrixType::MatrixA) ? options.mLayoutA : options.mLayoutB;
@ -348,8 +296,8 @@ static KernelParams setKernelParams(GemmOptions_ const& options, bool const batc
for (int b = 0; b < options.mNumBatches; b++)
{
int mM = batchM ? options.mBatchedM[b] : options.mM;
int mN = batchM ? options.mN : options.mBatchedN[b];
int mM = batchM ? options.mBatchedM[b] : options.mN;
int mN = batchM ? options.mM : options.mBatchedN[b];
// Skip Tma descriptor creation if expert isn't used
if (mM == 0 || mN == 0)

View File

@ -1,3 +1,4 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
@ -18,7 +19,6 @@
namespace batchedGemm
{
// This is device code
struct KernelParams
@ -32,55 +32,6 @@ struct KernelParams
// Maximum number of CTAs
static constexpr int MaxNumCtas = 2048;
// NOTE: TMA out-of-bounds optimization for MoE padded tokens:
//
// Originally the padded tokens is a 2D tensor [hiddenDim, ctaGridDimY * tileN] with stride [1,
// hiddenDim] and box size [tileM, tileN] at pointer p. We waste bandwidth bytes since we only
// want to load [0, batchEnd) out of the [0, tileN) box size: batchEnd is a runtime variable while
// box size needs to be fixed at compile time.
//
// To deal with this, we reshape the tensor to 3D: [hiddenDim, tileN, ctaGridDimY * tileN] with
// stride [1, hiddenDim, hiddenDim] and box size [tileM, tileN, 1]. For the original 2D
// tensor,
//
// Offset Coords [ : , ctaIdxY * tileN ],
// Box Sizes [ : , tileN ],
// Coords Range [ : , ctaIdxY * tileN : ctaIdxY * tileN + tileN],
//
// while we only want load the range [ctaIdxY * tileN, ctaIdxY * tileN + batchEnd), 1 <= batchEnd
// <= tileN
//
// For the reshaped 3D tensor,
//
// Offset Coords [ : , tileN - batchEnd ,
// ctaIdxY * tileN + batchEnd ],
// Box Sizes [ : , tileN ,
// 1 ],
// Coords Range [ : , tileN - batchEnd : min(tileN, 2 * tileN - batchEnd),
// ctaIdxY * tileN + batchEnd : ctaIdx * tileN + batchEnd + 1],
//
// while min(tileN, 2 * tileN - batchEnd) always evaluates to tileN. The unwanted tokens are
// essentially filtered out by utilizing the OOB feature of TMA. Since the 2nd and 3rd dimension
// has the same stride, we end up loading the following (adding the left and right end of the 2nd
// and 3rd dimension ranges):
//
// Effective 2D Coords Range
// [ : , tileN + ctaIdxY * tileN : tileN + ctaIdxY * tileN + batchEnd],
//
// This is exactly the same as the original range except for the offset tileN, thus we also need
// to offset the pointer in the opposite direction:
//
// Ptr (p) -> Ptr (p - tileN * hiddenDim)
//
// Due to the restrictions of TMA unit, the above operations requires the TMA descriptor and the
// underlying buffer be constructed differently:
// - Requires valid buffer at (p - tileN * hidden) - needs prepending `tileN` tokens.
// - TMA outermost dimension must be extended by `tileN` or loads will OOB in the rightmost side.
// The latter is because when batchEnd == tileN, the offset coords in the 3rd dimension becomes
// ctaIdxY * tileN + tileN. When ctaIdxY = ctaGridDimY - 1, it becomes ((ctaGridDimY - 1) * tileN
// + tileN = ctaGridDimY * tileN which is equal to the 3rd dimension size and will be filtered
// out. That's why we need to extend the tensor size by tileN.
//
// TMA descriptor for A.
// Must be setup using gemm::buildNdTmaDescriptor with shapes and strides from
// makeTmaShapeStrideAbc.

View File

@ -20,7 +20,6 @@
#include "trtllm/gen/CommonUtils.h"
#include "trtllm/gen/DtypeDecl.h"
#include <cassert>
#include <stdexcept>
namespace batchedGemm
{
@ -78,38 +77,6 @@ public:
}
// Returns the offset of the ith chunk
int32_t getChunkOffsetByName(std::string const& name) const
{
for (size_t ii = 0; ii < mSmemChunkNames.size(); ++ii)
{
if (mSmemChunkNames[ii] == name)
{
return getChunkOffset(ii);
}
}
throw std::runtime_error("Name not found: " + name);
}
// Returns the first chunk reuse flag given chunk name.
int getFirstChunkReuseFlagByName(std::string const& name) const
{
for (size_t ii = 0; ii < mSmemChunkNames.size(); ++ii)
{
if (mSmemChunkNames[ii] == name)
{
return getFirstChunkReuseFlag(ii);
}
}
throw std::runtime_error("Name not found: " + name);
}
// Function to calculate the total size of the SMEM array
int32_t getTotalSize() const
{
return getOffsetBeforeChunk(static_cast<int32_t>(mNumBytesAndAlignmentPerSmemChunk.size()));
}
private:
int32_t getChunkOffset(int32_t ii) const
{
if (mFirstChunkReuse[ii])
@ -124,6 +91,12 @@ private:
return getSizePaddedToAlignment(offset, mNumBytesAndAlignmentPerSmemChunk[ii].second);
}
// Function to calculate the total size of the SMEM array
int32_t getTotalSize() const
{
return getOffsetBeforeChunk(static_cast<int32_t>(mNumBytesAndAlignmentPerSmemChunk.size()));
}
// Returns the first chunk reuse flag for the ith chunk.
int getFirstChunkReuseFlag(int32_t ii) const
{
@ -166,7 +139,9 @@ int getNumSmemBitsPerElt(tg::Dtype dtype, tg::MmaKind mmaKind)
{
if (mmaKind == tg::MmaKind::Auto)
{
throw std::runtime_error("mmaKind != tg::MmaKind::Auto");
std::cout << "mmaKind != tg::MmaKind::Auto" << std::endl;
assert(false);
return -1;
}
if (mmaKind == tg::MmaKind::MxFp8Fp6Fp4)
{
@ -566,14 +541,14 @@ inline int32_t getTmemBufferSize(KernelTraits traits)
inline int32_t getSmemOffsetLoadA(KernelTraits traits)
{
return traits.mSmemAllocatorHelper.getChunkOffsetByName("smemLoadA");
return traits.mSmemAllocatorHelper.getChunkOffset(0);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int32_t getSmemOffsetLoadB(KernelTraits traits)
{
return traits.mSmemAllocatorHelper.getChunkOffsetByName("smemLoadB");
return traits.mSmemAllocatorHelper.getChunkOffset(1);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -587,63 +562,64 @@ inline int32_t getSmemOffsetLoadAb(KernelTraits traits)
inline int32_t getSmemOffsetLoadShuffleB(KernelTraits traits)
{
return traits.mSmemAllocatorHelper.getChunkOffsetByName("smemBShuffle");
return traits.mSmemAllocatorHelper.getChunkOffset(2);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int32_t getSmemOffsetGmemC(KernelTraits traits, int resIdx = 0)
{
return traits.mSmemAllocatorHelper.getChunkOffsetByName("smemGmemC" + std::to_string(resIdx));
return traits.mSmemAllocatorHelper.getChunkOffset(3 + resIdx);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int32_t getSmemOffsetRowMax(KernelTraits traits)
{
return traits.mSmemAllocatorHelper.getChunkOffsetByName("smemRowMax");
return traits.mSmemAllocatorHelper.getChunkOffset(5);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int32_t getSmemOffsetSliceK(KernelTraits traits)
{
return traits.mSmemAllocatorHelper.getChunkOffsetByName("smemSliceK");
return traits.mSmemAllocatorHelper.getChunkOffset(6);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int32_t getSmemOffsetPerTokenSf(KernelTraits traits)
{
return traits.mSmemAllocatorHelper.getChunkOffsetByName("smemPerTokenSf");
return traits.mSmemAllocatorHelper.getChunkOffset(7);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int32_t getSmemOffsetBias(KernelTraits traits)
{
return traits.mSmemAllocatorHelper.getChunkOffsetByName("smemBias");
return traits.mSmemAllocatorHelper.getChunkOffset(8);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int32_t getSmemOffsetBlockAmax(KernelTraits traits)
{
return traits.mSmemAllocatorHelper.getChunkOffsetByName("smemBlockAmax");
return traits.mSmemAllocatorHelper.getChunkOffset(9);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int32_t getSmemOffsetConstSfBuf(KernelTraits traits)
{
return traits.mSmemAllocatorHelper.getChunkOffsetByName("smemConstSfBuf");
return traits.mSmemAllocatorHelper.getChunkOffset(10);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int32_t isSmemAbRepurposedToGmemC(KernelTraits traits, int resIdx = 0)
{
return traits.mSmemAllocatorHelper.getFirstChunkReuseFlagByName("smemGmemC" + std::to_string(resIdx));
// Be conscious that the index (3 + resIdx) should match the index in getSmemOffsetGmemC().
return traits.mSmemAllocatorHelper.getFirstChunkReuseFlag(3 + resIdx);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -654,28 +630,28 @@ inline int32_t isSmemAbRepurposedToGmemC(KernelTraits traits, int resIdx = 0)
inline int32_t getTmemOffsetD(KernelTraits traits)
{
return traits.mTmemAllocatorHelper.getChunkOffsetByName("tmemD");
return traits.mTmemAllocatorHelper.getChunkOffset(0);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int32_t getTmemOffsetA(KernelTraits traits)
{
return traits.mTmemAllocatorHelper.getChunkOffsetByName("tmemA");
return traits.mTmemAllocatorHelper.getChunkOffset(1);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int32_t getTmemOffsetSfA(KernelTraits traits)
{
return traits.mTmemAllocatorHelper.getChunkOffsetByName("tmemSfA");
return traits.mTmemAllocatorHelper.getChunkOffset(2);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int32_t getTmemOffsetSfB(KernelTraits traits)
{
return traits.mTmemAllocatorHelper.getChunkOffsetByName("tmemSfB");
return traits.mTmemAllocatorHelper.getChunkOffset(3);
}
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -181,8 +181,6 @@ inline CUtensorMap buildNdTmaDescriptor(tg::Dtype dtype, tg::MmaKind mmaKind, st
if (result != CUDA_SUCCESS)
{
char const* errorString;
cuGetErrorString(result, &errorString);
std::stringstream ss;
ss << "Error: Failed to initialize the TMA descriptor " << result << std::endl;
@ -285,10 +283,8 @@ inline CUtensorMap buildSfTmaDescriptor(tg::Dtype dtype, std::vector<uint64_t> c
if (result != CUDA_SUCCESS)
{
char const* errorString;
cuGetErrorString(result, &errorString);
std::stringstream ss;
ss << "Error: Failed to initialize the TMA descriptor for SF " << errorString << std::endl;
ss << "Error: Failed to initialize the TMA descriptor for SF " << result << std::endl;
ss << "tmaFormat: " << static_cast<int>(tmaDataFormat) << " dim: " << dim << " gmem: " << gmemAddr << std::endl;

View File

@ -12,6 +12,7 @@
"epilogueTileM": 128,
"epilogueTileN": 8,
"numStages": 4,
"numStagesMma": 1,
"numSlicesForSplitK": 1,
"useTwoTmaLoadWarps": true,
"clusterDimX": 1,
@ -30,7 +31,8 @@
"sfLayoutC": "8x4",
"batch": "N",
"numExperts": 128,
"useCudaGraph": true
"useCudaGraph": true,
"clampLimit": 2
},
"BatchedGemmPerTensorScalingFp8LowLatency": {
"dtypeA": "e4m3",
@ -44,6 +46,7 @@
"epilogueTileM": 128,
"epilogueTileN": 8,
"numStages": 3,
"numStagesMma": 1,
"numSlicesForSplitK": 1,
"useTwoTmaLoadWarps": true,
"clusterDimX": 1,
@ -60,7 +63,8 @@
"gridWaitForPrimaryB": true,
"batch": "N",
"numExperts": 128,
"useCudaGraph": true
"useCudaGraph": true,
"clampLimit": 2
},
"BatchedGemmDeepSeekFp8LowLatency": {
"dtypeA": "e4m3",
@ -94,7 +98,123 @@
"numStagesMma": 4,
"batch": "N",
"numExperts": 128,
"useCudaGraph": true
"useCudaGraph": true,
"clampLimit": 2
},
"BatchedGemmMxE2m1E4m3LowLatency": {
"dtypeA": "mxe2m1",
"dtypeB": "e4m3",
"dtypeC": "e4m3",
"dtypeMmaB": "mxe4m3",
"mmaM": 128,
"mmaN": 8,
"mmaK": 32,
"tileM": 128,
"tileN": 8,
"tileK": 512,
"epilogueTileM": 128,
"epilogueTileN": 8,
"numStages": 3,
"numStagesMma": 1,
"numSlicesForSplitK": 1,
"useTwoTmaLoadWarps": true,
"clusterDimX": 1,
"clusterDimY": 1,
"clusterDimZ": 1,
"sliceK": false,
"transposeMmaOutput": true,
"useShuffledMatrixA": true,
"useDeepSeekFp8": false,
"useTmaStore": true,
"useCustomMmaSchedule": true,
"gridTriggerSecondaryB": true,
"gridWaitForPrimaryA": false,
"gridWaitForPrimaryB": true,
"sfLayoutB": "8x4",
"sfLayoutC": "8x4",
"batch": "N",
"numExperts": 128,
"useCudaGraph": true,
"biasType": "m",
"act": "swiglu",
"clampLimit": 2
},
"BatchedGemmMxE2m1MxE4m3LowLatency": {
"dtypeA": "mxe2m1",
"dtypeB": "mxe4m3",
"dtypeC": "mxe4m3",
"mmaM": 128,
"mmaN": 8,
"mmaK": 32,
"tileM": 128,
"tileN": 8,
"tileK": 512,
"epilogueTileM": 128,
"epilogueTileN": 8,
"numStages": 3,
"numStagesMma": 1,
"numSlicesForSplitK": 1,
"useTwoTmaLoadWarps": true,
"clusterDimX": 1,
"clusterDimY": 1,
"clusterDimZ": 1,
"sliceK": false,
"transposeMmaOutput": true,
"useShuffledMatrixA": true,
"useDeepSeekFp8": false,
"useTmaStore": true,
"useCustomMmaSchedule": true,
"gridTriggerSecondaryB": true,
"gridWaitForPrimaryA": false,
"gridWaitForPrimaryB": true,
"sfLayoutB": "8x4",
"sfLayoutC": "8x4",
"batch": "N",
"numExperts": 128,
"useCudaGraph": true,
"biasType": "m",
"act": "swiglu",
"clampLimit": 2
},
"BatchedGemmMxE2m1Bf16LowLatency": {
"dtypeA": "mxe2m1",
"dtypeB": "bf16",
"dtypeC": "bf16",
"dtypeMmaA": "bf16",
"dtypeMmaB": "bf16",
"mmaM": 128,
"mmaN": 8,
"mmaK": 16,
"tileM": 128,
"tileN": 8,
"tileK": 256,
"epilogueTileM": 128,
"epilogueTileN": 8,
"numStages": 3,
"numStagesMma": 1,
"numSlicesForSplitK": 1,
"useTwoTmaLoadWarps": true,
"clusterDimX": 1,
"clusterDimY": 1,
"clusterDimZ": 1,
"sliceK": false,
"transposeMmaOutput": true,
"useShuffledMatrixA": true,
"useDeepSeekFp8": false,
"useTmaStore": true,
"useCustomMmaSchedule": true,
"gridTriggerSecondaryB": true,
"gridWaitForPrimaryA": false,
"gridWaitForPrimaryB": true,
"sfLayoutB": "8x4",
"sfLayoutC": "8x4",
"batch": "N",
"numExperts": 128,
"useCudaGraph": true,
"biasType": "m",
"act": "swiglu",
"patchF2fp": true,
"clampLimit": 2
}
},
"configs": [
@ -221,6 +341,7 @@
"_template": "BatchedGemmPerTensorScalingFp8LowLatency",
"routeAct": true,
"fusedAct": true,
"usePerTokenSfB": true,
"useUnrollLoop2xForMma": [true, false],
"dtypeC": "e4m3",
"numTokens": 2,
@ -243,6 +364,150 @@
["static", 1],
["persistent", 2]
]
},
{
"_comment": "MxFp4xFp8_FC1",
"_template": "BatchedGemmMxE2m1E4m3LowLatency",
"routeAct": "ldgsts",
"fusedAct": true,
"sfLayoutB": "linear",
"useUnrollLoop2xForMma": [true, false],
"numTokens": 2,
"numExperts": 2,
"mmaN,tileN,epilogueTileN,tileK,numStages": [
[8, 8, 8, 512, 3],
[8, 8, 8, 256, 5],
[16, 16, 16, 256, 5],
[32, 32, 32, 256, 5],
[64, 64, 64, 256, 4]
],
"tileScheduler,numStagesMma": [
["static", 1],
["persistent", 2]
]
},
{
"_comment": "MxFp4xFp8_FC2",
"_template": "BatchedGemmMxE2m1E4m3LowLatency",
"routeAct": false,
"fusedAct": false,
"useUnrollLoop2xForMma": [true, false],
"dtypeC": "bf16",
"numTokens": 2,
"numExperts": 2,
"mmaN,tileN,epilogueTileN,tileK,numStages": [
[8, 8, 8, 512, 3],
[8, 8, 8, 256, 5],
[16, 16, 16, 256, 5],
[32, 32, 32, 256, 5],
[64, 64, 64, 256, 4]
],
"tileScheduler,numStagesMma": [
["static", 1],
["persistent", 2]
]
},
{
"_comment": "MxFp4xMxFp8_FC1",
"_template": "BatchedGemmMxE2m1MxE4m3LowLatency",
"routeAct": "ldgsts",
"fusedAct": true,
"sfLayoutB": "linear",
"useUnrollLoop2xForMma": [true, false],
"numTokens": 2,
"numExperts": 2,
"mmaN,tileN,epilogueTileN,tileK,numSlicesForSplitK,clusterDimZ,numStages": [
[8, 8, 8, 512, 1, 1, 3],
[8, 8, 8, 512, 2, 2, 3],
[8, 8, 8, 256, 1, 1, 4],
[8, 8, 8, 256, 1, 1, 5],
[8, 8, 8, 256, 1, 1, 6],
[8, 8, 8, 256, 2, 2, 4],
[8, 8, 8, 256, 2, 2, 5],
[8, 8, 8, 256, 2, 2, 6],
[16, 16, 16, 256, 1, 1, 3],
[16, 16, 16, 256, 1, 1, 4],
[32, 32, 32, 256, 1, 1, 3],
[32, 32, 32, 256, 1, 1, 4],
[64, 64, 64, 256, 1, 1, 3],
[64, 64, 64, 256, 1, 1, 4]
],
"tileScheduler,numStagesMma": [
["static", 1],
["persistent", 2]
]
},
{
"_comment": "MxFp4xMxFp8_FC2",
"_template": "BatchedGemmMxE2m1MxE4m3LowLatency",
"routeAct": false,
"fusedAct": false,
"useUnrollLoop2xForMma": [true, false],
"dtypeC": "bf16",
"numTokens": 2,
"numExperts": 2,
"mmaN,tileN,epilogueTileN,tileK,numSlicesForSplitK,clusterDimZ,numStages": [
[8, 8, 8, 512, 1, 1, 3],
[8, 8, 8, 512, 2, 2, 3],
[8, 8, 8, 256, 1, 1, 4],
[8, 8, 8, 256, 1, 1, 5],
[8, 8, 8, 256, 1, 1, 6],
[8, 8, 8, 256, 2, 2, 4],
[8, 8, 8, 256, 2, 2, 5],
[8, 8, 8, 256, 2, 2, 6],
[16, 16, 16, 256, 1, 1, 3],
[16, 16, 16, 256, 1, 1, 4],
[32, 32, 32, 256, 1, 1, 3],
[32, 32, 32, 256, 1, 1, 4],
[64, 64, 64, 256, 1, 1, 3],
[64, 64, 64, 256, 1, 1, 4]
],
"tileScheduler,numStagesMma": [
["static", 1],
["persistent", 2]
]
},
{
"_comment": "MxFp4xBf16_FC1",
"_template": "BatchedGemmMxE2m1Bf16LowLatency",
"routeAct": "ldgsts",
"fusedAct": true,
"sfLayoutB": "linear",
"useUnrollLoop2xForMma": [true, false],
"dtypeC": "bf16",
"numTokens": 2,
"numExperts": 2,
"tileK": 256,
"mmaN,tileN,epilogueTileN,numStages": [
[8, 8, 8, 3],
[16, 16, 16, 3],
[32, 32, 32, 3],
[64, 64, 64, 3]
],
"tileScheduler,numStagesMma": [
["static", 1],
["persistent", 2]
]
},
{
"_comment": "MxFp4xBf16_FC2",
"_template": "BatchedGemmMxE2m1Bf16LowLatency",
"routeAct": false,
"fusedAct": false,
"useUnrollLoop2xForMma": [true, false],
"dtypeC": "bf16",
"numTokens": 2,
"numExperts": 2,
"mmaN,tileN,epilogueTileN,numStages": [
[8, 8, 8, 3],
[16, 16, 16, 3],
[32, 32, 32, 3],
[64, 64, 64, 3]
],
"tileScheduler,numStagesMma": [
["static", 1],
["persistent", 2]
]
}
]
}

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:eba4edca5eaa1fc6b654c4b720339cd536a02723cc798fc17cb31314a1681633
size 701588

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:20ad15a9f6be1c021baf23f4f24154c22a05ce90d26c631dc21bf53e0a489174
size 581311

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:815661113030c8580cc963c443afc3c82fff7c8f8dd8a0ed98f95a08a91f619a
size 684616

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:75593264df3d37f7d23664c02e426196a5f7ee1cc6de76db84895fca4e706c97
size 562811

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1de230b7b4bcfec7b5100f8ddbe05f3789577a38767c314cc22554a0b4463275
size 722952

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4bcb2a1f2c500a4755940bc5803c9e7ac0a3987d671ff705599849625339dd0a
size 603467

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:11afdd6753faeffe1446abed54c61ccf0e190ff886fc0270d055a4dbce9a9298
size 705390

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7e4ecf7abe90d1a7f46a1ff3005cad8ebe0de8e7cf89862cf993f228a98ea00d
size 582747

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:58a2bd3f71c60360ac8aabb9a70e96f69e7ce9cb8de89c36fc10786cc47f5eb7
size 684024

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b5cc810d34a7ad8dfdf60f3708158afd3a1743528d6788dca91868bc86c66845
size 567153

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e921f66e6adc5292b99df5cfbae4a9cbae6182c8e99bbc83ea30bd1ca8ed8f55
size 667892

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3b6994776b67405b359d471fa5c873efa4714dd71397ddfd82a52d59cbf20a9a
size 550035

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0252fbcd6f38bcd3e39e07cdabcf7776a4089745290caeac27b6d446ba1cb46e
size 717132

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e5b588e415053600a90d38211bcb5a969471b0e5e6c3437d78254d236129a270
size 593303

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ec5cf07ec8b7a4405c305fb82f9eb7179a4a43ab14a2eacfadc35072b317cfd7
size 700704

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6188f80d9ca7f95ea404a72de95bb8a38cace5dd8a8e76527fd83cf16aaff87d
size 575543

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a4be5d8b93abcd165428c9770e68286224d01f6624941ed534fb66bcc17344ab
size 743964

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0482d8975738bd4e614dbb86745c5d49445d6a5c460d2245d3f17ce7bc992a1a
size 585703

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b725a2ec74293ef928ade1206ebf2e3726f5980bc943f157372a414834d756fd
size 725020

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9b2bd2855cc99dd074f24435265bfc32d0114a2d9e02ff98565c7881095674dc
size 566363

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cd31bc4f601149aba47f5e78fe197db66a83afde56071e8281818e1de727daa9
size 765280

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3a73ddc42c26585294145e543dedcf4140ff538044e5113aefed6dd5537cdaa2
size 607019

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2be435473939c1f81601b61b372a37982b0aa0f107cd4778c201d160c4f8e43c
size 745004

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ba80fab864884c31fdca7963a86df70f384e06e5ef76e904971543af60a05c06
size 586347

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:01cf2fbe35943350cbbc9930f632af76b9f2fc1de61dda1f55bd97c9cb15792e
size 687034

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f7e50ffcd72e58781a11e39f3d3de530214aaa25b9d8b78bc5326f5a17bec622
size 570163

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cdf20f912daa7f6e4580ead6b14f66c3aa0d70d536dfdb509ab06574f8dedcc2
size 670656

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:04085e7681246fbf85f61871704ceb68ec39dc3a2ed9a4c3b9855b8da6d6a0f6
size 552255

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fb2d60b386a7d82f5076a78095ac9ca7a1431c082c0f11915ef602b72e16d2b8
size 719944

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d556fada7372596fd09b3e8a06f21c67b891f35890bef733bd278d4f746e578e
size 596263

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:46d47463f5c2c2711468fde51b407b9e31fc7458cd98c2f2d79181037554591e
size 702972

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:58fb73ed28d0c7c1e0705fcade19faa88e7b310c4b97d0004c84ed52cc275cd1
size 577763

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4356863633b801a23e960dc4e9e424a69c7b1897a7bc97579e23504235d3e373
size 718360

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:40476bd068ca9bbc4f7217fa9fc04b6b7fdecdb45eb1e9f62dcd566e55c7d07a
size 596605

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fed2ac3d4effc40881584f93f4ecf938d0081131cb7b5fa26519be7101d8b0a4
size 700600

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:516bdac65a439b1ab5615912d573172c38792ab898e68f1b972253588767c398
size 577315

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f89b00b57cafb59f1b912d03ffb8037fba3cc614274d600a758f1597a4d0ab19
size 737950

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:16532048133f5f840f3823ecc343d3d5be65164341ae4f2900dba6901f50c588
size 618760

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:31b8d4b7836e084a9b3f5e8a76d8257f33840162d26b385db5bfcccfd36333fa
size 719648

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:10607e6a0e33ca77881111bf04f7e190bab576f1987ecaf38298b96682a3f51d
size 598089

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:da385287279c9383155405cd2344705c4607b7742466e8a545fef694bb75cb49
size 696556

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3273045650b7230c83758d9ccfd6cc06b52cbdd91624e50538481d447e48c509
size 566955

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:28e437809e9158ec82047a8ff72248fe641c735eb6ebb50984b76fda16df32b2
size 680966

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ea3a2e628b52e2ab71dee5405da588b8735f5cd2d250dd2b106621c603bc4183
size 549047

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e6f459a3ac169228f685c60db132055f25ff9ed8c3a8455be315858eb97d223a
size 729564

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:24e5cb1112172dfa0522f99de2d1cef6bfa50304b2d3c32fbde39bca36e92567
size 593847

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7cebb410d52fda0fef353941e34f1e9b09e152b6d801885d43f39b05ab7feecb
size 713284

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a01cadc451bde20a1dc1b2905f216468009401f5c230d8e3e172eb7c0e19a73e
size 574557

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a6d6ac9bea4823aa5e7f7efa696533e6d52405e5cfeba7732051e30789813eed
size 697586

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1d401d5ecb5ea7fe0a689c516a195e4eb325a1b2c6d38e2f0c49f712c923efce
size 576817

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8025a48697e5c22bb2cd0d21023d7a554ef64cb8658f24f171c6c249b0104941
size 680714

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f2c747c157bc41de3c7e85f31cb39752451d5809476af277602d64c0b5a6cb27
size 558317

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f945903cacb99529e668335deaeb0cc9a72f6261e0c2f1f87e037c31d0dbf29a
size 719000

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d9ca86c5b4528521ece2e2694187d67336e4b92a315d8fa3db9ed3fb9e4ea3aa
size 599761

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:22751baf77b6f7d97d3914458973a6ca03dfe01e7f9fd84e419903b133d82b16
size 700746

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3aa5c83701db0e0adf0f1727f454867a29c941587f1d6be2fa496143b5e768a1
size 579089

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b60a53ca988b19c6881800d3a8c03e3464e02644c9c8552bb8c9155defb5e3e4
size 679480

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:000feee3db5eba2d629e2c31599fa66807943f60e3205cd69c5c30e418a8a7ee
size 563497

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4531ad2f99ffcb3326fe8c3312ac93f47dc3f9cc564ee8bd8cabc80766d79828
size 560519

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:70f3e4c712f6c1e7a887771473daea1f4ebdb3febb4ae8e0a8e8045ee867712b
size 663694

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b6d4e15ebd2d638df47afc4fcc58ff56ba8032d7352475cdae04c20195795557
size 546377

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9a4899f491141f32b13e27af6af56d729abaf5e0454cbd6f1a6a2f409ccb237d
size 542611

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2bf8748bf1007349c5fc8cf062ba90669680ec8fe3681ecb6174cb1da0548a7f
size 713030

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b0ad13b702274ebacb48c40948662d38f6c184b61b868f3c6ed062862c712b04
size 589597

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4310d5b81d090acdfa8481681b3310a6417b0b72e67ca6b7605a5766bd691296
size 586619

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d8185e69e406c9e2cb2533f6c4f17ceba7c938642439634f265acad34cc56af5
size 696010

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7f65dfb888be44aa2213e8a8a9dc4ff984e4224135df2e94241bb52ce60c19df
size 571097

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:55285b3eaf9712c22ebc203ec690148dfb971f477710c0fc8139070573f138d8
size 567331

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6ebe7bcc1ee67e5c47ccc7c3fbf3f9fb11c9ee3df907cff85bcf9efdb06e473c
size 626970

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9e4c7081a9e7f297a9406783919d035cad4e542e1ffdf261957a2a1c5c1807ae
size 505905

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:85420fddbd50cf79cc03a9f7b42957e853967e7995c768b874c38c23c317cd93
size 621100

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:55ef9504c71a596c687518e99a9de172baa0d407acaa438f810cc66b4ab03353
size 497567

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ef80a87cb0f3df7e0663a811eeb2acb5471d886e9b7c57e1d1700f39fefafdd1
size 654748

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e4af837c7a8837637f32715f5ec16bdf37d88edc5722d5e73464a215d3815512
size 524459

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d7d29557f9f4d2c5b259f81c0f26f9df5edc25f4029cbb0b4dbecf3e90e71b46
size 648286

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d2486fee1a77912dbb3ebba6d4c4f6418c4c1a90013f0f2a323ece5e844f2753
size 516467

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d105ac221b0ab6dd54bb4d4a33e2816514a1d3c1f5de2bbfc6f33cf8e292b093
size 634518

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3a76c94d3c24f8a511c2bdbacf33e8b25d8b933cd2992764e20227cc4aa21ee2
size 515179

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:12e36aef6743d6ddd96ddb08418626fe5b3e4b83f663957386dec718d1325f8f
size 629388

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:36846461c9cee973c8e43123da81b8ce68512821cdb013bd3ff7e2b47cc4a736
size 507681

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d67d3bd621138d2201abac84f8ffb2bf1903b674b23e254c4efca4a5e5825c7a
size 661508

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e8279e11a9fef0ce15ac4cfdacfdf33796dbf85a14b5c76debdce5dc81b7e817
size 534523

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ba884722cb487cfa74838af9106919249ae4d0069cd7a76593f91644558f18d8
size 654896

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9fc47e97fff712a63a4837aaa4dfe2edbd4f8b3b6d0621d6478d9862b51156cb
size 523915

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:59736e223d17e04ce5a3b11056f1cd6dcce0795705e778c4a40188129644ee4e
size 672406

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4799e5b72cb1fc46ed6b3aeb38db6a0d3c682891c70e87b91095537f7e3f6e77
size 551143

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:15b965067f5d7bc01c722dd20399eb842a21e313040a47d5745bf6b305635ee6
size 667276

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2d6eb31c54ac57087ac4d1223fe88c62084b376ba9dfcbfc35032be1e179a2cd
size 543595

Some files were not shown because too many files have changed in this diff Show More