update trtllm-gen sm100f cubins of gemm kernels

Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com>
This commit is contained in:
Xiwen Yu 2025-08-04 18:15:25 +08:00
parent 52ad4436bc
commit 345c2bceaa
582 changed files with 11547 additions and 4159 deletions

View File

@ -244,11 +244,48 @@ struct BatchedGemmData
// Shape is [B].
float const* mPtrScaleGate{nullptr};
// The clamp limit for the accumulator before applying the activation.
// Shape is [B].
// Clamp is INF if nullptr.
// When the input is FP8 or NVFP4, the clamp has to be scaled by limit' = limit / dequantAb.
// If applied on SwiGlu, it will be:
//
// x_glu = x_glu.clamp(min=None, max=limit)
// x_linear = x_linear.clamp(min=-limit, max=limit)
//
// The given clamp limit applies to the dequantized values, so the order of operations would
// look something like this:
//
// x0 = x0 * dqAb
// x0 = clamp(x0, none, limit)
// x0 = x0 * sigmoid(alpha * x0)
// x1 = dqAb * x1
// x1 = clamp(x1, -limit, limit)
// out = qC * (x1 + beta) * x0
//
// Given that the dqAb and qC are combined into scaleC, we can bring the dqAb into the clamp
// limit and apply the clamping prior to dequantization:
//
// x0 = clamp(x0, none, limit / dqAb)
// x0 = x0 * dqAb
// x0 = x0 * sigmoid(alpha * x0)
// x1 = clamp(x1, -limit / dqAb, limit / dqAb)
// scaleC = dqAb * qC
// beta' = beta / dqAb
// out = scaleC * (x1 + beta') * x0
//
// Note this assumes that scaleAb == scaleGate which is true in TRT-LLM MoE use-case
//
float const* mPtrClampLimit{nullptr};
// The alpha and beta for SwiGlu.
// gatedActivation <- (x0 + beta) * activation(x1, alpha)
// Shape is [B].
// Alpha is 1.f if nullptr.
// Beta is 0.f if nullptr.
// The formula:
//
// out_glu = x_glu * torch.sigmoid(alpha * x_glu) + (x_linear + beta)
float const* mPtrSwiGluAlpha{nullptr};
float const* mPtrSwiGluBeta{nullptr};
@ -591,6 +628,7 @@ int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspa
{
// Might be used.
(void) usePdl;
(void) moduleCache;
// Get options from config and data.
auto options = getOptionsFromConfigAndData(config, batchedGemmData);
@ -642,17 +680,17 @@ int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspa
auto const numCtaZ = options.mNumSlicesForSplitK;
mNumCtas = numCtaX * numCtaY * numCtaZ;
auto kernelParams = KernelParams::setKernelParams(options, batchM, batchedGemmData.mInputBuffers.mPtrA,
auto kernelParams = KernelParamsSetup::setKernelParams(options, batchM, batchedGemmData.mInputBuffers.mPtrA,
batchedGemmData.mInputBuffers.mPtrB, batchedGemmData.mOutputBuffers.mPtrC,
batchedGemmData.mInputBuffers.mPtrSfA, batchedGemmData.mInputBuffers.mPtrSfB,
batchedGemmData.mInputBuffers.mPtrPerTokenSfA, batchedGemmData.mInputBuffers.mPtrPerTokenSfB,
batchedGemmData.mInputBuffers.mPtrBias, batchedGemmData.mOutputBuffers.mPtrSfC,
batchedGemmData.mInputBuffers.mPtrScaleC, batchedGemmData.mInputBuffers.mPtrScaleGate,
batchedGemmData.mInputBuffers.mPtrSwiGluAlpha, batchedGemmData.mInputBuffers.mPtrSwiGluBeta,
batchedGemmData.mInputBuffers.mPtrRouteMap, dPtrRowMax, dPtrRowMaxBars,
batchedGemmData.mInputBuffers.mPtrNumNonExitingCtas, batchedGemmData.mInputBuffers.mPtrTotalNumPaddedTokens,
batchedGemmData.mInputBuffers.mPtrCtaIdxXyToBatchIdx, batchedGemmData.mInputBuffers.mPtrCtaIdxXyToMnLimit,
maxNumCtasInBatchDim);
batchedGemmData.mInputBuffers.mPtrClampLimit, batchedGemmData.mInputBuffers.mPtrSwiGluAlpha,
batchedGemmData.mInputBuffers.mPtrSwiGluBeta, batchedGemmData.mInputBuffers.mPtrRouteMap, dPtrRowMax,
dPtrRowMaxBars, batchedGemmData.mInputBuffers.mPtrNumNonExitingCtas,
batchedGemmData.mInputBuffers.mPtrTotalNumPaddedTokens, batchedGemmData.mInputBuffers.mPtrCtaIdxXyToBatchIdx,
batchedGemmData.mInputBuffers.mPtrCtaIdxXyToMnLimit, maxNumCtasInBatchDim);
// The size of the grid.
std::vector<int32_t> grid{numCtaX, numCtaY, numCtaZ};
@ -660,26 +698,26 @@ int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspa
#ifdef TLLM_GEN_EXPORT_INTERFACE
CUmodule cuModule;
CUfunction cuFunction;
if (moduleCache.has_value())
{
ModuleCache& moduleCacheRef = moduleCache.value().get();
// Modules are associated with a specific context so include the ctxId in the key
// Modules are associated with a specific context, so the context is included in the key
CUcontext ctx;
unsigned long long ctxId;
cuCtxGetCurrent(&ctx);
cuCtxGetId(ctx, &ctxId);
// Reinterpret the ctxId as a string to avoid needing a custom hash or converting it to a string in decimal
// representation.
// Reinterpret the ctxId as a string to avoid needing a custom hash or converting it to a
// string in decimal representation.
std::string const ctxName
= std::string(reinterpret_cast<char*>(&ctxId), sizeof(unsigned long long) / sizeof(char));
std::string const funcName = std::string(config.mFunctionName);
// As the ctxName is a fixed number of bytes, the two strings can just be appended without risk of a collision
auto const moduleKey = ctxName + funcName;
auto module = moduleCacheRef.find(moduleKey);
// Check if module exists in cache. Otherwise, load it
// Use cache if module is found, otherwise load and insert into cache
if (module != moduleCacheRef.end())
{
cuFunction = std::get<1>(module->second);
@ -716,7 +754,7 @@ int32_t BatchedGemmInterface::run(BatchedGemmConfig const& config, void* workspa
{
return -1;
}
// If a module cache has not been given, unload the module to avoid overflow
// If a module cache has not been given, unload the module to avoid leaking
if (!moduleCache.has_value())
{
cuModuleUnload(cuModule);

View File

@ -96,10 +96,10 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions
int tileK, bool useUnrollLoop2xForMma, bool useCustomMmaSchedule, bool useHoistTryWaitForCustomMmaSchedule,
bool useDeepSeekFp8, bool usePerTokenSfA, bool usePerTokenSfB, bool useTmaStore, bool useTwoTmaLoadWarps,
bool useTwoMmaWarps, tg::SfLayout sfLayoutA, tg::SfLayout sfLayoutB, tg::SfLayout sfLayoutC,
int32_t sfReshapeFactor, gemm::TileScheduler tileScheduler, gemmGatedAct::ActType actType,
int32_t sfReshapeFactor, gemm::TileScheduler tileScheduler, gemmGatedAct::ActType actType, bool clampBeforeAct,
std::vector<int> batchedM, std::vector<int> batchedN, BatchMode batchMode, int numBatches, bool isStaticBatch,
int numTokens, RouteImpl routeImpl, bool gridWaitForPrimaryRouting, bool fusedAct,
int numRegsPerThreadNonEpilogueWarp, int numRegsPerThreadEpilogueWarp, int numRegsCastAWarps)
int numRegsPerThreadNonEpilogueWarp, int numRegsPerThreadEpilogueWarp, int numRegsCastAWarps, bool useTmaOobOpt)
: gemmGatedAct::GemmGatedActOptions(
gemm::GemmOptions(allReduceAlgo, biasType, blockK, clusterDimX, clusterDimY, clusterDimZ, dtypeAcc, dtypeA,
dtypeB, dtypeC, dtypeMmaA, dtypeMmaB, enablesEarlyExit, enablesDelayedEarlyExit, enablesGlobalPtxKnobs,
@ -112,19 +112,20 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions
useCustomMmaSchedule, useHoistTryWaitForCustomMmaSchedule, useDeepSeekFp8, usePerTokenSfA,
usePerTokenSfB, useTmaStore, useTwoTmaLoadWarps, useTwoMmaWarps, sfLayoutA, sfLayoutB, sfLayoutC,
sfReshapeFactor, tileScheduler),
actType)
actType, clampBeforeAct)
, mBatchedM(batchedM)
, mBatchedN(batchedN)
, mBatchMode(BatchMode(batchMode))
, mNumBatches(numBatches)
, mIsStaticBatch(isStaticBatch)
, mNumTokens(numTokens)
, mRouteImpl(routeImpl)
, mGridWaitForPrimaryRouting(gridWaitForPrimaryRouting)
, mFusedAct(fusedAct)
, mGridWaitForPrimaryRouting(gridWaitForPrimaryRouting)
, mIsStaticBatch(isStaticBatch)
, mNumBatches(numBatches)
, mNumRegsPerThreadNonEpilogueWarp(numRegsPerThreadNonEpilogueWarp)
, mNumRegsPerThreadEpilogueWarp(numRegsPerThreadEpilogueWarp)
, mNumRegsCastAWarps(numRegsCastAWarps)
, mNumTokens(numTokens)
, mRouteImpl(routeImpl)
, mUseTmaOobOpt(useTmaOobOpt)
{
}
@ -134,28 +135,28 @@ struct BatchedGemmOptions : public gemmGatedAct::GemmGatedActOptions
std::vector<int> mBatchedN;
// Whether batching M or N.
BatchMode mBatchMode{BatchMode::BatchM};
// Number of Gemm batches.
int mNumBatches;
// Whether the batch size is static (i.e. known at kernel launch time).
bool mIsStaticBatch{true};
// Total number of tokens.
int mNumTokens{32};
// Whether load the input tokens and do routing.
RouteImpl mRouteImpl{RouteImpl::NoRoute};
// Whether to perform a fused gated activation.
bool mFusedAct{false};
// Whether the loads that load from ptrRouteMap, ptrTotalNumPaddedTokens,
// ptrCtaIdxXyToBatchIdx, etc.. should wait on a grid dependency.
bool mGridWaitForPrimaryRouting{true};
// Whether to perform a fused gated activation.
bool mFusedAct{false};
// Whether the batch size is static (i.e. known at kernel launch time).
bool mIsStaticBatch{true};
// Number of Gemm batches.
int mNumBatches;
// Number of registers per thread for non-epilogue warps
int mNumRegsPerThreadNonEpilogueWarp{0};
// Number of registers per thread for epilogue warps
int mNumRegsPerThreadEpilogueWarp{0};
// Number of registers for the cast A warps.
int mNumRegsCastAWarps{0};
// Total number of tokens.
int mNumTokens{32};
// Whether load the input tokens and do routing.
RouteImpl mRouteImpl{RouteImpl::NoRoute};
// Whether to use TMA out-of-bounds optimization to reduce wasted traffic. See details in
// BatchedGemm/KernelParamsDecl.h.
bool mUseTmaOobOpt{false};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -165,6 +166,20 @@ bool checkAndUpdateBatchedGemmOptions(BatchedGemmOptions& options, bool isBlackw
{
bool isValid = true;
if (options.mUseTmaOobOpt && !options.mUseTwoTmaLoadWarps)
{
if (updateOptions)
{
// Since any routing (mRouteAct != NoRoute) requires mUseTwoTmaLoadWarps == true.
// Single TMA load warp is not the target use case for OOB optimization.
options.mUseTmaOobOpt = false;
}
else
{
TLLM_CHECK_ERROR(false, "TMA OOB optimization requires two TMA load warps.");
return false;
}
}
if (options.mFusedAct)
{
// ensure that we check the fused options as well
@ -340,6 +355,7 @@ struct BatchedGemmConfig
uint32_t const mSharedMemSize{0};
char const* mFunctionName{nullptr};
uint32_t const mNumThreadsPerCTA{0};
char const* mHash{nullptr};
#else
trtllm::gen::CudaRunner* mCudaRunner{nullptr};
#endif
@ -366,7 +382,8 @@ inline std::string dumpOptions(BatchedGemmOptions const& options)
ss << "mFusedAct=" << options.mFusedAct << "," << std::endl;
ss << "mNumRegsPerThreadNonEpilogueWarp=" << options.mNumRegsPerThreadNonEpilogueWarp << "," << std::endl;
ss << "mNumRegsPerThreadEpilogueWarp=" << options.mNumRegsPerThreadEpilogueWarp << "," << std::endl;
ss << "mNumRegsCastAWarps=" << options.mNumRegsCastAWarps << std::endl;
ss << "mNumRegsCastAWarps=" << options.mNumRegsCastAWarps << "," << std::endl;
ss << "mUseTmaOobOpt=" << options.mUseTmaOobOpt << std::endl;
return ss.str();
}

View File

@ -101,14 +101,17 @@ struct GemmGatedActOptions : public gemm::GemmOptions
{
GemmGatedActOptions() = default;
GemmGatedActOptions(gemm::GemmOptions options, ActType actType)
GemmGatedActOptions(gemm::GemmOptions options, ActType actType, bool clampBeforeAct)
: gemm::GemmOptions(options)
, mActType(actType)
, mClampBeforeAct(clampBeforeAct)
{
}
// Type of the gated activation.
ActType mActType{ActType::SwiGlu};
// Clamp the dequantized values to the range [-limit, limit].
bool mClampBeforeAct{false};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -175,7 +178,8 @@ inline std::string dumpOptions(GemmGatedActOptions const& options)
std::stringstream ss;
ss << gemm::dumpOptions(options) << ", ";
ss << "mActType="
<< "gemmGatedAct::ActType(" << static_cast<int32_t>(options.mActType) << ")" << std::endl;
<< "gemmGatedAct::ActType(" << static_cast<int32_t>(options.mActType) << ")," << std::endl;
ss << "mClampBeforeAct=" << options.mClampBeforeAct << "" << std::endl;
return ss.str();
}
@ -196,6 +200,7 @@ struct GemmGatedActConfig
uint32_t const mSharedMemSize{0};
char const* mFunctionName{nullptr};
uint32_t const mNumThreadsPerCTA{0};
char const* mHash{nullptr};
#else
trtllm::gen::CudaRunner* mCudaRunner{nullptr};
#endif

View File

@ -354,6 +354,7 @@ struct GemmConfig
uint32_t const mSharedMemSize{0};
char const* mFunctionName{nullptr};
uint32_t const mNumThreadsPerCTA{0};
char const* mHash{nullptr};
#else
trtllm::gen::CudaRunner* mCudaRunner{nullptr};
#endif
@ -526,6 +527,7 @@ inline int32_t getShuffleBlockSize(int epilogueTileM)
inline bool checkAndUpdateGemmOptions(
GemmOptions& options, bool isBlackwell, int /* tpGrpSize */, bool updateOptions = true)
{
if (options.mDtypeB == tg::Dtype::Void)
{
if (updateOptions)
@ -566,7 +568,8 @@ inline bool checkAndUpdateGemmOptions(
// Currently, we only support {MxFp4, NvFp4} -> Bf16.
TLLM_CHECK_ERROR((options.mDtypeA == options.mDtypeMmaA)
|| ((options.mDtypeA == tg::Dtype::MxE2m1 || options.mDtypeA == tg::Dtype::E2m1)
&& options.mDtypeMmaA == tg::Dtype::Bfloat16),
&& options.mDtypeMmaA == tg::Dtype::Bfloat16)
|| (options.mDtypeA == tg::Dtype::E2m1 && options.mDtypeMmaA == tg::Dtype::E4m3),
"Unsupported cast for A: ", tg::dtypeToString(options.mDtypeA), " -> ", tg::dtypeToString(options.mDtypeMmaA));
// Check that the B cast is supported.

View File

@ -0,0 +1,547 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
namespace batchedGemm
{
// This is device code
struct KernelParams
{
//////////////////////////////////////////////////////////////////////////////////////////////////
//
// BatchedGemm parameters.
//
//////////////////////////////////////////////////////////////////////////////////////////////////
// Maximum number of CTAs
static constexpr int MaxNumCtas = 2048;
// NOTE: TMA out-of-bounds optimization for MoE padded tokens:
//
// Originally the padded tokens is a 2D tensor [hiddenDim, ctaGridDimY * tileN] with stride [1,
// hiddenDim] and box size [tileM, tileN] at pointer p. We waste bandwidth bytes since we only
// want to load [0, batchEnd) out of the [0, tileN) box size: batchEnd is a runtime variable while
// box size needs to be fixed at compile time.
//
// To deal with this, we reshape the tensor to 3D: [hiddenDim, tileN, ctaGridDimY * tileN] with
// stride [1, hiddenDim, hiddenDim] and box size [tileM, tileN, 1]. For the original 2D
// tensor,
//
// Offset Coords [ : , ctaIdxY * tileN ],
// Box Sizes [ : , tileN ],
// Coords Range [ : , ctaIdxY * tileN : ctaIdxY * tileN + tileN],
//
// while we only want load the range [ctaIdxY * tileN, ctaIdxY * tileN + batchEnd), 1 <= batchEnd
// <= tileN
//
// For the reshaped 3D tensor,
//
// Offset Coords [ : , tileN - batchEnd ,
// ctaIdxY * tileN + batchEnd ],
// Box Sizes [ : , tileN ,
// 1 ],
// Coords Range [ : , tileN - batchEnd : min(tileN, 2 * tileN - batchEnd),
// ctaIdxY * tileN + batchEnd : ctaIdx * tileN + batchEnd + 1],
//
// while min(tileN, 2 * tileN - batchEnd) always evaluates to tileN. The unwanted tokens are
// essentially filtered out by utilizing the OOB feature of TMA. Since the 2nd and 3rd dimension
// has the same stride, we end up loading the following (adding the left and right end of the 2nd
// and 3rd dimension ranges):
//
// Effective 2D Coords Range
// [ : , tileN + ctaIdxY * tileN : tileN + ctaIdxY * tileN + batchEnd],
//
// This is exactly the same as the original range except for the offset tileN, thus we also need
// to offset the pointer in the opposite direction:
//
// Ptr (p) -> Ptr (p - tileN * hiddenDim)
//
// Due to the restrictions of TMA unit, the above operations requires the TMA descriptor and the
// underlying buffer be constructed differently:
// - Requires valid buffer at (p - tileN * hidden) - needs prepending `tileN` tokens.
// - TMA outermost dimension must be extended by `tileN` or loads will OOB in the rightmost side.
// The latter is because when batchEnd == tileN, the offset coords in the 3rd dimension becomes
// ctaIdxY * tileN + tileN. When ctaIdxY = ctaGridDimY - 1, it becomes ((ctaGridDimY - 1) * tileN
// + tileN = ctaGridDimY * tileN which is equal to the 3rd dimension size and will be filtered
// out. That's why we need to extend the tensor size by tileN.
//
// TMA descriptor for A.
// Must be setup using gemm::buildNdTmaDescriptor with shapes and strides from
// makeTmaShapeStrideAbc.
//
// If batchM:
// Logical shape is [sum(divUpMul(M[bi], tileM) for bi in B), K].
// Logical strides are [K, 1].
// Tile box shape is [tileM, tileK].
// Tile box strides are [tileK, 1].
//
// If batchN:
// If layoutA is MatrixLayout::MajorK
// Logical shape is [B, divUpMul(M, tileM), K].
// Logical strides are [divUpMul(M, tileM) * K, K, 1].
// Tile box shape is [1, tileM, tileK].
// Tile box strides are [0, tileK, 1].
// If layoutA is MatrixLayout::Mn
// Logical shape is [B, K, divUpMul(M, tileM)].
// Logical strides are [K * divUpMul(M, tileM), divUpMul(M, tileM), 1].
// Tile box shape is [1, tileK, tileM].
// Tile box strides are [0, tileM, 1].
// If layoutA is MatrixLayout::BlockMajorK
// Logical shape is [B, K / blockK, divUpMul(M, tileM), blockK].
// Logical strides are [K * divUpMul(M, tileM), divUpMul(M, tileM) * blockK, blockK, 1].
// Tile box shape is [1, tileK / min(blockK, tileK), tileM, min(blockK, tileK)].
// Tile box strides are [0, tileM * min(blockK, tileK), min(blockK, tileK), 1].
// where blockK is 128B.
//
// Dtype is set from options.mDtypeA.
CUtensorMap tmaA[1];
// TMA descriptor for B.
// Must be setup using gemm::buildNdTmaDescriptor with shapes and strides from
// makeTmaShapeStrideAbc.
//
// If batchM:
// If layoutB is MatrixLayout::MajorK
// Logical shape is [B, divUpMul(N, tileN), K].
// Logical strides are [divUpMul(N, tileN) * K, K, 1].
// Tile box shape is [1, tileN, tileK].
// Tile box strides are [0, tileK, 1].
// If layoutB is MatrixLayout::MajorMn
// Logical shape is [B, K, divUpMul(N, tileN)].
// Logical strides are [K * divUpMul(N, tileN), divUpMul(N, tileN), 1].
// Tile box shape is [1, tileK, tileN].
// Tile box strides are [0, tileN, 1].
// If layoutB is MatrixLayout::BlockMajorK
// Logical shape is [B, K / blockK, divUpMul(N, tileN), blockK].
// Logical strides are [K * divUpMul(N, tileN), divUpMul(N, tileN) * blockK, blockK, 1].
// Tile box shape is [1, tileK / min(blockK, tileK), tileN, min(blockK, tileK)].
// Tile box strides are [0, tileN * min(blockK, tileK), min(blockK, tileK), 1].
// where blockK is 128B.
//
// If batchN:
// Logical shape is [sum(divUpMul(N[bi], tileN) for bi in B), K].
// Logical strides are [K, 1].
// Tile box shape is [tileN, tileK].
// Tile box strides are [tileK, 1].
//
// Dtype is set from options.mDtypeB.
CUtensorMap tmaB[1];
// TMA descriptor for C, (when useTmaStore is true)
// Must be setup using gemm::buildNdTmaDescriptor with shapes and strides from
// makeTmaShapeStrideAbc.
//
// If batchM:
// Logical shape is [sum(divUpMul(M[bi], tileM) for bi in B), N].
// Logical strides are [N, 1].
// Tile box shape is [epilogueTileM, epilogueTileN].
// Tile box strides are [epilogueTileN, 1].
//
// If batchN:
// Logical shape is [sum(divUpMul(N[bi], tileN) for bi in B), M].
// Logical strides are [M, 1].
// Tile box shape is [epilogueTileN, epilogueTileM].
// Tile box strides are [epilogueTileM, 1].
//
// Dtype is set from options.mDtypeC.
CUtensorMap tmaC[1];
// TMA descriptor for the block scaling factors for A, for MxFp{4,8} and NvFp4 formats.
// Must be setup using gemm::buildSfTmaDescriptor with shapes and strides from
// makeTmaShapeStrideSfAb.
// The layout of scaling factors for A is always R128c4.
//
// Let P be the number of elements per SF. P=16 for NvFp4, P=32 for Mx formats.
// M must be a multiple of 128.
// K must be a multiple of 4P.
// The "logical" shape is: [paddedM, K / P], where paddedM is
// sum(divUpMul(M[bi], tileM) for bi in B) if batchM,
// otherwise divUpMul(M, TileM) * B.
// The R128c4 layout is: [paddedM / 128, K / P / 4, 512].
// The shape we use for TMA is: [paddedM / 128, K / P / 4, 2, 256].
//
// Dtype is Dtype::E4m3 for NvFp4, Dtype::UE8m0 for Mx formats.
CUtensorMap tmaSfA[1];
// TMA descriptor for the block scaling factors for B, for MxFp{4,8} and NvFp4 formats.
// Must be setup using gemm::buildSfTmaDescriptor with shapes and strides from
// makeTmaShapeStrideSfAb.
// The layout of block scaling factors for B is controlled by options.mSfLayoutB.
//
// Let P be the number of elements per SF. P=16 for NvFp4, P=32 for Mx formats.
// The "logical" shape is: [paddedN, K / 16]
// where paddedN is sum(divUpMul(N[bi], tileN) for bi in B) if batchN,
// otherwise divUpMul(N, TileN) * B.
//
// If the layout is R128c4,
// paddedN must be a multiple of 128.
// K must be a multiple of 4P.
// The R128c4 layout is: [paddedN / 128, K / P / 4, 512]
// The shape we use for TMA is: [paddedN / 128, K / P / 4, 2, 256]
//
// If the layout is R8c4,
// paddedN must be a multiple of 8.
// K must be a multiple of 4P.
// The R8c4 layout is: [paddedN / 8, K / P / 4, 32]
// The shape we use for TMA is: [paddedN / 8, K / P / 4 / repeats, repeats * 32]
// where repeats = min(tileK / P / 4, 8)
//
// Dtype is Dtype::E4m3 for NvFp4, Dtype::UE8m0 for Mx formats.
CUtensorMap tmaSfB[1];
// The input matrix A.
// If (routeAct == true && batchM), the shape is [M, K]. tmaA is not used.
// Otherwise, check layout of tmaA to see the shape and strides.
void const* ptrA{nullptr};
// The stride for matrix A in bytes.
// Equals to K * dtypeGetNumBits(dtypeA) / 8.
uint64_t strideInBytesA;
// The input matrix B.
// If (routeAct == true && batchN), the shape is [N, K]. tmaB is not used.
// Otherwise, check layout of tmaB to see the shape and strides.
void const* ptrB{nullptr};
// The stride for matrix B in bytes.
// Equals to K * dtypeGetNumBits(dtypeB) / 8.
uint64_t strideInBytesB;
// The output matrix C. Check "logical" layout of tmaC to see the shape and strides.
void* ptrC{nullptr};
// Inputs and output are MxFp{4,8}, Fp8, NvFp4.
// The scaling factors to apply to the output - can be used to incorporate input scaling factors
// as described below: C = SEncC * act(SDecA * SDecB * A * Bl) . (SDecA * SDecB * A * Br)
// -> ScaleGate = SDecA * SDecB
// ScaleC = SDecA * SDecB * SEncC
//
// Only the inputs are MxFp{4,8}, Fp8, NvFp4.
// C = act(SDecA * SDecB * A * Bl) . (SDecA * SDecB * A * Br)
// -> ScaleGate = SDecA * SDecB
// ScaleC = SDecA * SDecB
//
// Only the output is MxFp{4,8}, Fp8, NvFp4.
// C = SEncC * act(A * Bl) . (A * Br)
// -> ScaleGate = 1
// ScaleC = SEncC
//
// The output tensor scaling factor for MxFp{4,8}, Fp8, NvFp4 and DeepSeek FP8 quantization.
// TensorRT-LLM API requires a scaling factor on the device.
// Shape is [B]. One scaling factor per tensor in batch.
float const* ptrScaleC{nullptr};
// The output gate scale for MxFp{4,8}, Fp8, NvFp4 and DeepSeek FP8 quantization.
// TensorRT-LLM API requires a scaling factor on the device.
// Shape is [B]. One scaling factor per tensor in batch.
float const* ptrScaleGate{nullptr};
// The clamp limit before the activation.
// Shape is [B].
// Clamp is INF if nullptr.
// If applied on SwiGlu, it will be:
//
// x_glu = x_glu.clamp(min=None, max=limit)
// x_linear = x_linear.clamp(min=-limit, max=limit)
float const* ptrClampLimit{nullptr};
// The alpha and beta for SwiGlu.
// Shape is [B]. One alpha and one beta per tensor in batch.
// Alpha is 1.f if nullptr.
// Beta is 0.f if nullptr.
// The formula:
//
// out_glu = x_glu * torch.sigmoid(alpha * x_glu) * (x_linear + beta)
float const* ptrSwiGluAlpha{nullptr};
float const* ptrSwiGluBeta{nullptr};
// The K dimension. It is the hidden dimension of the input matrices.
int32_t k;
// The non-batched dimension.
// It is N if batchM, otherwise M.
int32_t nm;
// Tile stride per batch for the non-batched dimension.
// It is N / TileN if batchM, otherwise M / TileM.
int32_t tileStridePerBatch;
// TODO get rid of that.
// DeepSeek FP8 scaling factors for C
float* ptrDqSfsC{nullptr};
// The block scaling factors for A.
// The pointer must always be set regardless of the quantization recipe.
// If (routeAct == true && batchM), the shape is [M, K / 16]. tmaSfA is not used.
// For the layout (r128c4), see below.
// Otherwise,
// If MxFp{4,8} and NvFp4 formats are used,
// check the "logical" layout of tmaSfA to see the shape and strides.
// The dtype is Dtype::E4m3.
//
// If DeepSeek FP8 quantization recipe is used,
// If batchM:
// The shape is [K / 128, paddedM],
// where paddedM is sum(divUpMul(M[bi], tileM) for bi in B).
// If batchN:
// The shape is [M / 128, K / 128],
// The rightmost dimension is contiguous in memory.
// The dtype is Dtype::Float32.
void const* ptrSfA{nullptr};
// The block scaling factors for B.
// The pointer must always be set regardless of the quantization recipe.
// If (routeAct == true && batchN), the shape is [N, K / 16]. tmaSfB is not used.
// For the layout (r128c4, r8c4), see below.
// Otherwise,
// If MxFp{4,8} and NvFp4 formats are used,
// check the layout of tmaSfB to see the shape and strides.
// The dtype is Dtype::E4m3.
//
// If DeepSeek FP8 quantization recipe is used,
// If batchM:
// The shape is [N / 128, K / 128],
// If batchN:
// The shape is [K / 128, paddedN],
// where paddedN is sum(divUpMul(N[bi], tileN) for bi in B).
// The rightmost dimension is contiguous in memory.
// The dtype is Dtype::Float32.
void const* ptrSfB{nullptr};
// The per-token scaling factors from scale A.
//
// This is used for either:
// * Per-token scaling factor quantization schemes, such as MetaFP8. The dtype is Dtype::Float32
// * When the routing scales are applied to the input activations (only when output is not
// transposed). The dtype is Dtype::Bfloat16
//
// if (batchM (A is activations)):
// Logical shape is [sum(divUpMul(M[bi], tileM) for bi in B)]
//
// if (batchN (A is weights)):
// Logical shape is [B, divUpMul(M, tileM)]
//
void const* ptrPerTokenSfA{nullptr};
// The per-token scaling factors from scale B.
//
// This is used for either:
// * Per-token scaling factor quantization schemes, such as MetaFP8. The dtype is Dtype::Float32
// * When the routing scales are applied to the input activations (only when output is
// transposed). The dtype is Dtype::Bfloat16
//
// if (batchM (B is weights)):
// Logical shape is [B, divUpMul(N, tileN)]
//
// if (batchN (B is activations)):
// Logical shape is [sum(divUpMul(N[bi], tileN) for bi in B)]
void const* ptrPerTokenSfB{nullptr};
// The bias applied after the GEMM and before the activation function.
// The bias is applied before applying the global scaling factor. I.e.
// C = act(A * B + bias') * scaleC
// scaleC = dequantA * dequantB * quantC
// Thus, the bias' = bias / (dequantA * dequantB), where the bias is the original bias.
//
// If batchM, BiasType must be N, and bias shape is [B, N].
// The bias is broadcasted along the M dimension.
//
// If batchNm BiasType must be M, and bias shape is [B, M].
// The bias is broadcasted along the N dimension.
//
// The dtype is float32.
void const* ptrBias{nullptr};
// The output block scaling factors for C.
//
// If MxFp{4,8} and NvFp4 formats are used,
// The "logical" shape is:
// if batchM: [paddedM, N / 16]
// if batchN: [paddedN, M / 16]
// where paddedM is sum(divUpMul(M[bi], tileM) for bi in B),
// where paddedN is sum(divUpMul(N[bi], tileN) for bi in B).
//
// If the layout is R128c4,
// paddedOuter must be a multiple of 128.
// inner must be a multiple of 64.
// The R128c4 layout is: [paddedOuter / 128, inner / 16 / 4, 512]
// The shape we use for TMA is: [paddedOuter / 128, inner / 16 / 4, 2, 256]
// where inner = N if batchM, otherwise M.
// where paddedOuter = paddedM if batchM, otherwise paddedN.
//
// If the layout is R8c4,
// paddedOuter must be a multiple of 8.
// inner must be a multiple of 64.
// The R8c4 layout is: [paddedOuter / 8, inner / 16 / 4, 32]
// The shape we use for TMA is: [paddedOuter / 8, inner / 16 / 4 / repeats, repeats * 32]
// where repeats = min(tileInner / 16 / 4, 8),
// where tileInner = tileN if batchM, otherwise tileM,
// where paddedOuter = paddedM if batchM, otherwise paddedN.
// where inner = N if batchM, otherwise M.
//
// The dtype is Dtype::E4m3.
//
// If DeepSeek FP8 quantization recipe is used,
// If batchM:
// The shape is [N / 128, paddedM],
// where paddedM is sum(divUpMul(M[bi], tileM) for bi in B).
// If batchN:
// The shape is [M / 128, paddedN],
// where paddedN is sum(divUpMul(N[bi], tileN) for bi in B).
// The rightmost dimension is contiguous in memory.
// The dtype is Dtype::Float32.
void* ptrSfC{nullptr};
//////////////////////////////////////////////////////////////////////////////////////////////////
//
// Routing activations parameters.
//
//////////////////////////////////////////////////////////////////////////////////////////////////
// These params are used when the kernel is configured with -routeAct true.
// The inputs are not padded, but the outputs are padded to divUpMul(M[bi], tileM) for batchM or
// divUpMul(N[bi], tileN) for batchN.
// If -routeAct is false, the params are not used and should be set to zero.
// The routeMap for the input tokens.
// Map of expanded token index (counting the previous padded tokens) to the batch index
// the token belongs to.
// The shape is
// [sum(divUpMul(M[bi], tileM) for bi in B)] for batchM
// [sum(divUpMul(N[bi], tileN) for bi in B)] for batchN
// The dtype is int32_t.
//
// There are 3 tokens [0, 1, 2] such that [0, 1] belong to batch [B0] and [2] to batch [B1].
// Let's assume that the padded size is 4.
//
// The expanded indices for tokens [0, 1, 2] are:
// expandedIdx[0] = 0
// expandedIdx[1] = 1
// expandedIdx[2] = divUpMul(2, 4) + 0 = 4
//
// The route map is [B0, B0, X, X, B1, X, X, X] where X could be any value.
int32_t const* ptrRouteMap{nullptr};
// Total number of unpadded inputs
int32_t numTokens;
// Total number of batches
int32_t numBatches;
//////////////////////////////////////////////////////////////////////////////////////////////////
//
// Batching information parameters.
//
//////////////////////////////////////////////////////////////////////////////////////////////////
// In some cases, some CTAs must early-exit. E.g. when the grid size is set statically, but the
// actual workload is decided at runtime. This element on the device contains the number of CTAs
// that do not early-exit. The number corresponds to the X dim of the grid when the output is not
// transposed (i.e. batchM). To the Y dim, otherwise.
// The size is 1 and the dtype is int32_t.
// Used if isStaticBatch == false, otherwise set to nullptr.
// The pointer points to a scalar and the dtype is int32_t. The pointed value must be >= 0.
int32_t const* ptrNumNonExitingCtas{nullptr};
// Pointer to total number of padded tokens.
// Computed as
// int32_t totalNumPaddedTokens{0};
// for (int bi = 0; bi < options.mNumBatches; bi++) {
// totalNumPaddedTokens += batchM ? divUpMul(options.mBatchedM[bi], options.mTileM)
// : divUpMul(options.mBatchedN[bi], options.mTileN);
// }
// The size is 1 and the dtype is int32_t.
// If isStaticBatch == true, ptrTotalNumPaddedTokens should be set to nullptr and
// totalNumPaddedTokens is used.
int32_t const* ptrTotalNumPaddedTokens{nullptr};
// Pointer to the map from the CTA index (in X/Y dim) to the batch index.
// Maps CTA index in batch dim (i.e. blockDim.x if batchM, otherwise blockDim.y)
// to batch index.
// E.g. with listM = 128,255,32 and tileM = 128, should be equal to
// ctaIdxXyToBatchIdx = [0, 1, 1, 2]
// If isStaticBatch == true, ptrCtaIdxXyToBatchIdx should be set to nullptr and ctaIdxXyToBatchIdx
// is used.
int32_t const* ptrCtaIdxXyToBatchIdx{nullptr};
// Pointer from the CTA index X/Y to the expanded tile index where the expanded tile index is
// computed as:
//
// int expandedIdx = 0;
// for (int bi = 0; bi < batchIdx-1; ++bi) {
// expandIdx = divUpMul(numTokens[bi], TileM/N);
// }
// expandIdx += <index in the batch>
// E.g. with numTokens = [128,255,32] and tileM = 128, should be equal to
// ptrCtaIdxXyToMnLimit = [128, 256, 383, 416]
int32_t const* ptrCtaIdxXyToMnLimit{nullptr};
// Total number of padded tokens - used as the stride for the activation and C scaling factors.
// Check ptrTotalNumPaddedTokens to see how it is computed.
// If isStaticBatch == true, totalNumPaddedTokens is used, otherwise ptrTotalNumPaddedTokens.
int32_t totalNumPaddedTokens;
// A map from CTA index X/Y to batch index.
// Check ptrCtaIdxXyToBatchIdx to see how it is computed.
// If isStaticBatch == true, ctaIdxXyToBatchIdx is used, otherwise ptrCtaIdxXyToBatchIdx.
int32_t ctaIdxXyToBatchIdx[MaxNumCtas];
// **Expanded** limits for the batched dimension:
// tile * ctaIdxXyToTileIdxMn[ctaIdxXy] -> ctaIdxXyToMnLimit[ctaIdxXy]
// Check ptrCtaIdxXyToMnLimit to see how it is computed.
// If isStaticBatch == true, ctaIdxXyToMnLimit is used, otherwise ptrCtaIdxXyToMnLimit.
int32_t ctaIdxXyToMnLimit[MaxNumCtas];
//////////////////////////////////////////////////////////////////////////////////////////////////
//
// All-reduce parameters.
//
//////////////////////////////////////////////////////////////////////////////////////////////////
// The rank id of the current device in the multi-gpu space.
int rank;
// The number of peer devices in tensor-parallel group.
int tpGrpSize;
//////////////////////////////////////////////////////////////////////////////////////////////////
//
// GatedAct parameters.
//
//////////////////////////////////////////////////////////////////////////////////////////////////
// Pointer for partial row max for DeepSeek FP8 recipe.
// This is temporary storage for the row max results.
// If batchM, the shape is [2, totalNumPaddedTokens, N / 128] and the dtype is float.
// Otherwise, the shape is [2, totalNumPaddedTokens, M / 128] and the dtype is float.
float* ptrPartialRowMax{nullptr};
// Flags in global memory that sync on "exit" for row max computation.
// The shape is [numTilesM * numTilesN / 2] and the dtype is uint32_t, where
// if batchM,
// numTilesM = divUp(totalNumPaddedTokens, tileM).
// numTilesN = divUp(N, tileN).
// Otherwise,
// numTilesM = divUp(M, tileM).
// numTilesN = divUp(totalNumPaddedTokens, tileN).
//
// The memory must be set to 0 before the kernel launch.
uint32_t* ptrRowMaxCompletionBars{nullptr};
};
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace batchedGemm

View File

@ -20,6 +20,7 @@
#include "trtllm/gen/CommonUtils.h"
#include "trtllm/gen/DtypeDecl.h"
#include <cassert>
#include <stdexcept>
namespace batchedGemm
{
@ -77,6 +78,38 @@ public:
}
// Returns the offset of the ith chunk
int32_t getChunkOffsetByName(std::string const& name) const
{
for (size_t ii = 0; ii < mSmemChunkNames.size(); ++ii)
{
if (mSmemChunkNames[ii] == name)
{
return getChunkOffset(ii);
}
}
throw std::runtime_error("Name not found: " + name);
}
// Returns the first chunk reuse flag given chunk name.
int getFirstChunkReuseFlagByName(std::string const& name) const
{
for (size_t ii = 0; ii < mSmemChunkNames.size(); ++ii)
{
if (mSmemChunkNames[ii] == name)
{
return getFirstChunkReuseFlag(ii);
}
}
throw std::runtime_error("Name not found: " + name);
}
// Function to calculate the total size of the SMEM array
int32_t getTotalSize() const
{
return getOffsetBeforeChunk(static_cast<int32_t>(mNumBytesAndAlignmentPerSmemChunk.size()));
}
private:
int32_t getChunkOffset(int32_t ii) const
{
if (mFirstChunkReuse[ii])
@ -91,12 +124,6 @@ public:
return getSizePaddedToAlignment(offset, mNumBytesAndAlignmentPerSmemChunk[ii].second);
}
// Function to calculate the total size of the SMEM array
int32_t getTotalSize() const
{
return getOffsetBeforeChunk(static_cast<int32_t>(mNumBytesAndAlignmentPerSmemChunk.size()));
}
// Returns the first chunk reuse flag for the ith chunk.
int getFirstChunkReuseFlag(int32_t ii) const
{
@ -139,9 +166,7 @@ int getNumSmemBitsPerElt(tg::Dtype dtype, tg::MmaKind mmaKind)
{
if (mmaKind == tg::MmaKind::Auto)
{
std::cout << "mmaKind != tg::MmaKind::Auto" << std::endl;
assert(false);
return -1;
throw std::runtime_error("mmaKind != tg::MmaKind::Auto");
}
if (mmaKind == tg::MmaKind::MxFp8Fp6Fp4)
{
@ -541,14 +566,14 @@ inline int32_t getTmemBufferSize(KernelTraits traits)
inline int32_t getSmemOffsetLoadA(KernelTraits traits)
{
return traits.mSmemAllocatorHelper.getChunkOffset(0);
return traits.mSmemAllocatorHelper.getChunkOffsetByName("smemLoadA");
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int32_t getSmemOffsetLoadB(KernelTraits traits)
{
return traits.mSmemAllocatorHelper.getChunkOffset(1);
return traits.mSmemAllocatorHelper.getChunkOffsetByName("smemLoadB");
}
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -562,64 +587,63 @@ inline int32_t getSmemOffsetLoadAb(KernelTraits traits)
inline int32_t getSmemOffsetLoadShuffleB(KernelTraits traits)
{
return traits.mSmemAllocatorHelper.getChunkOffset(2);
return traits.mSmemAllocatorHelper.getChunkOffsetByName("smemBShuffle");
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int32_t getSmemOffsetGmemC(KernelTraits traits, int resIdx = 0)
{
return traits.mSmemAllocatorHelper.getChunkOffset(3 + resIdx);
return traits.mSmemAllocatorHelper.getChunkOffsetByName("smemGmemC" + std::to_string(resIdx));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int32_t getSmemOffsetRowMax(KernelTraits traits)
{
return traits.mSmemAllocatorHelper.getChunkOffset(5);
return traits.mSmemAllocatorHelper.getChunkOffsetByName("smemRowMax");
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int32_t getSmemOffsetSliceK(KernelTraits traits)
{
return traits.mSmemAllocatorHelper.getChunkOffset(6);
return traits.mSmemAllocatorHelper.getChunkOffsetByName("smemSliceK");
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int32_t getSmemOffsetPerTokenSf(KernelTraits traits)
{
return traits.mSmemAllocatorHelper.getChunkOffset(7);
return traits.mSmemAllocatorHelper.getChunkOffsetByName("smemPerTokenSf");
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int32_t getSmemOffsetBias(KernelTraits traits)
{
return traits.mSmemAllocatorHelper.getChunkOffset(8);
return traits.mSmemAllocatorHelper.getChunkOffsetByName("smemBias");
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int32_t getSmemOffsetBlockAmax(KernelTraits traits)
{
return traits.mSmemAllocatorHelper.getChunkOffset(9);
return traits.mSmemAllocatorHelper.getChunkOffsetByName("smemBlockAmax");
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int32_t getSmemOffsetConstSfBuf(KernelTraits traits)
{
return traits.mSmemAllocatorHelper.getChunkOffset(10);
return traits.mSmemAllocatorHelper.getChunkOffsetByName("smemConstSfBuf");
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int32_t isSmemAbRepurposedToGmemC(KernelTraits traits, int resIdx = 0)
{
// Be conscious that the index (3 + resIdx) should match the index in getSmemOffsetGmemC().
return traits.mSmemAllocatorHelper.getFirstChunkReuseFlag(3 + resIdx);
return traits.mSmemAllocatorHelper.getFirstChunkReuseFlagByName("smemGmemC" + std::to_string(resIdx));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -630,28 +654,28 @@ inline int32_t isSmemAbRepurposedToGmemC(KernelTraits traits, int resIdx = 0)
inline int32_t getTmemOffsetD(KernelTraits traits)
{
return traits.mTmemAllocatorHelper.getChunkOffset(0);
return traits.mTmemAllocatorHelper.getChunkOffsetByName("tmemD");
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int32_t getTmemOffsetA(KernelTraits traits)
{
return traits.mTmemAllocatorHelper.getChunkOffset(1);
return traits.mTmemAllocatorHelper.getChunkOffsetByName("tmemA");
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int32_t getTmemOffsetSfA(KernelTraits traits)
{
return traits.mTmemAllocatorHelper.getChunkOffset(2);
return traits.mTmemAllocatorHelper.getChunkOffsetByName("tmemSfA");
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline int32_t getTmemOffsetSfB(KernelTraits traits)
{
return traits.mTmemAllocatorHelper.getChunkOffset(3);
return traits.mTmemAllocatorHelper.getChunkOffsetByName("tmemSfB");
}
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -181,6 +181,8 @@ inline CUtensorMap buildNdTmaDescriptor(tg::Dtype dtype, tg::MmaKind mmaKind, st
if (result != CUDA_SUCCESS)
{
char const* errorString;
cuGetErrorString(result, &errorString);
std::stringstream ss;
ss << "Error: Failed to initialize the TMA descriptor " << result << std::endl;
@ -283,8 +285,10 @@ inline CUtensorMap buildSfTmaDescriptor(tg::Dtype dtype, std::vector<uint64_t> c
if (result != CUDA_SUCCESS)
{
char const* errorString;
cuGetErrorString(result, &errorString);
std::stringstream ss;
ss << "Error: Failed to initialize the TMA descriptor for SF " << result << std::endl;
ss << "Error: Failed to initialize the TMA descriptor for SF " << errorString << std::endl;
ss << "tmaFormat: " << static_cast<int>(tmaDataFormat) << " dim: " << dim << " gmem: " << gmemAddr << std::endl;

View File

@ -12,7 +12,6 @@
"epilogueTileM": 128,
"epilogueTileN": 8,
"numStages": 4,
"numMmaStages": 1,
"numSlicesForSplitK": 1,
"useTwoTmaLoadWarps": true,
"clusterDimX": 1,
@ -30,7 +29,6 @@
"sfLayoutB": "8x4",
"sfLayoutC": "8x4",
"batch": "N",
"useMetaFp8": false,
"numExperts": 128,
"useCudaGraph": true
},
@ -46,7 +44,6 @@
"epilogueTileM": 128,
"epilogueTileN": 8,
"numStages": 3,
"numMmaStages": 1,
"numSlicesForSplitK": 1,
"useTwoTmaLoadWarps": true,
"clusterDimX": 1,
@ -62,7 +59,6 @@
"gridWaitForPrimaryA": false,
"gridWaitForPrimaryB": true,
"batch": "N",
"useMetaFp8": false,
"numExperts": 128,
"useCudaGraph": true
},
@ -97,7 +93,6 @@
"hoistMmaTaskTryWaits": true,
"numStagesMma": 4,
"batch": "N",
"useMetaFp8": false,
"numExperts": 128,
"useCudaGraph": true
}
@ -107,7 +102,6 @@
"_template": "BatchedGemmFp4LowLatency",
"routeAct": false,
"fusedAct": false,
"useRoutingScalesOnInput": false,
"useUnrollLoop2xForMma": [true, false],
"dtypeC": ["bf16", "fp16", "e2m1"],
"listN": "8,8",
@ -119,7 +113,6 @@
"_template": "BatchedGemmPerTensorScalingFp8LowLatency",
"routeAct": false,
"fusedAct": false,
"useRoutingScalesOnInput": false,
"useUnrollLoop2xForMma": [true, false],
"dtypeC": ["bf16", "fp16", "e4m3"],
"listN": "8,8",
@ -131,7 +124,6 @@
"_template": "BatchedGemmDeepSeekFp8LowLatency",
"routeAct": false,
"fusedAct": false,
"useRoutingScalesOnInput": false,
"useUnrollLoop2xForMma": [true, false],
"dtypeC": ["bf16", "fp16", "e4m3"],
"listN": "8,8",
@ -145,7 +137,6 @@
"routeAct": true,
"fusedAct": true,
"sfLayoutB": "linear",
"useRoutingScalesOnInput": false,
"useUnrollLoop2xForMma": [true, false],
"dtypeC": "e2m1",
"numTokens": 2,
@ -166,7 +157,6 @@
"_template": "BatchedGemmFp4LowLatency",
"routeAct": false,
"fusedAct": false,
"useRoutingScalesOnInput": false,
"useUnrollLoop2xForMma": [true, false],
"dtypeC": "bf16",
"numTokens": 2,
@ -191,7 +181,6 @@
"_template": "BatchedGemmDeepSeekFp8LowLatency",
"routeAct": true,
"fusedAct": false,
"useRoutingScalesOnInput": false,
"useUnrollLoop2xForMma": [true, false],
"dtypeC": "e4m3",
"numTokens": 2,
@ -212,7 +201,6 @@
"_template": "BatchedGemmDeepSeekFp8LowLatency",
"routeAct": false,
"fusedAct": false,
"useRoutingScalesOnInput": false,
"useUnrollLoop2xForMma": [true, false],
"dtypeC": "bf16",
"numTokens": 2,
@ -233,7 +221,6 @@
"_template": "BatchedGemmPerTensorScalingFp8LowLatency",
"routeAct": true,
"fusedAct": true,
"useRoutingScalesOnInput": true,
"useUnrollLoop2xForMma": [true, false],
"dtypeC": "e4m3",
"numTokens": 2,
@ -247,7 +234,6 @@
"_template": "BatchedGemmPerTensorScalingFp8LowLatency",
"routeAct": false,
"fusedAct": false,
"useRoutingScalesOnInput": false,
"useUnrollLoop2xForMma": [true, false],
"dtypeC": "bf16",
"numTokens": 2,

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:afb26c01c09a14d5fd5243d018164e3cb47f6acb26b380810bd24ea20b4deef3
size 549397

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:fd8cf63739a7968321556b7980ce56577f9692921a7f62043295be7e0f9f19eb
size 575499

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9d42abe6ddcac40aaa119421f50038adb4edf9f4e12931ba8528a858bd4de0ca
size 514477

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7882097269f622d56d99f704b42a0331d1fdc59442e803ff9c8f836b67c7d7d2
size 417333

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:15e4735cc2d2e2bf3fdb10698e20303a532939acdfcb6c8c691b2a2d569663a6
size 440671

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f4e0e4a92d078fc240e9006ae271cc08a91f214d31cc40269affe86f83b664ec
size 555211

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b6d33c4ccff36bae9207ff93168d3607d5b0681bc586d50b8c6995b409661403
size 581313

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d0463544c6a80fd23dbdc767dfc96f512b2757dec2d976c917a9850a16f9e6f5
size 527297

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2cb24d3bc83aabe4c4ae6bac627a70fbc5cfc537f917a6c93b8d6830e1e0c373
size 545849

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:08067655e5ad721255bb86ec17617c7217ef9dc1165d1ce4bcc11c5a86dec681
size 415845

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:342721dcf5114816a32a15d31e24711a2aadd52ea4eb578695c502324786f163
size 439973

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:f11f3bc7b0c137577c335e4762068a0de4785a544d85f2575b787360c300f0a8
size 548601

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2a35b8ab05817e3168b36002a8c0bfa45ff6cbe09d3a2db4af8be4dfb95fb6e6
size 574701

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2c4f2e7ba6b10a99f386253c6b9315620804590f2feae7dd78351c3fce34d9ec
size 513681

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:86179f2a7d2ddbffcb8b40a49f5babc31dc6dd80acff882a67a1eae40d396b24
size 532233

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:59566acf5116dd2221842073c5fcea6bcf70eb5ee29b14482e5d4efd33ebadb4
size 415745

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:41933fab6978bc725dad91ccd0539a25f804f1e579a0360d3d6289eab0e076de
size 439873

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ea03bd9dbfb524b5f20de9aa11629b4c56aa23d03162bce221db2823bb227a44
size 695858

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:42aeda3c7a2ba5b2506cfd9d38064a2ed4ea72f93758b1f3f11c375461b36c88
size 575929

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:815661113030c8580cc963c443afc3c82fff7c8f8dd8a0ed98f95a08a91f619a
size 684616

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:75593264df3d37f7d23664c02e426196a5f7ee1cc6de76db84895fca4e706c97
size 562811

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cc2a320b8589746d4193fab08c40f2221bc31c4d6f0e65db57386d6a38c75a19
size 718062

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ceb630ea13703b70566398f5421b6fa5c398b72cf415cfcb34479ad59858adf4
size 598083

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:11afdd6753faeffe1446abed54c61ccf0e190ff886fc0270d055a4dbce9a9298
size 705390

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7e4ecf7abe90d1a7f46a1ff3005cad8ebe0de8e7cf89862cf993f228a98ea00d
size 582747

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1245c751dbef978b1c4befcf5bd43a8f2c9579d82bda49310b8ff3a5f9b51555
size 677458

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ce89e844f370baa26823f610f9f4bf92febd85edcf73cfa48aa79a24b2acad55
size 544207

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7cccf1e15a06823668068d7d8c762b9e9152f9b16b00e9f9096ea525c8fd2fc2
size 541229

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e921f66e6adc5292b99df5cfbae4a9cbae6182c8e99bbc83ea30bd1ca8ed8f55
size 667892

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3b6994776b67405b359d471fa5c873efa4714dd71397ddfd82a52d59cbf20a9a
size 550035

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:04e231d8aefe68dce4ddf7ad71a02d4ab36ee1b22896526a2e3415be9bd2704c
size 710466

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ec9c13b6822a66c5c2f4be788e25b82e0373728ce65beacc599967a38a3d1dc1
size 570309

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ebad2f94ecd782e421bb1086268d9d6906e86f119995bbcc8a4da01233c6e65a
size 567379

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ec5cf07ec8b7a4405c305fb82f9eb7179a4a43ab14a2eacfadc35072b317cfd7
size 700704

Some files were not shown because too many files have changed in this diff Show More