[TRTLLM-9372][feat] Enable CuteDSL MoE with Large EP (#9592)

Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
This commit is contained in:
Enwei Zhu 2025-12-06 14:08:52 +08:00 committed by GitHub
parent c2f2add6df
commit 7cd5a67e25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 757 additions and 379 deletions

View File

@ -38,6 +38,7 @@
#include <optional>
#include <sstream>
#include <string>
#include <unordered_map>
#ifndef _WIN32 // Linux
#include <sys/sysinfo.h>
#endif // not WIN32
@ -432,6 +433,21 @@ inline int getMaxSharedMemoryPerBlockOptin()
return nByteMaxSharedMemoryPerBlockOptin;
}
template <typename T>
inline int getMaxActiveBlocksPerSM(T kernel, int blockSize, size_t dynamicSMemSize)
{
static std::unordered_map<T, int> cache;
auto it = cache.find(kernel);
if (it != cache.end())
{
return it->second;
}
int numBlocks;
check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocks, kernel, blockSize, dynamicSMemSize));
cache[kernel] = numBlocks;
return numBlocks;
}
template <typename T1, typename T2>
inline size_t divUp(T1 const& a, T2 const& b)
{

View File

@ -67,7 +67,7 @@ __global__ void moePermuteKernel(InputType const* input, InputType* permuted_out
int64_t const kCopyPerToken = hidden_size / kElemPerCopy;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
cudaGridDependencySynchronize();
#endif
int32_t const num_tokens = num_non_exiting_tiles[0] * tile_size;
@ -110,7 +110,7 @@ __global__ void moePermuteKernel(InputType const* input, InputType* permuted_out
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
@ -141,11 +141,11 @@ void moePermute(InputType const* input, InputType* permuted_output, SFType const
}
#endif
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
int32_t const blocks = std::min(smCount * 8, max_num_permuted_tokens);
int32_t const threads = kThreadsPerBlock;
auto kernel = &moePermuteKernel<InputType, SFType, kSFVecSize, kThreadsPerBlock>;
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM(kernel, kThreadsPerBlock, 0);
int32_t const blocks = std::min(smCount * maxBlocksPerSM, max_num_permuted_tokens);
int32_t const threads = kThreadsPerBlock;
cudaLaunchConfig_t config;
config.gridDim = blocks;
@ -195,7 +195,7 @@ __global__ void moeUnpermuteKernel(InputType const* permuted_input, InputType* o
int32_t const token_idx = blockIdx.x;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
cudaGridDependencySynchronize();
#endif
auto* dst_ptr = reinterpret_cast<ElemCopyType*>(output) + token_idx * kCopyPerToken;
@ -232,7 +232,7 @@ __global__ void moeUnpermuteKernel(InputType const* permuted_input, InputType* o
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
@ -277,6 +277,105 @@ INSTANTIATE_MOE_UNPERMUTE(__nv_bfloat16, __nv_bfloat16);
#endif
#undef INSTANTIATE_MOE_UNPERMUTE
template <typename InputType, int32_t kThreadsPerBlock>
__global__ void moeOutputMemsetKernel(InputType* input, int32_t const* tile_idx_to_mn_limit,
int32_t const* expanded_idx_to_permuted_idx, int32_t const* permuted_idx_to_expanded_idx,
int32_t const* num_non_exiting_tiles, int32_t const hidden_size, int32_t const top_k, int32_t const tile_size)
{
int32_t constexpr kElemPerCopy = elemPerCopy<InputType>();
int64_t const kCopyPerToken = hidden_size / kElemPerCopy;
InputType rmem[kElemPerCopy];
#pragma unroll
for (int32_t j = 0; j < kElemPerCopy; j++)
{
rmem[j] = 0;
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
int32_t const num_tokens = num_non_exiting_tiles[0] * tile_size;
for (int32_t permuted_idx = blockIdx.x; permuted_idx < num_tokens; permuted_idx += gridDim.x)
{
int32_t const tile_idx = permuted_idx / tile_size;
if (permuted_idx >= tile_idx_to_mn_limit[tile_idx])
{
continue;
}
int32_t const expanded_idx = permuted_idx_to_expanded_idx[permuted_idx];
int32_t const token_idx = expanded_idx / top_k;
int32_t const topk_idx = expanded_idx % top_k;
bool is_first_in_topk = true;
for (int32_t k = 0; k < topk_idx; k++)
{
if (expanded_idx_to_permuted_idx[token_idx * top_k + k] >= 0)
{
is_first_in_topk = false;
break;
}
}
if (!is_first_in_topk)
{
continue;
}
auto* dst_ptr = reinterpret_cast<ElemCopyType*>(input) + token_idx * kCopyPerToken;
for (int32_t i = threadIdx.x; i < kCopyPerToken; i += kThreadsPerBlock)
{
dst_ptr[i] = *reinterpret_cast<ElemCopyType*>(rmem);
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
template <typename InputType>
void moeOutputMemset(InputType* input, int32_t const* tile_idx_to_mn_limit, int32_t const* expanded_idx_to_permuted_idx,
int32_t const* permuted_idx_to_expanded_idx, int32_t const* num_non_exiting_tiles,
int32_t const max_num_permuted_tokens, int32_t const hidden_size, int32_t const top_k, int32_t const tile_size,
cudaStream_t stream)
{
int32_t constexpr kThreadsPerBlock = 256;
int32_t constexpr kElemPerCopy = elemPerCopy<InputType>();
TLLM_CHECK_WITH_INFO(hidden_size % kElemPerCopy == 0, "hidden_size must be divisible by %d.", kElemPerCopy);
auto kernel = &moeOutputMemsetKernel<InputType, kThreadsPerBlock>;
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM(kernel, kThreadsPerBlock, 0);
int32_t const blocks = std::min(smCount * maxBlocksPerSM, max_num_permuted_tokens);
int32_t const threads = kThreadsPerBlock;
cudaLaunchConfig_t config;
config.gridDim = blocks;
config.blockDim = threads;
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config, kernel, input, tile_idx_to_mn_limit, expanded_idx_to_permuted_idx,
permuted_idx_to_expanded_idx, num_non_exiting_tiles, hidden_size, top_k, tile_size);
}
#define INSTANTIATE_MOE_OUTPUT_MEMSET(InputType) \
template void moeOutputMemset<InputType>(InputType * input, int32_t const* tile_idx_to_mn_limit, \
int32_t const* expanded_idx_to_permuted_idx, int32_t const* permuted_idx_to_expanded_idx, \
int32_t const* num_non_exiting_tiles, int32_t const max_num_permuted_tokens, int32_t const hidden_size, \
int32_t const top_k, int32_t const tile_size, cudaStream_t stream)
INSTANTIATE_MOE_OUTPUT_MEMSET(half);
#ifdef ENABLE_BF16
INSTANTIATE_MOE_OUTPUT_MEMSET(__nv_bfloat16);
#endif
#undef INSTANTIATE_MOE_OUTPUT_MEMSET
template <typename InputType, typename OutputType, typename SFType, int32_t kSFVecSize, typename ActFn,
int32_t kThreadsPerBlock>
__global__ void moeActivationKernel(InputType const* input, OutputType* output, float const* global_sf,
@ -297,7 +396,7 @@ __global__ void moeActivationKernel(InputType const* input, OutputType* output,
ActFn act{};
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
cudaGridDependencySynchronize();
#endif
float global_sf_val = global_sf == nullptr ? 1.0f : global_sf[0];
@ -353,7 +452,7 @@ __global__ void moeActivationKernel(InputType const* input, OutputType* output,
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
@ -382,10 +481,6 @@ void moeActivation(InputType const* input, OutputType* output, float const* glob
}
#endif
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
int32_t const blocks = std::min(smCount * 8, max_num_permuted_tokens);
int32_t const threads = kThreadsPerBlock;
auto get_act_kernel = [](ActivationType activation_type) -> void (*)(InputType const* input, OutputType* output,
float const* global_sf, SFType* output_sf,
int32_t const* tile_idx_to_mn_limit,
@ -424,6 +519,11 @@ void moeActivation(InputType const* input, OutputType* output, float const* glob
};
auto kernel = get_act_kernel(activation_params.activation_type);
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM(kernel, kThreadsPerBlock, 0);
int32_t const blocks = std::min(smCount * maxBlocksPerSM, max_num_permuted_tokens);
int32_t const threads = kThreadsPerBlock;
cudaLaunchConfig_t config;
config.gridDim = blocks;
config.blockDim = threads;

View File

@ -32,6 +32,12 @@ void moeUnpermute(InputType const* permuted_input, InputType* output, int32_t co
TopKScaleType const* topk_scales, int32_t const num_tokens, int32_t const hidden_size, int32_t const top_k,
cudaStream_t stream);
template <typename InputType>
void moeOutputMemset(InputType* input, int32_t const* tile_idx_to_mn_limit, int32_t const* expanded_idx_to_permuted_idx,
int32_t const* permuted_idx_to_expanded_idx, int32_t const* num_non_exiting_tiles,
int32_t const max_num_permuted_tokens, int32_t const hidden_size, int32_t const top_k, int32_t const tile_size,
cudaStream_t stream);
template <typename InputType, typename OutputType, typename SFType>
void moeActivation(InputType const* input, OutputType* output, float const* global_sf, SFType* output_sf,
int32_t const* tile_idx_to_mn_limit, int32_t const* num_non_exiting_tiles,

View File

@ -1587,11 +1587,6 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
int64_t num_padding_tokens = 0;
#endif
static int64_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
// Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200).
int64_t const blocks = std::min(smCount * 8, std::max(num_rows * k, num_padding_tokens));
int64_t const threads = EXPAND_THREADS_PER_BLOCK;
auto func = [&]()
{
#ifdef ENABLE_FP8
@ -1637,6 +1632,12 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
}
}();
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM(func, EXPAND_THREADS_PER_BLOCK, 0);
int32_t const blocks
= std::min(smCount * maxBlocksPerSM, static_cast<int32_t>(std::max(num_rows * k, num_padding_tokens)));
int32_t const threads = EXPAND_THREADS_PER_BLOCK;
cudaLaunchConfig_t config;
config.gridDim = blocks;
config.blockDim = threads;
@ -1891,15 +1892,18 @@ void finalizeMoeRoutingKernelLauncher(GemmOutputType const* expanded_permuted_ro
if (parallelism_config.ep_size > 1 && enable_alltoall)
{
// If all-to-all comm is enabled, finalizeMoeRouting doesn't need to fill the invalid output tokens with zeros.
static int const smCount = tensorrt_llm::common::getMultiProcessorCount();
// Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200).
int64_t const blocks = smCount * 8;
int64_t const threads = FINALIZE_THREADS_PER_BLOCK;
config.gridDim = blocks;
config.blockDim = threads;
auto func = final_scales
? &finalizeMoeRoutingNoFillingKernel<OutputType, GemmOutputType, ScaleBiasType, ScaleMode::DEFAULT>
: &finalizeMoeRoutingNoFillingKernel<OutputType, GemmOutputType, ScaleBiasType, ScaleMode::NO_SCALE>;
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
int32_t const maxBlocksPerSM
= tensorrt_llm::common::getMaxActiveBlocksPerSM(func, FINALIZE_THREADS_PER_BLOCK, 0);
int32_t const blocks = std::min(smCount * maxBlocksPerSM, static_cast<int32_t>(num_rows * experts_per_token));
int32_t const threads = FINALIZE_THREADS_PER_BLOCK;
config.gridDim = blocks;
config.blockDim = threads;
cudaLaunchKernelEx(&config, func, expanded_permuted_rows, reduced_unpermuted_output, bias_ptr, final_scales,
unpermuted_row_to_permuted_row, permuted_row_to_unpermuted_row, token_selected_experts,
expert_first_token_offset, num_rows, padded_cols, unpadded_cols, experts_per_token, num_experts_per_node,
@ -2235,11 +2239,6 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8
int64_t num_padding_tokens = 0;
#endif
static int64_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
// Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200).
int64_t const blocks = std::min(smCount * 8, std::max(expanded_num_tokens, num_padding_tokens));
int64_t const threads = ACTIVATION_THREADS_PER_BLOCK;
auto fn = [&]()
{
// IMPORTANT: Keep the order of the activation functions in the same order as the ActivationType enum in
@ -2302,6 +2301,12 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8
}
}();
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM(fn, ACTIVATION_THREADS_PER_BLOCK, 0);
int32_t const blocks
= std::min(smCount * maxBlocksPerSM, static_cast<int32_t>(std::max(expanded_num_tokens, num_padding_tokens)));
int32_t const threads = ACTIVATION_THREADS_PER_BLOCK;
cudaLaunchConfig_t config;
config.gridDim = blocks;
config.blockDim = threads;

View File

@ -647,7 +647,9 @@ void run(Data& data, void* stream)
//
// The upper bound is a strict requirement. The number of blocks should be determined by querying
// the device properties, or conservatively low.
static int const numBlocksCoop = tensorrt_llm::common::getMultiProcessorCount();
static int const smCount = tensorrt_llm::common::getMultiProcessorCount();
// WAR: Reserve 8 SMs for overlapping kernels.
int const numBlocksCoop = smCount - 8;
// Maximum number of tokens supported by the kernel using a cooperative launch.
int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK;

View File

@ -139,6 +139,8 @@ std::tuple<torch::Tensor, torch::optional<torch::Tensor>> moe_permute(torch::Ten
TORCH_CHECK(tile_idx_to_mn_limit.scalar_type() == torch::kInt32, "tile_idx_to_mn_limit must be int32.");
int64_t const num_tiles = tile_idx_to_mn_limit.size(0);
TORCH_CHECK(permuted_idx_to_expanded_idx.dim() == 1, "permuted_idx_to_expanded_idx must be 1D.");
TORCH_CHECK(
permuted_idx_to_expanded_idx.scalar_type() == torch::kInt32, "permuted_idx_to_expanded_idx must be int32.");
int64_t const max_num_permuted_tokens = permuted_idx_to_expanded_idx.size(0);
TORCH_CHECK(max_num_permuted_tokens == tile_tokens_dim * num_tiles,
"max_num_permuted_tokens must be equal to tile_tokens_dim * num_tiles.");
@ -253,6 +255,69 @@ torch::Tensor moe_unpermute(torch::Tensor const& permuted_input, torch::Tensor c
return output;
}
void moe_output_memset_inplace(torch::Tensor const& input, torch::Tensor const& tile_idx_to_mn_limit,
torch::Tensor const& expanded_idx_to_permuted_idx, torch::Tensor const& permuted_idx_to_expanded_idx,
torch::Tensor const& num_non_exiting_tiles, int64_t const tile_tokens_dim, int64_t const top_k,
int64_t const ep_size, bool const enable_alltoall = false)
{
TORCH_CHECK(input.dim() == 2, "input must be 2D.");
int64_t const num_tokens = input.size(0);
int64_t const hidden_size = input.size(1);
TORCH_CHECK(expanded_idx_to_permuted_idx.dim() == 2, "expanded_idx_to_permuted_idx must be 2D.");
TORCH_CHECK(
expanded_idx_to_permuted_idx.scalar_type() == torch::kInt32, "expanded_idx_to_permuted_idx must be int32.");
TORCH_CHECK(
expanded_idx_to_permuted_idx.size(0) == num_tokens, "expanded_idx_to_permuted_idx.size(0) must be num_tokens.");
TORCH_CHECK(expanded_idx_to_permuted_idx.size(1) == top_k, "expanded_idx_to_permuted_idx.size(1) must be top_k.");
TORCH_CHECK(tile_idx_to_mn_limit.dim() == 1, "tile_idx_to_mn_limit must be 1D.");
TORCH_CHECK(tile_idx_to_mn_limit.scalar_type() == torch::kInt32, "tile_idx_to_mn_limit must be int32.");
int64_t const num_tiles = tile_idx_to_mn_limit.size(0);
TORCH_CHECK(permuted_idx_to_expanded_idx.dim() == 1, "permuted_idx_to_expanded_idx must be 1D.");
TORCH_CHECK(
permuted_idx_to_expanded_idx.scalar_type() == torch::kInt32, "permuted_idx_to_expanded_idx must be int32.");
int64_t const max_num_permuted_tokens = permuted_idx_to_expanded_idx.size(0);
TORCH_CHECK(max_num_permuted_tokens == tile_tokens_dim * num_tiles,
"max_num_permuted_tokens must be equal to tile_tokens_dim * num_tiles.");
TORCH_CHECK(max_num_permuted_tokens >= num_tokens * top_k,
"max_num_permuted_tokens must be greater than or equal to num_tokens * top_k.");
TORCH_CHECK(num_non_exiting_tiles.numel() == 1, "num_non_exiting_tiles must have 1 element.");
TORCH_CHECK(num_non_exiting_tiles.scalar_type() == torch::kInt32, "num_non_exiting_tiles must be int32.");
auto const& stream = at::cuda::getCurrentCUDAStream(input.get_device());
#define DISPATCH_MOE_OUTPUT_MEMSET(InputType) \
do \
{ \
if (!enable_alltoall || ep_size <= top_k) \
{ \
cudaMemsetAsync(input.data_ptr(), 0x0, sizeof(InputType) * num_tokens * hidden_size, stream); \
} \
else \
{ \
tensorrt_llm::kernels::cute_dsl::moeOutputMemset<InputType>(static_cast<InputType*>(input.data_ptr()), \
tile_idx_to_mn_limit.data_ptr<int32_t>(), expanded_idx_to_permuted_idx.data_ptr<int32_t>(), \
permuted_idx_to_expanded_idx.data_ptr<int32_t>(), num_non_exiting_tiles.data_ptr<int32_t>(), \
max_num_permuted_tokens, hidden_size, top_k, tile_tokens_dim, stream); \
} \
} while (0)
if (input.scalar_type() == torch::kHalf)
{
DISPATCH_MOE_OUTPUT_MEMSET(half);
}
else if (input.scalar_type() == torch::kBFloat16)
{
DISPATCH_MOE_OUTPUT_MEMSET(__nv_bfloat16);
}
else
{
TORCH_CHECK(false, "Unsupported input dtype: ", input.scalar_type());
}
#undef DISPATCH_MOE_OUTPUT_MEMSET
}
// Activation
torch::Tensor moe_swiglu(torch::Tensor const& input, torch::Tensor const& tile_idx_to_mn_limit,
@ -421,6 +486,10 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
"moe_permute(Tensor input, Tensor? input_sf, Tensor tile_idx_to_mn_limit, Tensor permuted_idx_to_expanded_idx, "
"Tensor num_non_exiting_tiles, int tile_tokens_dim, int top_k) -> (Tensor, Tensor?)");
m.def("moe_unpermute(Tensor permuted_input, Tensor expanded_idx_to_permuted_idx, Tensor topk_scales) -> Tensor");
m.def(
"moe_output_memset_inplace(Tensor(a!) input, Tensor tile_idx_to_mn_limit, Tensor expanded_idx_to_permuted_idx, "
"Tensor permuted_idx_to_expanded_idx, Tensor num_non_exiting_tiles, int tile_tokens_dim, int top_k, int "
"ep_size, bool enable_alltoall = False) -> ()");
m.def(
"moe_swiglu(Tensor input, Tensor tile_idx_to_mn_limit, Tensor num_non_exiting_tiles, "
"int tile_tokens_dim) -> Tensor");
@ -438,6 +507,7 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
m.impl("moe_sort", &torch_ext::moe_sort);
m.impl("moe_permute", &torch_ext::moe_permute);
m.impl("moe_unpermute", &torch_ext::moe_unpermute);
m.impl("moe_output_memset_inplace", &torch_ext::moe_output_memset_inplace);
m.impl("moe_swiglu", &torch_ext::moe_swiglu);
m.impl("moe_swiglu_nvfp4_quantize", &torch_ext::moe_swiglu_nvfp4_quantize);
m.impl("moe_gelu", &torch_ext::moe_gelu);

View File

@ -730,7 +730,7 @@ class AutoTuner:
# Log the cache miss. Expect no cache miss in inference.
if not is_cache_hit:
logger.warning_once(
f"[AutoTunner] Using the fallback tactic, due to cache miss on input shapes={input_shapes}",
f"[AutoTuner] {custom_op} using the fallback tactic, due to cache miss on input shapes={input_shapes}",
key=(custom_op, "warning_autotuning_cache_miss_fallback"))
return (best_runner, best_tactic)

View File

@ -77,6 +77,13 @@ def inplace_info():
torch.ops.trtllm.logits_bitmask.default: {
1: "logits"
},
torch.ops.trtllm.moe_output_memset_inplace.default: {
1: "input"
},
torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell.default:
{
6: "output"
},
torch.ops.trtllm.pp_recv_tensors.default: {
1: "tensors"
},

View File

@ -149,7 +149,7 @@ class GroupedGemmInputsHelper:
def inputs_pre_hook_finalize_fusion(
self, inputs: List[torch.Tensor]) -> List[torch.Tensor]:
a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, token_final_scales = inputs
a, b, a_sf, b_sf, alpha, output, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, token_final_scales = inputs
num_tokens = self.infer_num_tokens(a.size(0))
num_tokens_per_expert = self.generate_num_tokens_per_expert(num_tokens)
tile_idx_to_group_idx_list = self.generate_tile_idx_to_group_idx(
@ -184,7 +184,7 @@ class GroupedGemmInputsHelper:
[num_non_exiting_tiles_val],
dtype=num_non_exiting_tiles.dtype,
device=num_non_exiting_tiles.device)
return a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, token_final_scales
return a, b, a_sf, b_sf, alpha, output, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, token_final_scales
class FusedMoEInputsHelper:
@ -268,8 +268,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
**kwargs,
) -> List[Tuple[int, int]]:
# Early exit: Check SM version - CuteDSL NVFP4 only supports SM 100 and SM 103
sm_version = get_sm_version()
if sm_version not in [100, 103]:
if (sm_version := get_sm_version()) not in (100, 103):
logger.debug(
f"CuteDSL: SM version {sm_version} is not supported. "
f"CuteDSL NVFP4 only supports SM 100 (B200) and SM 103 (B300). Skipping all tactics."
@ -597,8 +596,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
for automatic backend selection with better performance.
"""
# Validate SM version before attempting to use CuteDSL
sm_version = get_sm_version()
if sm_version not in [100, 103]:
if (sm_version := get_sm_version()) not in (100, 103):
raise ValueError(
f"CuteDSL NVFP4 backend requires SM 100 (B200) or SM 103 (B300), but got SM {sm_version}. "
f"Please use nvfp4_gemm with backend='auto' for automatic backend selection."
@ -660,9 +658,9 @@ if IS_CUTLASS_DSL_AVAILABLE:
self.output_dtype = output_dtype
self.scaling_vector_size = scaling_vector_size
if get_sm_version() != 100:
if (sm_version := get_sm_version()) not in (100, 103):
raise ValueError(
f"SM version {get_sm_version()} is not supported for {self.__class__.__name__}, it only supports SM 100"
f"{self.__class__.kernel_class.__name__} supports SM 100 (B200) and SM 103 (B300) only, but got SM {sm_version}"
)
def unique_id(self):
@ -947,9 +945,9 @@ if IS_CUTLASS_DSL_AVAILABLE:
self.output_dtype = output_dtype
self.scaling_vector_size = scaling_vector_size
if get_sm_version() != 100:
if (sm_version := get_sm_version()) not in (100, 103):
raise ValueError(
f"SM version {get_sm_version()} is not supported for {self.__class__.__name__}, it only supports SM 100"
f"{self.__class__.kernel_class.__name__} supports SM 100 (B200) and SM 103 (B300) only, but got SM {sm_version}"
)
def unique_id(self):
@ -1015,11 +1013,12 @@ if IS_CUTLASS_DSL_AVAILABLE:
helper.map_to_tuning_buckets), ),
constraint_specs=(
ConstraintSpec(2, 0, fp4_scale_infer_shape),
ConstraintSpec(5, 0, helper.infer_shape_max_num_tiles),
ConstraintSpec(5, 0, helper.infer_shape_num_tokens),
ConstraintSpec(6, 0, helper.infer_shape_max_num_tiles),
ConstraintSpec(7, 0, helper.infer_shape_max_num_tiles),
ConstraintSpec(
7, 0, helper.infer_shape_max_num_permuted_tokens),
ConstraintSpec(9, 0, helper.infer_shape_num_tokens)),
8, 0, helper.infer_shape_max_num_permuted_tokens),
ConstraintSpec(10, 0, helper.infer_shape_num_tokens)),
inputs_pre_hook=helper.inputs_pre_hook_finalize_fusion,
use_cuda_graph=True,
)
@ -1027,7 +1026,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
def forward(self, inputs: List[torch.Tensor],
tactic: Optional[tuple]) -> torch.Tensor:
a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, token_final_scales = inputs
a, b, a_sf, b_sf, alpha, c, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, token_final_scales = inputs
assert a.dtype == torch.float4_e2m1fn_x2
assert a.dim() == 2
assert b.dtype == torch.float4_e2m1fn_x2
@ -1051,6 +1050,11 @@ if IS_CUTLASS_DSL_AVAILABLE:
assert b_sf.size(2) == scale_k
assert alpha.size(0) == l
assert c.dtype == self.output_dtype
assert c.dim() == 2
num_tokens = c.size(0)
assert c.size(1) == n
num_tiles = m // self.tile_size
assert tile_idx_to_group_idx.dtype == torch.int32
assert tile_idx_to_group_idx.size() == (num_tiles, )
@ -1062,14 +1066,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
assert num_non_exiting_tiles.numel() == 1
assert token_final_scales.dtype == torch.float32
assert token_final_scales.dim() == 2
num_tokens = token_final_scales.size(0)
assert token_final_scales.size(1) == self.top_k
# TODO: Overlap the memset
c = torch.zeros(num_tokens,
n,
dtype=self.output_dtype,
device=a.device)
assert token_final_scales.size() == (num_tokens, self.top_k)
a_ptr = make_ptr(cutlass.Float4E2M1FN,
a.data_ptr(),
@ -1182,6 +1179,51 @@ if IS_CUTLASS_DSL_AVAILABLE:
)
return c
@torch.library.custom_op(
"trtllm::cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell",
mutates_args=("output", ),
device_types="cuda")
def cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell(
input: torch.Tensor,
weight: torch.Tensor,
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
alpha: torch.Tensor,
output: torch.Tensor,
tile_idx_to_group_idx: torch.Tensor,
tile_idx_to_mn_limit: torch.Tensor,
permuted_idx_to_expanded_idx: torch.Tensor,
num_non_exiting_tiles: torch.Tensor,
token_final_scales: torch.Tensor,
num_experts: int,
top_k: int,
num_local_experts: int,
local_expert_offset: int,
tile_size: int,
output_dtype: torch.dtype,
scaling_vector_size: int = 16,
) -> None:
tuner = AutoTuner.get()
runner = Sm100BlockScaledContiguousGroupedGemmFinalizeFusionRunner(
num_experts, top_k, num_local_experts, local_expert_offset,
tile_size, output_dtype, scaling_vector_size)
inputs = [
input, weight, input_scale, weight_scale, alpha, output,
tile_idx_to_group_idx, tile_idx_to_mn_limit,
permuted_idx_to_expanded_idx, num_non_exiting_tiles,
token_final_scales
]
_, best_tactic = tuner.choose_one(
"trtllm::cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell",
[runner],
runner.get_tuning_config(),
inputs,
)
runner(inputs, tactic=best_tactic)
@torch.library.custom_op(
"trtllm::cute_dsl_nvfp4_grouped_gemm_finalize_blackwell",
mutates_args=(),
@ -1205,25 +1247,32 @@ if IS_CUTLASS_DSL_AVAILABLE:
output_dtype: torch.dtype,
scaling_vector_size: int = 16,
) -> torch.Tensor:
tuner = AutoTuner.get()
runner = Sm100BlockScaledContiguousGroupedGemmFinalizeFusionRunner(
num_experts, top_k, num_local_experts, local_expert_offset,
tile_size, output_dtype, scaling_vector_size)
inputs = [
input, weight, input_scale, weight_scale, alpha,
tile_idx_to_group_idx, tile_idx_to_mn_limit,
permuted_idx_to_expanded_idx, num_non_exiting_tiles,
token_final_scales
]
_, best_tactic = tuner.choose_one(
"trtllm::cute_dsl_nvfp4_grouped_gemm_finalize_blackwell",
[runner],
runner.get_tuning_config(),
inputs,
num_tokens = token_final_scales.size(0)
n = weight.size(1)
output = torch.zeros(num_tokens,
n,
dtype=output_dtype,
device=input.device)
torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell(
input=input,
weight=weight,
input_scale=input_scale,
weight_scale=weight_scale,
alpha=alpha,
output=output,
tile_idx_to_group_idx=tile_idx_to_group_idx,
tile_idx_to_mn_limit=tile_idx_to_mn_limit,
permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx,
num_non_exiting_tiles=num_non_exiting_tiles,
token_final_scales=token_final_scales,
num_experts=num_experts,
top_k=top_k,
num_local_experts=num_local_experts,
local_expert_offset=local_expert_offset,
tile_size=tile_size,
output_dtype=output_dtype,
scaling_vector_size=scaling_vector_size,
)
output = runner(inputs, tactic=best_tactic)
return output
@torch.library.register_fake(
@ -1275,9 +1324,9 @@ if IS_CUTLASS_DSL_AVAILABLE:
self.tile_size = tile_size
self.scaling_vector_size = scaling_vector_size
if get_sm_version() != 100:
if (sm_version := get_sm_version()) not in (100, 103):
raise ValueError(
f"SM version {get_sm_version()} is not supported for {self.__class__.__name__}, it only supports SM 100"
f"{self.__class__.kernel_class.__name__} supports SM 100 (B200) and SM 103 (B300) only, but got SM {sm_version}"
)
def unique_id(self):

View File

@ -2631,6 +2631,7 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
):
scale_k = k // scaling_vector_size
interm_size = n // 2
scale_interm_size = interm_size // scaling_vector_size
num_tiles = m // tile_size
a = cute.make_tensor(a_ptr, layout=cute.make_ordered_layout((m, k, 1), order=(1, 0, 2)))
b = cute.make_tensor(b_ptr, layout=cute.make_ordered_layout((n, k, l), order=(1, 0, 2)))
@ -2652,7 +2653,7 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
c_sf = cute.make_tensor(
c_sf_ptr,
layout=cute.make_ordered_layout(
(32, 4, interm_size // 128, 4, scale_k // 4, l), order=(2, 1, 4, 0, 3, 5)
(32, 4, m // 128, 4, scale_interm_size // 4, 1), order=(2, 1, 4, 0, 3, 5)
),
)
alpha = cute.make_tensor(alpha_ptr, layout=cute.make_layout((l,)))

View File

@ -47,7 +47,9 @@ from .communication import (
NVLinkOneSided,
NVLinkTwoSided,
)
from .fused_moe_cute_dsl import CuteDslFusedMoE
from .fused_moe_cutlass import CutlassFusedMoE
from .fused_moe_deepgemm import DeepGemmFusedMoE
from .fused_moe_trtllm_gen import TRTLLMGenFusedMoE
@ -56,7 +58,7 @@ class ConfigurableMoE(MoE):
Configurable MoE layer using composition pattern with automatic configuration
This class orchestrates the MoE execution flow by composing:
- moe_backend: Existing FusedMoE implementation (CutlassFusedMoE, WideEPMoE, etc.)
- moe_backend: Existing FusedMoE implementation (CutlassFusedMoE, CuteDslFusedMoE, etc.)
Note: Current FusedMoE implementations are used as backends (transitional).
Future will have dedicated MoEBackend interface.
- Communication: Handles distributed communication (auto-selected)
@ -797,7 +799,7 @@ class ConfigurableMoE(MoE):
"""
Get the current MoE backend implementation
Note: Returns a FusedMoE instance (e.g., CutlassFusedMoE, WideEPMoE)
Note: Returns a FusedMoE instance (e.g., CutlassFusedMoE, CuteDslFusedMoE)
"""
return self._backend
@ -902,27 +904,26 @@ class ConfigurableMoE(MoE):
Returns:
Dict: Backend-specific keyword arguments
"""
backend_name = self.backend.__class__.__name__
kwargs = {}
# Common parameters for Cutlass and DeepGemm
if backend_name in ["CutlassFusedMoE", "DeepGemmFusedMoE"]:
if self.backend.__class__ in (CutlassFusedMoE, DeepGemmFusedMoE, CuteDslFusedMoE):
pass
# Cutlass-specific parameters
if backend_name == "CutlassFusedMoE":
if self.backend.__class__ == CutlassFusedMoE:
pass
# WideEP-specific parameters
elif backend_name == "WideEPMoE":
pass
# CuteDSL-specific parameters
elif self.backend.__class__ == CuteDslFusedMoE:
kwargs["enable_alltoall"] = self.enable_alltoall
# DeepGemm-specific parameters
elif backend_name == "DeepGemmFusedMoE":
elif self.backend.__class__ == DeepGemmFusedMoE:
pass
# TRTLLMGen-specific parameters
elif backend_name == "TRTLLMGenFusedMoE":
elif self.backend.__class__ == TRTLLMGenFusedMoE:
# Determine router_logits based on whether routing has been done
# If backend doesn't support load balancer, routing is done before communication
# In that case, router_logits should be None (routing already done)

View File

@ -20,6 +20,8 @@ from .interface import MoE, MoEWeightLoadingMode
from .moe_load_balancer import get_moe_load_balancer
from .routing import BaseMoeRoutingMethod
ENABLE_CONFIGURABLE_MOE = os.environ.get("ENABLE_CONFIGURABLE_MOE", "0") == "1"
def get_moe_cls(
model_config: ModelConfig,
@ -33,7 +35,16 @@ def get_moe_cls(
elif moe_backend.upper() == "VANILLA":
return VanillaMoE
elif moe_backend.upper() == "CUTEDSL":
return CuteDslFusedMoE
if quant_config is not None and (
quant_config.quant_mode.has_fp8_block_scales()
or quant_config.quant_mode.has_nvfp4()):
return CuteDslFusedMoE
else:
logger.warning(
"CuteDslFusedMoE only supports fp8_block_scales and nvfp4. "
f"Check out details in quant_config: {quant_config}. Using CutlassFusedMoE instead."
)
return CutlassFusedMoE
elif moe_backend.upper() == "DEEPGEMM":
return DeepGemmFusedMoE
elif moe_backend.upper() == "TRTLLM":
@ -48,8 +59,8 @@ def get_moe_cls(
else:
logger.warning(
"TRTLLMGenFusedMoE only supports fp8_block_scales, nvfp4, w4a16_mxfp4, w4a8_mxfp4_fp8 and w4a8_mxfp4_mxfp8. "
f"Check out details in quant_config: {quant_config}"
"Using CutlassFusedMoE instead.")
f"Check out details in quant_config: {quant_config}. Using CutlassFusedMoE instead."
)
return CutlassFusedMoE
elif moe_backend.upper() == "WIDEEP":
return WideEPMoE
@ -129,8 +140,8 @@ def create_moe_backend(
moe_load_balancer = get_moe_load_balancer()
if moe_load_balancer is not None:
assert moe_cls in [
WideEPMoE, CutlassFusedMoE, TRTLLMGenFusedMoE
], "MoE Load Balance is only supported in WideEPMoE, CutlassFusedMoE and TRTLLMGenFusedMoE now."
WideEPMoE, CutlassFusedMoE, TRTLLMGenFusedMoE, CuteDslFusedMoE
], "MoE Load Balance is only supported in WideEPMoE, CutlassFusedMoE, TRTLLMGenFusedMoE and CuteDslFusedMoE now."
if bias:
assert moe_cls in [CutlassFusedMoE, TritonFusedMoE, TRTLLMGenFusedMoE
@ -229,6 +240,8 @@ def create_moe_backend(
weight_loading_mode=weight_loading_mode,
apply_router_weight_on_input=apply_router_weight_on_input,
layer_idx=layer_idx,
init_load_balancer=init_load_balancer,
without_comm=without_comm,
)
elif moe_cls == DeepGemmFusedMoE:
return moe_cls(
@ -331,13 +344,9 @@ def create_moe(
moe_cls = get_moe_cls(model_config, override_quant_config)
# Check if ENABLE_CONFIGURABLE_MOE environment variable is set
enable_configurable_moe = os.environ.get('ENABLE_CONFIGURABLE_MOE',
'0') == '1'
if enable_configurable_moe:
# ConfigurableMoE is only supported for TRTLLMGenFusedMoE backend
if moe_cls == TRTLLMGenFusedMoE:
if ENABLE_CONFIGURABLE_MOE or moe_cls == CuteDslFusedMoE:
# ConfigurableMoE only supports TRTLLMGenFusedMoE and CuteDslFusedMoE backends
if moe_cls in (TRTLLMGenFusedMoE, CuteDslFusedMoE):
return ConfigurableMoE(
routing_method=routing_method,
num_experts=num_experts,
@ -358,12 +367,13 @@ def create_moe(
else:
# Check if this is a TRTLLM backend request that fallback to CutlassFusedMoE
requested_backend = model_config.moe_backend.upper()
if requested_backend == "TRTLLM" and moe_cls == CutlassFusedMoE:
if requested_backend in ("TRTLLM",
"CUTEDSL") and moe_cls == CutlassFusedMoE:
# Workaround for test cases where TRTLLM backend fallbacks to CutlassFusedMoE due to quant_config incompatibility
# Log warning and continue with the fallback backend
logger.warning(
f"ENABLE_CONFIGURABLE_MOE is set but TRTLLM backend fallback to {moe_cls.__name__} due to quant_config. "
f"ConfigurableMoE only supports TRTLLMGenFusedMoE backend. "
f"ConfigurableMoE only supports TRTLLMGenFusedMoE and CuteDslFusedMoE backends. "
f"Continuing with legacy MoE backend {moe_cls.__name__}.")
else:
# For other incompatible backends, raise error

View File

@ -8,7 +8,7 @@ from tensorrt_llm._utils import is_sm_100f
from ...distributed import allgather
from ...model_config import ModelConfig
from ...utils import AuxStreamType, Fp4QuantizedTensor, ceil_div
from ...utils import AuxStreamType, Fp4QuantizedTensor
from .fused_moe_cutlass import CutlassFusedMoE
from .interface import AlltoallMethodType
from .quantization import MoEWeightLoadingMode, NVFP4CuteDslFusedMoEMethod
@ -180,6 +180,8 @@ class CuteDslFusedMoE(CutlassFusedMoE):
VANILLA,
apply_router_weight_on_input: bool = False,
layer_idx: Optional[int] = None,
init_load_balancer: bool = True,
without_comm: bool = False,
):
super().__init__(
@ -194,6 +196,8 @@ class CuteDslFusedMoE(CutlassFusedMoE):
weight_loading_mode=weight_loading_mode,
apply_router_weight_on_input=apply_router_weight_on_input,
layer_idx=layer_idx,
init_load_balancer=init_load_balancer,
without_comm=without_comm,
)
def select_alltoall_method_type(self) -> AlltoallMethodType:
@ -206,51 +210,163 @@ class CuteDslFusedMoE(CutlassFusedMoE):
return NVFP4CuteDslFusedMoEMethod()
return super()._get_quant_method()
def forward_chunk_unquantized(
self,
x: Union[torch.Tensor, Fp4QuantizedTensor],
router_logits: torch.Tensor,
output_dtype: Optional[torch.dtype] = None,
all_rank_num_tokens: Optional[List[int]] = None,
use_dp_padding: Optional[bool] = None,
repeating_info: tuple = (True, True),
) -> torch.Tensor:
assert not self.has_any_quant
return super().forward_chunk(x,
router_logits,
output_dtype=output_dtype,
all_rank_num_tokens=all_rank_num_tokens,
use_dp_padding=use_dp_padding,
repeating_info=repeating_info)
def quantize_input(self,
x: Union[torch.Tensor, Fp4QuantizedTensor],
post_quant_comm: bool = True):
"""Quantize inputs prior to post-communication (alltoall/allgather) or before MoE computation.
def forward_chunk_fp8_block_scales(
self,
x: Union[torch.Tensor, Fp4QuantizedTensor],
router_logits: torch.Tensor,
output_dtype: Optional[torch.dtype] = None,
all_rank_num_tokens: Optional[List[int]] = None,
use_dp_padding: Optional[bool] = None,
repeating_info: tuple = (True, True),
Args:
x: Input tensor to quantize
post_quant_comm:
If True, quantize for post-quant communication path.
If False, quantize for non-communication path
Returns: (x, x_sf) where x_sf is already reshaped to 2D if needed
For quantization methods that produce scaling factors:
- x_sf is reshaped from 1D to 2D: [num_elements] -> [batch_size, ceil_div(hidden_size, scaling_vector_size)]
- The 2D shape is required for proper handling in alltoall/allgather operations
- scaling_vector_size is typically the group size for block-wise quantization
"""
x_sf = None
if self.has_nvfp4:
if isinstance(x, Fp4QuantizedTensor):
assert not x.is_sf_swizzled, "Fp4QuantizedTensor should not be swizzled before communication"
x_row = x.shape[0]
x, x_sf = x.fp4_tensor, x.scaling_factor
else:
x_row = x.shape[0]
x, x_sf = torch.ops.trtllm.fp4_quantize(
x, self.fc31_input_scale, self.scaling_vector_size, False,
False)
elif self.has_deepseek_fp8_block_scales:
# FP8 block scales doesn't support permutation of quantized inputs.
# WAR: The quantization is in run_moe_fp8_block_scales.
pass
else:
raise ValueError(
f"{self.__class__.__name__} doesn't support quantization mode {self.quant_config.quant_mode}."
)
if x_sf is not None:
x_sf = x_sf.view(x_row, -1)
return x, x_sf
def run_moe_nvfp4(
self,
x: torch.Tensor,
token_selected_experts: torch.Tensor,
token_final_scales: Optional[torch.Tensor],
x_sf: Optional[torch.Tensor] = None,
enable_alltoall: bool = False,
) -> torch.Tensor:
assert self.has_nvfp4
output_dtype = torch.bfloat16
tile_size = 128
tile_idx_to_expert_idx, tile_idx_to_mn_limit, expanded_idx_to_permuted_idx, permuted_idx_to_expanded_idx, total_num_padded_tokens, num_non_exiting_tiles = torch.ops.trtllm.moe_sort(
token_selected_experts=token_selected_experts,
token_final_scales=token_final_scales,
num_experts=self.num_slots,
top_k=self.routing_method.experts_per_token,
local_expert_offset=self.slot_start,
local_num_experts=self.expert_size_per_partition,
tile_tokens_dim=tile_size,
)
x, x_sf = torch.ops.trtllm.moe_permute(
input=x.view(torch.float4_e2m1fn_x2),
input_sf=x_sf,
tile_idx_to_mn_limit=tile_idx_to_mn_limit,
permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx,
num_non_exiting_tiles=num_non_exiting_tiles,
tile_tokens_dim=tile_size,
top_k=self.routing_method.experts_per_token,
)
x, x_sf = torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_swiglu_blackwell(
input=x.view(torch.float4_e2m1fn_x2),
weight=self.w3_w1_weight.view(torch.float4_e2m1fn_x2),
input_scale=x_sf.view(torch.uint8),
weight_scale=self.quant_scales.fc1_weight_block.view(torch.uint8),
alpha=self.quant_scales.fc1_global,
tile_idx_to_group_idx=tile_idx_to_expert_idx,
num_non_exiting_tiles=num_non_exiting_tiles,
global_sf=self.fc2_input_scale,
num_experts=self.num_slots,
top_k=self.routing_method.experts_per_token,
num_local_experts=self.expert_size_per_partition,
local_expert_offset=self.slot_start,
tile_size=tile_size,
)
if self.use_fused_finalize:
output = torch.empty((token_final_scales.size(0), self.hidden_size),
dtype=output_dtype,
device=x.device)
torch.ops.trtllm.moe_output_memset_inplace(
input=output,
tile_idx_to_mn_limit=tile_idx_to_mn_limit,
expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx,
permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx,
num_non_exiting_tiles=num_non_exiting_tiles,
tile_tokens_dim=tile_size,
top_k=self.routing_method.experts_per_token,
ep_size=self.mapping.moe_ep_size,
enable_alltoall=enable_alltoall,
)
torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell(
input=x.view(torch.float4_e2m1fn_x2),
weight=self.w2_weight.view(torch.float4_e2m1fn_x2),
input_scale=x_sf.view(torch.uint8),
weight_scale=self.quant_scales.fc2_weight_block.view(
torch.uint8),
alpha=self.quant_scales.fc2_global,
output=output,
tile_idx_to_group_idx=tile_idx_to_expert_idx,
tile_idx_to_mn_limit=tile_idx_to_mn_limit,
permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx,
num_non_exiting_tiles=num_non_exiting_tiles,
token_final_scales=token_final_scales,
num_experts=self.num_slots,
top_k=self.routing_method.experts_per_token,
num_local_experts=self.expert_size_per_partition,
local_expert_offset=self.slot_start,
tile_size=tile_size,
output_dtype=output_dtype,
)
x = output
else:
x = torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_blackwell(
input=x.view(torch.float4_e2m1fn_x2),
weight=self.w2_weight.view(torch.float4_e2m1fn_x2),
input_scale=x_sf.view(torch.uint8),
weight_scale=self.quant_scales.fc2_weight_block.view(
torch.uint8),
alpha=self.quant_scales.fc2_global,
tile_idx_to_group_idx=tile_idx_to_expert_idx,
num_non_exiting_tiles=num_non_exiting_tiles,
num_experts=self.num_slots,
top_k=self.routing_method.experts_per_token,
num_local_experts=self.expert_size_per_partition,
local_expert_offset=self.slot_start,
tile_size=tile_size,
output_dtype=output_dtype,
)
x = torch.ops.trtllm.moe_unpermute(
permuted_input=x,
expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx,
topk_scales=token_final_scales,
)
return x
def run_moe_fp8_block_scales(
self,
x: torch.Tensor,
token_selected_experts: torch.Tensor,
token_final_scales: Optional[torch.Tensor],
x_sf: Optional[torch.Tensor] = None,
enable_alltoall: bool = False,
) -> torch.Tensor:
assert self.has_deepseek_fp8_block_scales
# apply routing
token_selected_experts, token_final_scales = self.routing_method.apply(
router_logits)
assert token_selected_experts.shape[
1] == self.routing_method.experts_per_token
assert token_selected_experts.shape == token_final_scales.shape
assert token_selected_experts.shape[0] == router_logits.shape[0]
assert token_final_scales.dtype == torch.float32
assert token_selected_experts.dtype == torch.int32
if self.apply_router_weight_on_input:
assert self.routing_method.top_k == 1, "Current workaround only supports top-1 routing"
assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input"
x = x * token_final_scales.to(x.dtype)
# TODO: remove this once we have correct fusedmoe kernel ready
token_final_scales = None
assert x_sf is None
weight_dtype = self.w3_w1_weight.dtype
(
@ -304,7 +420,7 @@ class CuteDslFusedMoE(CutlassFusedMoE):
permuted_row_to_unpermuted_row_tensor,
token_selected_experts,
expert_first_token_offset_tensor,
False, # enable_alltoall
enable_alltoall,
x.shape[0], # num_rows
x.shape[1], # (possibly padded) hidden_size
self.unpadded_hidden_size, # original hidden size
@ -317,140 +433,50 @@ class CuteDslFusedMoE(CutlassFusedMoE):
)
return h4
def forward_chunk_nvfp4(
self,
x: Union[torch.Tensor, Fp4QuantizedTensor],
router_logits: torch.Tensor,
output_dtype: Optional[torch.dtype] = None,
all_rank_num_tokens: Optional[List[int]] = None,
use_dp_padding: Optional[bool] = None,
repeating_info: tuple = (True, True),
def run_moe(
self,
x: torch.Tensor,
token_selected_experts: torch.Tensor,
token_final_scales: Optional[torch.Tensor],
x_sf: Optional[torch.Tensor] = None,
enable_alltoall: bool = False,
) -> torch.Tensor:
assert self.has_nvfp4
"""
Run MoE computation with CuteDSL backend.
if isinstance(x, Fp4QuantizedTensor):
assert output_dtype is not None
else:
output_dtype = x.dtype
This method encapsulates the core MoE computation logic, handling different
quantization schemes (fp8_block_scales and nvfp4).
# apply routing
token_selected_experts, token_final_scales = self.routing_method.apply(
router_logits)
assert token_selected_experts.shape[
1] == self.routing_method.experts_per_token
assert token_selected_experts.shape == token_final_scales.shape
assert token_selected_experts.shape[0] == router_logits.shape[0]
assert token_final_scales.dtype == torch.float32
assert token_selected_experts.dtype == torch.int32
Args:
# Standard MoE interface parameters:
x: Input hidden states (may be pre-quantized)
token_selected_experts: Expert IDs [num_tokens, top_k]. If EPLB is enabled,
this represents expert slots [num_tokens, top_k] instead.
token_final_scales: Final scaling factors for each token
x_sf: Input scale factors (optional, for certain quantization schemes)
enable_alltoall: Whether alltoall communication is enabled.
run_post_quant_allgather = self.use_dp and self.parallel_size > 1
if run_post_quant_allgather:
if isinstance(x, Fp4QuantizedTensor):
assert not x.is_sf_swizzled, "Fp4QuantizedTensor should not be swizzled before communication"
x, x_sf = x.fp4_tensor, x.scaling_factor
else:
x, x_sf = torch.ops.trtllm.fp4_quantize(
x, self.fc31_input_scale, self.scaling_vector_size, False,
False)
# note: we use uint8 to store 2 fp4 values
x_row, x_col = x.size(0), x.size(1) * 2
else:
if not isinstance(x, Fp4QuantizedTensor):
x, x_sf = torch.ops.trtllm.fp4_quantize(
x, self.fc31_input_scale, self.scaling_vector_size, False,
False)
if run_post_quant_allgather:
# Original allgather logic
if x_sf is not None:
x_sf = x_sf.view(x_row, ceil_div(x_col,
self.scaling_vector_size))
assert x_sf.dim(
) == 2, "The hidden states scaling factor should be 2D tensor before allgather"
x, x_sf, token_selected_experts, token_final_scales = allgather(
[x, x_sf, token_selected_experts, token_final_scales],
self.mapping,
dim=0,
sizes=None if use_dp_padding else all_rank_num_tokens)
tile_size = 128
tile_idx_to_expert_idx, tile_idx_to_mn_limit, expanded_idx_to_permuted_idx, permuted_idx_to_expanded_idx, total_num_padded_tokens, num_non_exiting_tiles = torch.ops.trtllm.moe_sort(
token_selected_experts=token_selected_experts,
token_final_scales=token_final_scales,
num_experts=self.num_slots,
top_k=self.routing_method.experts_per_token,
local_expert_offset=self.slot_start,
local_num_experts=self.expert_size_per_partition,
tile_tokens_dim=tile_size,
)
x, x_sf = torch.ops.trtllm.moe_permute(
input=x.view(torch.float4_e2m1fn_x2),
input_sf=x_sf,
tile_idx_to_mn_limit=tile_idx_to_mn_limit,
permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx,
num_non_exiting_tiles=num_non_exiting_tiles,
tile_tokens_dim=tile_size,
top_k=self.routing_method.experts_per_token,
)
x, x_sf = torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_swiglu_blackwell(
input=x.view(torch.float4_e2m1fn_x2),
weight=self.w3_w1_weight.view(torch.float4_e2m1fn_x2),
input_scale=x_sf.view(torch.uint8),
weight_scale=self.quant_scales.fc1_weight_block.view(torch.uint8),
alpha=self.quant_scales.fc1_global,
tile_idx_to_group_idx=tile_idx_to_expert_idx,
num_non_exiting_tiles=num_non_exiting_tiles,
global_sf=self.fc2_input_scale,
num_experts=self.num_slots,
top_k=self.routing_method.experts_per_token,
num_local_experts=self.expert_size_per_partition,
local_expert_offset=self.slot_start,
tile_size=tile_size,
)
if self.use_fused_finalize:
x = torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_finalize_blackwell(
input=x.view(torch.float4_e2m1fn_x2),
weight=self.w2_weight.view(torch.float4_e2m1fn_x2),
input_scale=x_sf.view(torch.uint8),
weight_scale=self.quant_scales.fc2_weight_block.view(
torch.uint8),
alpha=self.quant_scales.fc2_global,
tile_idx_to_group_idx=tile_idx_to_expert_idx,
tile_idx_to_mn_limit=tile_idx_to_mn_limit,
permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx,
num_non_exiting_tiles=num_non_exiting_tiles,
Returns:
final_hidden_states tensor.
"""
if self.has_nvfp4:
return self.run_moe_nvfp4(
x=x,
token_selected_experts=token_selected_experts,
token_final_scales=token_final_scales,
num_experts=self.num_slots,
top_k=self.routing_method.experts_per_token,
num_local_experts=self.expert_size_per_partition,
local_expert_offset=self.slot_start,
tile_size=tile_size,
output_dtype=output_dtype,
)
x_sf=x_sf,
enable_alltoall=enable_alltoall)
elif self.has_deepseek_fp8_block_scales:
return self.run_moe_fp8_block_scales(
x=x,
token_selected_experts=token_selected_experts,
token_final_scales=token_final_scales,
x_sf=x_sf,
enable_alltoall=enable_alltoall)
else:
x = torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_blackwell(
input=x.view(torch.float4_e2m1fn_x2),
weight=self.w2_weight.view(torch.float4_e2m1fn_x2),
input_scale=x_sf.view(torch.uint8),
weight_scale=self.quant_scales.fc2_weight_block.view(
torch.uint8),
alpha=self.quant_scales.fc2_global,
tile_idx_to_group_idx=tile_idx_to_expert_idx,
num_non_exiting_tiles=num_non_exiting_tiles,
num_experts=self.num_slots,
top_k=self.routing_method.experts_per_token,
num_local_experts=self.expert_size_per_partition,
local_expert_offset=self.slot_start,
tile_size=tile_size,
output_dtype=output_dtype,
raise ValueError(
f"{self.__class__.__name__} doesn't support quantization mode {self.quant_config.quant_mode}."
)
x = torch.ops.trtllm.moe_unpermute(
permuted_input=x,
expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx,
topk_scales=token_final_scales,
)
return x
def forward_chunk(
self,
@ -461,32 +487,30 @@ class CuteDslFusedMoE(CutlassFusedMoE):
use_dp_padding: Optional[bool] = None,
repeating_info: tuple = (True, True),
) -> torch.Tensor:
if self.has_any_quant:
if self.has_nvfp4:
return self.forward_chunk_nvfp4(
x,
router_logits,
output_dtype=output_dtype,
all_rank_num_tokens=all_rank_num_tokens,
use_dp_padding=use_dp_padding,
repeating_info=repeating_info)
elif self.has_deepseek_fp8_block_scales:
return self.forward_chunk_fp8_block_scales(
x,
router_logits,
output_dtype=output_dtype,
all_rank_num_tokens=all_rank_num_tokens,
use_dp_padding=use_dp_padding,
repeating_info=repeating_info)
else:
raise ValueError(
f"unsupported quantization mode for CUTEDSL backend: {self.quant_config.quant_mode}"
)
else:
return self.forward_chunk_unquantized(
x,
router_logits,
output_dtype=output_dtype,
all_rank_num_tokens=all_rank_num_tokens,
use_dp_padding=use_dp_padding,
repeating_info=repeating_info)
# Currently, the default path is that ConfigurableMoE calls CuteDslFusedMoE.run_moe.
# This forward_chunk method is a reference implementation of the legacy path.
# Apply routing
token_selected_experts, token_final_scales = self.routing_method.apply(
router_logits)
assert token_selected_experts.shape[
1] == self.routing_method.experts_per_token
assert token_selected_experts.shape == token_final_scales.shape
assert token_selected_experts.shape[0] == router_logits.shape[0]
assert token_final_scales.dtype == torch.float32
assert token_selected_experts.dtype == torch.int32
x, x_sf = self.quantize_input(x)
if self.use_dp and self.parallel_size > 1:
x, x_sf, token_selected_experts, token_final_scales = allgather(
[x, x_sf, token_selected_experts, token_final_scales],
self.mapping,
dim=0,
sizes=None if use_dp_padding else all_rank_num_tokens)
x = self.run_moe(x=x,
token_selected_experts=token_selected_experts,
token_final_scales=token_final_scales,
x_sf=x_sf,
enable_alltoall=False)
return x

View File

@ -76,6 +76,7 @@ class CutlassFusedMoE(MoE):
swiglu_beta: Optional[torch.Tensor] = None,
swiglu_limit: Optional[torch.Tensor] = None,
init_load_balancer: bool = True,
without_comm: bool = False,
activation_type: ActivationType = ActivationType.Swiglu,
):
@ -138,49 +139,58 @@ class CutlassFusedMoE(MoE):
self.has_been_profiled = False
self.has_been_profiled_min_latency = False
# TODO: AlltoAll code is largely duplicated with WideEPMoE. Consider refactor and reuse in the future.
self.alltoall_method_type = self.select_alltoall_method_type()
logger.info_once(
f"{self.__class__.__name__} selects alltoall_method_type {self.alltoall_method_type!r}",
key="alltoall_method_type")
self.alltoall_workspace = None
self.alltoall_prepare_workspace = None
self.use_low_precision_combine = False
if self.enable_alltoall:
self.use_low_precision_combine = model_config.use_low_precision_moe_combine
# When without_comm=True, skip communication initialization (ConfigurableMoE will handle it)
if not without_comm:
self.alltoall_method_type = self.select_alltoall_method_type()
logger.info_once(
f"{self.__class__.__name__} selects alltoall_method_type {self.alltoall_method_type!r}",
key="alltoall_method_type")
self.alltoall_workspace = None
self.alltoall_prepare_workspace = None
self.use_low_precision_combine = False
if self.enable_alltoall:
self.use_low_precision_combine = model_config.use_low_precision_moe_combine
if self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided:
MnnvlMemory.initialize()
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
model_config.mapping)
self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace(
model_config.mapping)
elif self.alltoall_method_type == AlltoallMethodType.NVLinkOneSided:
# Calculate required workspace size
ep_size = self.mapping.moe_ep_size
max_num_tokens = model_config.max_num_tokens
hidden_size = self.hidden_size
dtype = self.dtype or torch.float16
if self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided:
MnnvlMemory.initialize()
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
model_config.mapping)
self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace(
model_config.mapping)
elif self.alltoall_method_type == AlltoallMethodType.NVLinkOneSided:
# Calculate required workspace size
ep_size = self.mapping.moe_ep_size
max_num_tokens = model_config.max_num_tokens
hidden_size = self.hidden_size
dtype = self.dtype or torch.float16
workspace_size = MoeAlltoAll.calculate_required_workspace_size(
ep_size, self.routing_method.experts_per_token,
max_num_tokens, hidden_size, dtype)
workspace_size = MoeAlltoAll.calculate_required_workspace_size(
ep_size, self.routing_method.experts_per_token,
max_num_tokens, hidden_size, dtype)
self.moe_a2a = MoeAlltoAll(
mapping=self.mapping,
max_num_tokens=model_config.max_num_tokens,
top_k=self.routing_method.experts_per_token,
num_experts=self.num_slots,
workspace_size_per_rank=workspace_size,
)
elif self.alltoall_method_type == AlltoallMethodType.DeepEP or self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
raise NotImplementedError(
"DeepEP and DeepEPLowLatency are not supported for CutlassFusedMoE yet"
)
else:
raise NotImplementedError(
f"Unsupported alltoall method type: {self.alltoall_method_type!r}"
)
self.moe_a2a = MoeAlltoAll(
mapping=self.mapping,
max_num_tokens=model_config.max_num_tokens,
top_k=self.routing_method.experts_per_token,
num_experts=self.num_slots,
workspace_size_per_rank=workspace_size,
)
elif self.alltoall_method_type == AlltoallMethodType.DeepEP or self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
raise NotImplementedError(
"DeepEP and DeepEPLowLatency are not supported for CutlassFusedMoE yet"
)
else:
raise NotImplementedError(
f"Unsupported alltoall method type: {self.alltoall_method_type!r}"
)
else:
# When without_comm=True, set minimal attributes
# Communication will be handled by parent wrapper (e.g., ConfigurableMoE)
self.alltoall_method_type = AlltoallMethodType.NotEnabled
self.alltoall_workspace = None
self.alltoall_prepare_workspace = None
self.use_low_precision_combine = False
self.moe_a2a = None
# If True, the router weight will be multiplied on the input rather than at the end of FC2
self.apply_router_weight_on_input = apply_router_weight_on_input

View File

@ -156,14 +156,14 @@ class TRTLLMGenFusedMoE(MoE):
raise NotImplementedError(
f"Unsupported alltoall method type: {self.alltoall_method_type!r}"
)
else:
# When without_comm=True, set minimal attributes
# Communication will be handled by parent wrapper (e.g., ConfigurableMoE)
self.alltoall_method_type = AlltoallMethodType.NotEnabled
self.alltoall_workspace = None
self.alltoall_prepare_workspace = None
self.use_low_precision_combine = False
self.moe_a2a = None
else:
# When without_comm=True, set minimal attributes
# Communication will be handled by parent wrapper (e.g., ConfigurableMoE)
self.alltoall_method_type = AlltoallMethodType.NotEnabled
self.alltoall_workspace = None
self.alltoall_prepare_workspace = None
self.use_low_precision_combine = False
self.moe_a2a = None
self._weights_created = False
if not model_config.skip_create_weights_in_init:

View File

@ -2082,34 +2082,44 @@ class NVFP4CutlassFusedMoEMethod(NVFP4FusedMoEMethod):
class NVFP4CuteDslFusedMoEMethod(NVFP4CutlassFusedMoEMethod):
def post_load_weights(self, module: torch.nn.Module):
super().post_load_weights(module)
def load_expert_w3_w1_weight(self, module: torch.nn.Module,
w1_weight: torch.Tensor,
w3_weight: torch.Tensor,
dst_w3_w1_weight: torch.Tensor):
super().load_expert_w3_w1_weight(module, w1_weight, w3_weight,
dst_w3_w1_weight)
# Interleave FC1 weight and scales for GEMM1 + SwiGLU fusion.
w3_w1_weight = module.w3_w1_weight.data.view(float4_e2m1x2)
m = w3_w1_weight.size(1)
n = w3_w1_weight.size(2) * 2
# Interleave FC1 weight for GEMM1 + SwiGLU fusion.
w3_w1_weight = dst_w3_w1_weight.cuda().view(float4_e2m1x2)
w3_w1_weight_interleaved = interleave_linear_and_gate(w3_w1_weight,
group_size=64,
dim=1)
dim=0)
w3_w1_weight_interleaved = w3_w1_weight_interleaved.view(
module.w3_w1_weight.data.dtype)
module.w3_w1_weight.data.copy_(w3_w1_weight_interleaved)
dst_w3_w1_weight.dtype)
dst_w3_w1_weight.copy_(w3_w1_weight_interleaved)
w3_w1_weight_scale = module.quant_scales.fc1_weight_block.data.view(
float4_sf_dtype)
def load_expert_w3_w1_weight_scale_nvfp4(
self, module: torch.nn.Module, w1_weight_scale: torch.Tensor,
w3_weight_scale: torch.Tensor,
dst_w3_w1_weight_scale: torch.Tensor):
super().load_expert_w3_w1_weight_scale_nvfp4(module, w1_weight_scale,
w3_weight_scale,
dst_w3_w1_weight_scale)
# Interleave FC1 scales for GEMM1 + SwiGLU fusion.
n = module.intermediate_size_per_partition * 2
k = module.hidden_size
w3_w1_weight_scale = dst_w3_w1_weight_scale.cuda().view(float4_sf_dtype)
w3_w1_weight_scale_unswizzled = unswizzle_sf(
w3_w1_weight_scale, m, n).view(-1, m,
n // module.scaling_vector_size)
w3_w1_weight_scale, n, k).view(n, k // module.scaling_vector_size)
w3_w1_weight_scale_unswizzled_interleaved = interleave_linear_and_gate(
w3_w1_weight_scale_unswizzled, group_size=64, dim=1)
w3_w1_weight_scale_unswizzled, group_size=64, dim=0)
w3_w1_weight_scale_interleaved = swizzle_sf(
w3_w1_weight_scale_unswizzled_interleaved, m,
n).view(-1, m, n // module.scaling_vector_size)
w3_w1_weight_scale_unswizzled_interleaved, n,
k).view(n, k // module.scaling_vector_size)
w3_w1_weight_scale_interleaved = w3_w1_weight_scale_interleaved.view(
module.quant_scales.fc1_weight_block.data.dtype)
module.quant_scales.fc1_weight_block.data.copy_(
w3_w1_weight_scale_interleaved)
dst_w3_w1_weight_scale.dtype)
dst_w3_w1_weight_scale.copy_(w3_w1_weight_scale_interleaved)
class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod):

View File

@ -112,13 +112,13 @@ class GemmRunner(TunableRunner):
mutates_args=())
def get_best_gemm_tactic(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
runners = [GemmRunner()]
tunner = AutoTuner.get()
tuner = AutoTuner.get()
tuning_config = TuningConfig(dynamic_tensor_specs=(DynamicTensorSpec(
input_idx=0,
dim_idx=0,
gen_tuning_buckets=get_power_of_2_num_tokens_buckets,
map_to_tuning_buckets=next_positive_power_of_2), ), )
runner, tactic = tunner.choose_one(
runner, tactic = tuner.choose_one(
"autotuner_test::get_best_gemm_tactic",
runners,
tuning_config,
@ -175,20 +175,20 @@ def test_autotuner_try_block():
x, w = torch.randn(M, 64), torch.randn(64, 128)
runners = [PartialCrashedRunner()]
tunner = AutoTuner.get()
tuner = AutoTuner.get()
tuning_config = TuningConfig(dynamic_tensor_specs=(DynamicTensorSpec(
input_idx=0,
dim_idx=0,
gen_tuning_buckets=get_power_of_2_num_tokens_buckets,
map_to_tuning_buckets=next_positive_power_of_2), ), )
with autotune():
runner, tactic = tunner.choose_one("test_autotuner_try_block", runners,
tuning_config, [x, w])
runner, tactic = tuner.choose_one("test_autotuner_try_block", runners,
tuning_config, [x, w])
m = M // 2
while m >= 1:
_, tactic = tunner.choose_one("test_autotuner_try_block", runners,
tuning_config, [torch.randn(m, 64), w])
_, tactic = tuner.choose_one("test_autotuner_try_block", runners,
tuning_config, [torch.randn(m, 64), w])
assert tactic in [
-1, 0
], f"Expect only tactic -1, 0 being chosen, but got tactic {tactic}."

View File

@ -1364,9 +1364,10 @@ def test_fused_moe_nvfp4(dtype, moe_backend):
if dtype == torch.float16:
pytest.skip(
"CUTEDSL NVFP4 MoE backend does not support float16 yet")
if get_sm_version() != 100:
if get_sm_version() not in (100, 103):
pytest.skip(
"CUTEDSL NVFP4 MoE backend is only supported on SM 100 GPUs")
"CUTEDSL NVFP4 MoE backend supports SM 100 (B200) and SM 103 (B300) only"
)
test_all_kernels = True
if get_sm_version() == 120:

View File

@ -227,6 +227,60 @@ def test_moe_unpermute(dtype: str, num_tokens: int, top_k: int, tile_size: int):
torch.testing.assert_close(x, x_ref)
@pytest.mark.parametrize("tile_size", [128, 256])
@pytest.mark.parametrize("ep_size", [1, 8, 32])
@pytest.mark.parametrize("top_k", [1, 2, 8])
@pytest.mark.parametrize("num_tokens", [128, 515, 1024])
@pytest.mark.parametrize("dtype", ["bfloat16", "float16"])
def test_moe_output_memset_inplace(
dtype: str, num_tokens: int, top_k: int, ep_size: int, tile_size: int
):
dtype = getattr(torch, dtype)
hidden_size = 4096
num_experts = 256
num_local_experts = num_experts // ep_size
enable_alltoall = True
routing_logits = torch.randn(num_tokens, num_experts, device="cuda")
token_final_scales, token_selected_experts = routing_logits.topk(top_k, dim=-1)
token_selected_experts = token_selected_experts.to(torch.int32)
token_final_scales = token_final_scales.softmax(dim=-1).to(torch.float32)
(
tile_idx_to_group_idx,
tile_idx_to_mn_limit,
expanded_idx_to_permuted_idx,
permuted_idx_to_expanded_idx,
total_num_padded_tokens,
num_non_exiting_tiles,
) = torch.ops.trtllm.moe_sort(
token_selected_experts=token_selected_experts,
token_final_scales=token_final_scales,
num_experts=num_experts,
top_k=top_k,
local_expert_offset=0,
local_num_experts=num_local_experts,
tile_tokens_dim=tile_size,
)
x = torch.ones(num_tokens, hidden_size, dtype=dtype, device="cuda")
torch.ops.trtllm.moe_output_memset_inplace(
x,
tile_idx_to_mn_limit,
expanded_idx_to_permuted_idx,
permuted_idx_to_expanded_idx,
num_non_exiting_tiles,
tile_size,
top_k,
ep_size,
enable_alltoall=enable_alltoall,
)
x_ref = torch.zeros_like(x)
if enable_alltoall and ep_size > top_k:
x_ref[(expanded_idx_to_permuted_idx < 0).all(dim=-1)] = 1
torch.testing.assert_close(x, x_ref)
@pytest.mark.parametrize("tile_size", [128, 256])
@pytest.mark.parametrize("top_k", [1, 2, 8])
@pytest.mark.parametrize("num_tokens", [128, 515, 1024])
@ -257,7 +311,10 @@ def test_moe_swiglu(dtype: str, num_tokens: int, top_k: int, tile_size: int):
torch.testing.assert_close(y[:num_permuted_tokens], y_ref[:num_permuted_tokens])
@pytest.mark.skipif(get_sm_version() != 100, reason="This test is only supported on SM 100 GPUs")
@pytest.mark.skipif(
get_sm_version() not in (100, 103),
reason="This test is only supported on SM 100 and SM 103 GPUs",
)
@pytest.mark.parametrize("tile_size", [128, 256])
@pytest.mark.parametrize("top_k", [1, 2, 8])
@pytest.mark.parametrize("num_tokens", [128, 515, 1024])
@ -332,7 +389,10 @@ def test_moe_gelu(dtype: str, num_tokens: int, top_k: int, tile_size: int):
torch.testing.assert_close(y[:num_permuted_tokens], y_ref[:num_permuted_tokens])
@pytest.mark.skipif(get_sm_version() != 100, reason="This test is only supported on SM 100 GPUs")
@pytest.mark.skipif(
get_sm_version() not in (100, 103),
reason="This test is only supported on SM 100 and SM 103 GPUs",
)
@pytest.mark.parametrize("tile_size", [128])
@pytest.mark.parametrize("ep_size", [1, 8, 32])
@pytest.mark.parametrize("top_k", [1, 2, 8])
@ -425,7 +485,10 @@ def test_nvfp4_grouped_gemm_blackwell(num_tokens: int, top_k: int, ep_size: int,
torch.testing.assert_close(c[:num_valid_permuted_tokens], c_ref[:num_valid_permuted_tokens])
@pytest.mark.skipif(get_sm_version() != 100, reason="This test is only supported on SM 100 GPUs")
@pytest.mark.skipif(
get_sm_version() not in (100, 103),
reason="This test is only supported on SM 100 and SM 103 GPUs",
)
@pytest.mark.parametrize("tile_size", [128])
@pytest.mark.parametrize("ep_size", [1, 8, 32])
@pytest.mark.parametrize("top_k", [1, 2, 8])
@ -523,7 +586,10 @@ def test_nvfp4_grouped_gemm_finalize_blackwell(
assert match_ratio > 0.99
@pytest.mark.skipif(get_sm_version() != 100, reason="This test is only supported on SM 100 GPUs")
@pytest.mark.skipif(
get_sm_version() not in (100, 103),
reason="This test is only supported on SM 100 and SM 103 GPUs",
)
@pytest.mark.parametrize("tile_size", [128])
@pytest.mark.parametrize("ep_size", [1, 8, 32])
@pytest.mark.parametrize("top_k", [1, 2, 8])