mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-9372][feat] Enable CuteDSL MoE with Large EP (#9592)
Signed-off-by: Enwei Zhu <21126786+syuoni@users.noreply.github.com>
This commit is contained in:
parent
c2f2add6df
commit
7cd5a67e25
@ -38,6 +38,7 @@
|
||||
#include <optional>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#ifndef _WIN32 // Linux
|
||||
#include <sys/sysinfo.h>
|
||||
#endif // not WIN32
|
||||
@ -432,6 +433,21 @@ inline int getMaxSharedMemoryPerBlockOptin()
|
||||
return nByteMaxSharedMemoryPerBlockOptin;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline int getMaxActiveBlocksPerSM(T kernel, int blockSize, size_t dynamicSMemSize)
|
||||
{
|
||||
static std::unordered_map<T, int> cache;
|
||||
auto it = cache.find(kernel);
|
||||
if (it != cache.end())
|
||||
{
|
||||
return it->second;
|
||||
}
|
||||
int numBlocks;
|
||||
check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocks, kernel, blockSize, dynamicSMemSize));
|
||||
cache[kernel] = numBlocks;
|
||||
return numBlocks;
|
||||
}
|
||||
|
||||
template <typename T1, typename T2>
|
||||
inline size_t divUp(T1 const& a, T2 const& b)
|
||||
{
|
||||
|
||||
@ -67,7 +67,7 @@ __global__ void moePermuteKernel(InputType const* input, InputType* permuted_out
|
||||
int64_t const kCopyPerToken = hidden_size / kElemPerCopy;
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.wait;");
|
||||
cudaGridDependencySynchronize();
|
||||
#endif
|
||||
|
||||
int32_t const num_tokens = num_non_exiting_tiles[0] * tile_size;
|
||||
@ -110,7 +110,7 @@ __global__ void moePermuteKernel(InputType const* input, InputType* permuted_out
|
||||
}
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -141,11 +141,11 @@ void moePermute(InputType const* input, InputType* permuted_output, SFType const
|
||||
}
|
||||
#endif
|
||||
|
||||
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
|
||||
int32_t const blocks = std::min(smCount * 8, max_num_permuted_tokens);
|
||||
int32_t const threads = kThreadsPerBlock;
|
||||
|
||||
auto kernel = &moePermuteKernel<InputType, SFType, kSFVecSize, kThreadsPerBlock>;
|
||||
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
|
||||
int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM(kernel, kThreadsPerBlock, 0);
|
||||
int32_t const blocks = std::min(smCount * maxBlocksPerSM, max_num_permuted_tokens);
|
||||
int32_t const threads = kThreadsPerBlock;
|
||||
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = blocks;
|
||||
@ -195,7 +195,7 @@ __global__ void moeUnpermuteKernel(InputType const* permuted_input, InputType* o
|
||||
int32_t const token_idx = blockIdx.x;
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.wait;");
|
||||
cudaGridDependencySynchronize();
|
||||
#endif
|
||||
|
||||
auto* dst_ptr = reinterpret_cast<ElemCopyType*>(output) + token_idx * kCopyPerToken;
|
||||
@ -232,7 +232,7 @@ __global__ void moeUnpermuteKernel(InputType const* permuted_input, InputType* o
|
||||
}
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -277,6 +277,105 @@ INSTANTIATE_MOE_UNPERMUTE(__nv_bfloat16, __nv_bfloat16);
|
||||
#endif
|
||||
#undef INSTANTIATE_MOE_UNPERMUTE
|
||||
|
||||
template <typename InputType, int32_t kThreadsPerBlock>
|
||||
__global__ void moeOutputMemsetKernel(InputType* input, int32_t const* tile_idx_to_mn_limit,
|
||||
int32_t const* expanded_idx_to_permuted_idx, int32_t const* permuted_idx_to_expanded_idx,
|
||||
int32_t const* num_non_exiting_tiles, int32_t const hidden_size, int32_t const top_k, int32_t const tile_size)
|
||||
{
|
||||
int32_t constexpr kElemPerCopy = elemPerCopy<InputType>();
|
||||
int64_t const kCopyPerToken = hidden_size / kElemPerCopy;
|
||||
|
||||
InputType rmem[kElemPerCopy];
|
||||
#pragma unroll
|
||||
for (int32_t j = 0; j < kElemPerCopy; j++)
|
||||
{
|
||||
rmem[j] = 0;
|
||||
}
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
cudaGridDependencySynchronize();
|
||||
#endif
|
||||
|
||||
int32_t const num_tokens = num_non_exiting_tiles[0] * tile_size;
|
||||
for (int32_t permuted_idx = blockIdx.x; permuted_idx < num_tokens; permuted_idx += gridDim.x)
|
||||
{
|
||||
int32_t const tile_idx = permuted_idx / tile_size;
|
||||
if (permuted_idx >= tile_idx_to_mn_limit[tile_idx])
|
||||
{
|
||||
continue;
|
||||
}
|
||||
int32_t const expanded_idx = permuted_idx_to_expanded_idx[permuted_idx];
|
||||
int32_t const token_idx = expanded_idx / top_k;
|
||||
int32_t const topk_idx = expanded_idx % top_k;
|
||||
|
||||
bool is_first_in_topk = true;
|
||||
for (int32_t k = 0; k < topk_idx; k++)
|
||||
{
|
||||
if (expanded_idx_to_permuted_idx[token_idx * top_k + k] >= 0)
|
||||
{
|
||||
is_first_in_topk = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!is_first_in_topk)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
auto* dst_ptr = reinterpret_cast<ElemCopyType*>(input) + token_idx * kCopyPerToken;
|
||||
for (int32_t i = threadIdx.x; i < kCopyPerToken; i += kThreadsPerBlock)
|
||||
{
|
||||
dst_ptr[i] = *reinterpret_cast<ElemCopyType*>(rmem);
|
||||
}
|
||||
}
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename InputType>
|
||||
void moeOutputMemset(InputType* input, int32_t const* tile_idx_to_mn_limit, int32_t const* expanded_idx_to_permuted_idx,
|
||||
int32_t const* permuted_idx_to_expanded_idx, int32_t const* num_non_exiting_tiles,
|
||||
int32_t const max_num_permuted_tokens, int32_t const hidden_size, int32_t const top_k, int32_t const tile_size,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
int32_t constexpr kThreadsPerBlock = 256;
|
||||
int32_t constexpr kElemPerCopy = elemPerCopy<InputType>();
|
||||
TLLM_CHECK_WITH_INFO(hidden_size % kElemPerCopy == 0, "hidden_size must be divisible by %d.", kElemPerCopy);
|
||||
|
||||
auto kernel = &moeOutputMemsetKernel<InputType, kThreadsPerBlock>;
|
||||
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
|
||||
int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM(kernel, kThreadsPerBlock, 0);
|
||||
int32_t const blocks = std::min(smCount * maxBlocksPerSM, max_num_permuted_tokens);
|
||||
int32_t const threads = kThreadsPerBlock;
|
||||
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = blocks;
|
||||
config.blockDim = threads;
|
||||
config.dynamicSmemBytes = 0;
|
||||
config.stream = stream;
|
||||
cudaLaunchAttribute attrs[1];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
cudaLaunchKernelEx(&config, kernel, input, tile_idx_to_mn_limit, expanded_idx_to_permuted_idx,
|
||||
permuted_idx_to_expanded_idx, num_non_exiting_tiles, hidden_size, top_k, tile_size);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_MOE_OUTPUT_MEMSET(InputType) \
|
||||
template void moeOutputMemset<InputType>(InputType * input, int32_t const* tile_idx_to_mn_limit, \
|
||||
int32_t const* expanded_idx_to_permuted_idx, int32_t const* permuted_idx_to_expanded_idx, \
|
||||
int32_t const* num_non_exiting_tiles, int32_t const max_num_permuted_tokens, int32_t const hidden_size, \
|
||||
int32_t const top_k, int32_t const tile_size, cudaStream_t stream)
|
||||
|
||||
INSTANTIATE_MOE_OUTPUT_MEMSET(half);
|
||||
#ifdef ENABLE_BF16
|
||||
INSTANTIATE_MOE_OUTPUT_MEMSET(__nv_bfloat16);
|
||||
#endif
|
||||
#undef INSTANTIATE_MOE_OUTPUT_MEMSET
|
||||
|
||||
template <typename InputType, typename OutputType, typename SFType, int32_t kSFVecSize, typename ActFn,
|
||||
int32_t kThreadsPerBlock>
|
||||
__global__ void moeActivationKernel(InputType const* input, OutputType* output, float const* global_sf,
|
||||
@ -297,7 +396,7 @@ __global__ void moeActivationKernel(InputType const* input, OutputType* output,
|
||||
ActFn act{};
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.wait;");
|
||||
cudaGridDependencySynchronize();
|
||||
#endif
|
||||
|
||||
float global_sf_val = global_sf == nullptr ? 1.0f : global_sf[0];
|
||||
@ -353,7 +452,7 @@ __global__ void moeActivationKernel(InputType const* input, OutputType* output,
|
||||
}
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
asm volatile("griddepcontrol.launch_dependents;");
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -382,10 +481,6 @@ void moeActivation(InputType const* input, OutputType* output, float const* glob
|
||||
}
|
||||
#endif
|
||||
|
||||
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
|
||||
int32_t const blocks = std::min(smCount * 8, max_num_permuted_tokens);
|
||||
int32_t const threads = kThreadsPerBlock;
|
||||
|
||||
auto get_act_kernel = [](ActivationType activation_type) -> void (*)(InputType const* input, OutputType* output,
|
||||
float const* global_sf, SFType* output_sf,
|
||||
int32_t const* tile_idx_to_mn_limit,
|
||||
@ -424,6 +519,11 @@ void moeActivation(InputType const* input, OutputType* output, float const* glob
|
||||
};
|
||||
auto kernel = get_act_kernel(activation_params.activation_type);
|
||||
|
||||
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
|
||||
int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM(kernel, kThreadsPerBlock, 0);
|
||||
int32_t const blocks = std::min(smCount * maxBlocksPerSM, max_num_permuted_tokens);
|
||||
int32_t const threads = kThreadsPerBlock;
|
||||
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = blocks;
|
||||
config.blockDim = threads;
|
||||
|
||||
@ -32,6 +32,12 @@ void moeUnpermute(InputType const* permuted_input, InputType* output, int32_t co
|
||||
TopKScaleType const* topk_scales, int32_t const num_tokens, int32_t const hidden_size, int32_t const top_k,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <typename InputType>
|
||||
void moeOutputMemset(InputType* input, int32_t const* tile_idx_to_mn_limit, int32_t const* expanded_idx_to_permuted_idx,
|
||||
int32_t const* permuted_idx_to_expanded_idx, int32_t const* num_non_exiting_tiles,
|
||||
int32_t const max_num_permuted_tokens, int32_t const hidden_size, int32_t const top_k, int32_t const tile_size,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <typename InputType, typename OutputType, typename SFType>
|
||||
void moeActivation(InputType const* input, OutputType* output, float const* global_sf, SFType* output_sf,
|
||||
int32_t const* tile_idx_to_mn_limit, int32_t const* num_non_exiting_tiles,
|
||||
|
||||
@ -1587,11 +1587,6 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
|
||||
int64_t num_padding_tokens = 0;
|
||||
#endif
|
||||
|
||||
static int64_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
|
||||
// Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200).
|
||||
int64_t const blocks = std::min(smCount * 8, std::max(num_rows * k, num_padding_tokens));
|
||||
int64_t const threads = EXPAND_THREADS_PER_BLOCK;
|
||||
|
||||
auto func = [&]()
|
||||
{
|
||||
#ifdef ENABLE_FP8
|
||||
@ -1637,6 +1632,12 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
|
||||
}
|
||||
}();
|
||||
|
||||
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
|
||||
int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM(func, EXPAND_THREADS_PER_BLOCK, 0);
|
||||
int32_t const blocks
|
||||
= std::min(smCount * maxBlocksPerSM, static_cast<int32_t>(std::max(num_rows * k, num_padding_tokens)));
|
||||
int32_t const threads = EXPAND_THREADS_PER_BLOCK;
|
||||
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = blocks;
|
||||
config.blockDim = threads;
|
||||
@ -1891,15 +1892,18 @@ void finalizeMoeRoutingKernelLauncher(GemmOutputType const* expanded_permuted_ro
|
||||
if (parallelism_config.ep_size > 1 && enable_alltoall)
|
||||
{
|
||||
// If all-to-all comm is enabled, finalizeMoeRouting doesn't need to fill the invalid output tokens with zeros.
|
||||
static int const smCount = tensorrt_llm::common::getMultiProcessorCount();
|
||||
// Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200).
|
||||
int64_t const blocks = smCount * 8;
|
||||
int64_t const threads = FINALIZE_THREADS_PER_BLOCK;
|
||||
config.gridDim = blocks;
|
||||
config.blockDim = threads;
|
||||
auto func = final_scales
|
||||
? &finalizeMoeRoutingNoFillingKernel<OutputType, GemmOutputType, ScaleBiasType, ScaleMode::DEFAULT>
|
||||
: &finalizeMoeRoutingNoFillingKernel<OutputType, GemmOutputType, ScaleBiasType, ScaleMode::NO_SCALE>;
|
||||
|
||||
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
|
||||
int32_t const maxBlocksPerSM
|
||||
= tensorrt_llm::common::getMaxActiveBlocksPerSM(func, FINALIZE_THREADS_PER_BLOCK, 0);
|
||||
int32_t const blocks = std::min(smCount * maxBlocksPerSM, static_cast<int32_t>(num_rows * experts_per_token));
|
||||
int32_t const threads = FINALIZE_THREADS_PER_BLOCK;
|
||||
|
||||
config.gridDim = blocks;
|
||||
config.blockDim = threads;
|
||||
cudaLaunchKernelEx(&config, func, expanded_permuted_rows, reduced_unpermuted_output, bias_ptr, final_scales,
|
||||
unpermuted_row_to_permuted_row, permuted_row_to_unpermuted_row, token_selected_experts,
|
||||
expert_first_token_offset, num_rows, padded_cols, unpadded_cols, experts_per_token, num_experts_per_node,
|
||||
@ -2235,11 +2239,6 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8
|
||||
int64_t num_padding_tokens = 0;
|
||||
#endif
|
||||
|
||||
static int64_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
|
||||
// Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200).
|
||||
int64_t const blocks = std::min(smCount * 8, std::max(expanded_num_tokens, num_padding_tokens));
|
||||
int64_t const threads = ACTIVATION_THREADS_PER_BLOCK;
|
||||
|
||||
auto fn = [&]()
|
||||
{
|
||||
// IMPORTANT: Keep the order of the activation functions in the same order as the ActivationType enum in
|
||||
@ -2302,6 +2301,12 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8
|
||||
}
|
||||
}();
|
||||
|
||||
static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount();
|
||||
int32_t const maxBlocksPerSM = tensorrt_llm::common::getMaxActiveBlocksPerSM(fn, ACTIVATION_THREADS_PER_BLOCK, 0);
|
||||
int32_t const blocks
|
||||
= std::min(smCount * maxBlocksPerSM, static_cast<int32_t>(std::max(expanded_num_tokens, num_padding_tokens)));
|
||||
int32_t const threads = ACTIVATION_THREADS_PER_BLOCK;
|
||||
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = blocks;
|
||||
config.blockDim = threads;
|
||||
|
||||
@ -647,7 +647,9 @@ void run(Data& data, void* stream)
|
||||
//
|
||||
// The upper bound is a strict requirement. The number of blocks should be determined by querying
|
||||
// the device properties, or conservatively low.
|
||||
static int const numBlocksCoop = tensorrt_llm::common::getMultiProcessorCount();
|
||||
static int const smCount = tensorrt_llm::common::getMultiProcessorCount();
|
||||
// WAR: Reserve 8 SMs for overlapping kernels.
|
||||
int const numBlocksCoop = smCount - 8;
|
||||
|
||||
// Maximum number of tokens supported by the kernel using a cooperative launch.
|
||||
int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK;
|
||||
|
||||
@ -139,6 +139,8 @@ std::tuple<torch::Tensor, torch::optional<torch::Tensor>> moe_permute(torch::Ten
|
||||
TORCH_CHECK(tile_idx_to_mn_limit.scalar_type() == torch::kInt32, "tile_idx_to_mn_limit must be int32.");
|
||||
int64_t const num_tiles = tile_idx_to_mn_limit.size(0);
|
||||
TORCH_CHECK(permuted_idx_to_expanded_idx.dim() == 1, "permuted_idx_to_expanded_idx must be 1D.");
|
||||
TORCH_CHECK(
|
||||
permuted_idx_to_expanded_idx.scalar_type() == torch::kInt32, "permuted_idx_to_expanded_idx must be int32.");
|
||||
int64_t const max_num_permuted_tokens = permuted_idx_to_expanded_idx.size(0);
|
||||
TORCH_CHECK(max_num_permuted_tokens == tile_tokens_dim * num_tiles,
|
||||
"max_num_permuted_tokens must be equal to tile_tokens_dim * num_tiles.");
|
||||
@ -253,6 +255,69 @@ torch::Tensor moe_unpermute(torch::Tensor const& permuted_input, torch::Tensor c
|
||||
return output;
|
||||
}
|
||||
|
||||
void moe_output_memset_inplace(torch::Tensor const& input, torch::Tensor const& tile_idx_to_mn_limit,
|
||||
torch::Tensor const& expanded_idx_to_permuted_idx, torch::Tensor const& permuted_idx_to_expanded_idx,
|
||||
torch::Tensor const& num_non_exiting_tiles, int64_t const tile_tokens_dim, int64_t const top_k,
|
||||
int64_t const ep_size, bool const enable_alltoall = false)
|
||||
{
|
||||
TORCH_CHECK(input.dim() == 2, "input must be 2D.");
|
||||
int64_t const num_tokens = input.size(0);
|
||||
int64_t const hidden_size = input.size(1);
|
||||
TORCH_CHECK(expanded_idx_to_permuted_idx.dim() == 2, "expanded_idx_to_permuted_idx must be 2D.");
|
||||
TORCH_CHECK(
|
||||
expanded_idx_to_permuted_idx.scalar_type() == torch::kInt32, "expanded_idx_to_permuted_idx must be int32.");
|
||||
TORCH_CHECK(
|
||||
expanded_idx_to_permuted_idx.size(0) == num_tokens, "expanded_idx_to_permuted_idx.size(0) must be num_tokens.");
|
||||
TORCH_CHECK(expanded_idx_to_permuted_idx.size(1) == top_k, "expanded_idx_to_permuted_idx.size(1) must be top_k.");
|
||||
TORCH_CHECK(tile_idx_to_mn_limit.dim() == 1, "tile_idx_to_mn_limit must be 1D.");
|
||||
TORCH_CHECK(tile_idx_to_mn_limit.scalar_type() == torch::kInt32, "tile_idx_to_mn_limit must be int32.");
|
||||
int64_t const num_tiles = tile_idx_to_mn_limit.size(0);
|
||||
TORCH_CHECK(permuted_idx_to_expanded_idx.dim() == 1, "permuted_idx_to_expanded_idx must be 1D.");
|
||||
TORCH_CHECK(
|
||||
permuted_idx_to_expanded_idx.scalar_type() == torch::kInt32, "permuted_idx_to_expanded_idx must be int32.");
|
||||
int64_t const max_num_permuted_tokens = permuted_idx_to_expanded_idx.size(0);
|
||||
TORCH_CHECK(max_num_permuted_tokens == tile_tokens_dim * num_tiles,
|
||||
"max_num_permuted_tokens must be equal to tile_tokens_dim * num_tiles.");
|
||||
TORCH_CHECK(max_num_permuted_tokens >= num_tokens * top_k,
|
||||
"max_num_permuted_tokens must be greater than or equal to num_tokens * top_k.");
|
||||
|
||||
TORCH_CHECK(num_non_exiting_tiles.numel() == 1, "num_non_exiting_tiles must have 1 element.");
|
||||
TORCH_CHECK(num_non_exiting_tiles.scalar_type() == torch::kInt32, "num_non_exiting_tiles must be int32.");
|
||||
|
||||
auto const& stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
|
||||
#define DISPATCH_MOE_OUTPUT_MEMSET(InputType) \
|
||||
do \
|
||||
{ \
|
||||
if (!enable_alltoall || ep_size <= top_k) \
|
||||
{ \
|
||||
cudaMemsetAsync(input.data_ptr(), 0x0, sizeof(InputType) * num_tokens * hidden_size, stream); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
tensorrt_llm::kernels::cute_dsl::moeOutputMemset<InputType>(static_cast<InputType*>(input.data_ptr()), \
|
||||
tile_idx_to_mn_limit.data_ptr<int32_t>(), expanded_idx_to_permuted_idx.data_ptr<int32_t>(), \
|
||||
permuted_idx_to_expanded_idx.data_ptr<int32_t>(), num_non_exiting_tiles.data_ptr<int32_t>(), \
|
||||
max_num_permuted_tokens, hidden_size, top_k, tile_tokens_dim, stream); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
if (input.scalar_type() == torch::kHalf)
|
||||
{
|
||||
DISPATCH_MOE_OUTPUT_MEMSET(half);
|
||||
}
|
||||
else if (input.scalar_type() == torch::kBFloat16)
|
||||
{
|
||||
DISPATCH_MOE_OUTPUT_MEMSET(__nv_bfloat16);
|
||||
}
|
||||
else
|
||||
{
|
||||
TORCH_CHECK(false, "Unsupported input dtype: ", input.scalar_type());
|
||||
}
|
||||
|
||||
#undef DISPATCH_MOE_OUTPUT_MEMSET
|
||||
}
|
||||
|
||||
// Activation
|
||||
|
||||
torch::Tensor moe_swiglu(torch::Tensor const& input, torch::Tensor const& tile_idx_to_mn_limit,
|
||||
@ -421,6 +486,10 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
"moe_permute(Tensor input, Tensor? input_sf, Tensor tile_idx_to_mn_limit, Tensor permuted_idx_to_expanded_idx, "
|
||||
"Tensor num_non_exiting_tiles, int tile_tokens_dim, int top_k) -> (Tensor, Tensor?)");
|
||||
m.def("moe_unpermute(Tensor permuted_input, Tensor expanded_idx_to_permuted_idx, Tensor topk_scales) -> Tensor");
|
||||
m.def(
|
||||
"moe_output_memset_inplace(Tensor(a!) input, Tensor tile_idx_to_mn_limit, Tensor expanded_idx_to_permuted_idx, "
|
||||
"Tensor permuted_idx_to_expanded_idx, Tensor num_non_exiting_tiles, int tile_tokens_dim, int top_k, int "
|
||||
"ep_size, bool enable_alltoall = False) -> ()");
|
||||
m.def(
|
||||
"moe_swiglu(Tensor input, Tensor tile_idx_to_mn_limit, Tensor num_non_exiting_tiles, "
|
||||
"int tile_tokens_dim) -> Tensor");
|
||||
@ -438,6 +507,7 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
m.impl("moe_sort", &torch_ext::moe_sort);
|
||||
m.impl("moe_permute", &torch_ext::moe_permute);
|
||||
m.impl("moe_unpermute", &torch_ext::moe_unpermute);
|
||||
m.impl("moe_output_memset_inplace", &torch_ext::moe_output_memset_inplace);
|
||||
m.impl("moe_swiglu", &torch_ext::moe_swiglu);
|
||||
m.impl("moe_swiglu_nvfp4_quantize", &torch_ext::moe_swiglu_nvfp4_quantize);
|
||||
m.impl("moe_gelu", &torch_ext::moe_gelu);
|
||||
|
||||
@ -730,7 +730,7 @@ class AutoTuner:
|
||||
# Log the cache miss. Expect no cache miss in inference.
|
||||
if not is_cache_hit:
|
||||
logger.warning_once(
|
||||
f"[AutoTunner] Using the fallback tactic, due to cache miss on input shapes={input_shapes}",
|
||||
f"[AutoTuner] {custom_op} using the fallback tactic, due to cache miss on input shapes={input_shapes}",
|
||||
key=(custom_op, "warning_autotuning_cache_miss_fallback"))
|
||||
|
||||
return (best_runner, best_tactic)
|
||||
|
||||
@ -77,6 +77,13 @@ def inplace_info():
|
||||
torch.ops.trtllm.logits_bitmask.default: {
|
||||
1: "logits"
|
||||
},
|
||||
torch.ops.trtllm.moe_output_memset_inplace.default: {
|
||||
1: "input"
|
||||
},
|
||||
torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell.default:
|
||||
{
|
||||
6: "output"
|
||||
},
|
||||
torch.ops.trtllm.pp_recv_tensors.default: {
|
||||
1: "tensors"
|
||||
},
|
||||
|
||||
@ -149,7 +149,7 @@ class GroupedGemmInputsHelper:
|
||||
|
||||
def inputs_pre_hook_finalize_fusion(
|
||||
self, inputs: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, token_final_scales = inputs
|
||||
a, b, a_sf, b_sf, alpha, output, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, token_final_scales = inputs
|
||||
num_tokens = self.infer_num_tokens(a.size(0))
|
||||
num_tokens_per_expert = self.generate_num_tokens_per_expert(num_tokens)
|
||||
tile_idx_to_group_idx_list = self.generate_tile_idx_to_group_idx(
|
||||
@ -184,7 +184,7 @@ class GroupedGemmInputsHelper:
|
||||
[num_non_exiting_tiles_val],
|
||||
dtype=num_non_exiting_tiles.dtype,
|
||||
device=num_non_exiting_tiles.device)
|
||||
return a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, token_final_scales
|
||||
return a, b, a_sf, b_sf, alpha, output, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, token_final_scales
|
||||
|
||||
|
||||
class FusedMoEInputsHelper:
|
||||
@ -268,8 +268,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
**kwargs,
|
||||
) -> List[Tuple[int, int]]:
|
||||
# Early exit: Check SM version - CuteDSL NVFP4 only supports SM 100 and SM 103
|
||||
sm_version = get_sm_version()
|
||||
if sm_version not in [100, 103]:
|
||||
if (sm_version := get_sm_version()) not in (100, 103):
|
||||
logger.debug(
|
||||
f"CuteDSL: SM version {sm_version} is not supported. "
|
||||
f"CuteDSL NVFP4 only supports SM 100 (B200) and SM 103 (B300). Skipping all tactics."
|
||||
@ -597,8 +596,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
for automatic backend selection with better performance.
|
||||
"""
|
||||
# Validate SM version before attempting to use CuteDSL
|
||||
sm_version = get_sm_version()
|
||||
if sm_version not in [100, 103]:
|
||||
if (sm_version := get_sm_version()) not in (100, 103):
|
||||
raise ValueError(
|
||||
f"CuteDSL NVFP4 backend requires SM 100 (B200) or SM 103 (B300), but got SM {sm_version}. "
|
||||
f"Please use nvfp4_gemm with backend='auto' for automatic backend selection."
|
||||
@ -660,9 +658,9 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
self.output_dtype = output_dtype
|
||||
self.scaling_vector_size = scaling_vector_size
|
||||
|
||||
if get_sm_version() != 100:
|
||||
if (sm_version := get_sm_version()) not in (100, 103):
|
||||
raise ValueError(
|
||||
f"SM version {get_sm_version()} is not supported for {self.__class__.__name__}, it only supports SM 100"
|
||||
f"{self.__class__.kernel_class.__name__} supports SM 100 (B200) and SM 103 (B300) only, but got SM {sm_version}"
|
||||
)
|
||||
|
||||
def unique_id(self):
|
||||
@ -947,9 +945,9 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
self.output_dtype = output_dtype
|
||||
self.scaling_vector_size = scaling_vector_size
|
||||
|
||||
if get_sm_version() != 100:
|
||||
if (sm_version := get_sm_version()) not in (100, 103):
|
||||
raise ValueError(
|
||||
f"SM version {get_sm_version()} is not supported for {self.__class__.__name__}, it only supports SM 100"
|
||||
f"{self.__class__.kernel_class.__name__} supports SM 100 (B200) and SM 103 (B300) only, but got SM {sm_version}"
|
||||
)
|
||||
|
||||
def unique_id(self):
|
||||
@ -1015,11 +1013,12 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
helper.map_to_tuning_buckets), ),
|
||||
constraint_specs=(
|
||||
ConstraintSpec(2, 0, fp4_scale_infer_shape),
|
||||
ConstraintSpec(5, 0, helper.infer_shape_max_num_tiles),
|
||||
ConstraintSpec(5, 0, helper.infer_shape_num_tokens),
|
||||
ConstraintSpec(6, 0, helper.infer_shape_max_num_tiles),
|
||||
ConstraintSpec(7, 0, helper.infer_shape_max_num_tiles),
|
||||
ConstraintSpec(
|
||||
7, 0, helper.infer_shape_max_num_permuted_tokens),
|
||||
ConstraintSpec(9, 0, helper.infer_shape_num_tokens)),
|
||||
8, 0, helper.infer_shape_max_num_permuted_tokens),
|
||||
ConstraintSpec(10, 0, helper.infer_shape_num_tokens)),
|
||||
inputs_pre_hook=helper.inputs_pre_hook_finalize_fusion,
|
||||
use_cuda_graph=True,
|
||||
)
|
||||
@ -1027,7 +1026,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
|
||||
def forward(self, inputs: List[torch.Tensor],
|
||||
tactic: Optional[tuple]) -> torch.Tensor:
|
||||
a, b, a_sf, b_sf, alpha, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, token_final_scales = inputs
|
||||
a, b, a_sf, b_sf, alpha, c, tile_idx_to_group_idx, tile_idx_to_mn_limit, permuted_idx_to_expanded_idx, num_non_exiting_tiles, token_final_scales = inputs
|
||||
assert a.dtype == torch.float4_e2m1fn_x2
|
||||
assert a.dim() == 2
|
||||
assert b.dtype == torch.float4_e2m1fn_x2
|
||||
@ -1051,6 +1050,11 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
assert b_sf.size(2) == scale_k
|
||||
assert alpha.size(0) == l
|
||||
|
||||
assert c.dtype == self.output_dtype
|
||||
assert c.dim() == 2
|
||||
num_tokens = c.size(0)
|
||||
assert c.size(1) == n
|
||||
|
||||
num_tiles = m // self.tile_size
|
||||
assert tile_idx_to_group_idx.dtype == torch.int32
|
||||
assert tile_idx_to_group_idx.size() == (num_tiles, )
|
||||
@ -1062,14 +1066,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
assert num_non_exiting_tiles.numel() == 1
|
||||
assert token_final_scales.dtype == torch.float32
|
||||
assert token_final_scales.dim() == 2
|
||||
num_tokens = token_final_scales.size(0)
|
||||
assert token_final_scales.size(1) == self.top_k
|
||||
|
||||
# TODO: Overlap the memset
|
||||
c = torch.zeros(num_tokens,
|
||||
n,
|
||||
dtype=self.output_dtype,
|
||||
device=a.device)
|
||||
assert token_final_scales.size() == (num_tokens, self.top_k)
|
||||
|
||||
a_ptr = make_ptr(cutlass.Float4E2M1FN,
|
||||
a.data_ptr(),
|
||||
@ -1182,6 +1179,51 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
)
|
||||
return c
|
||||
|
||||
@torch.library.custom_op(
|
||||
"trtllm::cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell",
|
||||
mutates_args=("output", ),
|
||||
device_types="cuda")
|
||||
def cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
alpha: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
tile_idx_to_group_idx: torch.Tensor,
|
||||
tile_idx_to_mn_limit: torch.Tensor,
|
||||
permuted_idx_to_expanded_idx: torch.Tensor,
|
||||
num_non_exiting_tiles: torch.Tensor,
|
||||
token_final_scales: torch.Tensor,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
num_local_experts: int,
|
||||
local_expert_offset: int,
|
||||
tile_size: int,
|
||||
output_dtype: torch.dtype,
|
||||
scaling_vector_size: int = 16,
|
||||
) -> None:
|
||||
tuner = AutoTuner.get()
|
||||
|
||||
runner = Sm100BlockScaledContiguousGroupedGemmFinalizeFusionRunner(
|
||||
num_experts, top_k, num_local_experts, local_expert_offset,
|
||||
tile_size, output_dtype, scaling_vector_size)
|
||||
|
||||
inputs = [
|
||||
input, weight, input_scale, weight_scale, alpha, output,
|
||||
tile_idx_to_group_idx, tile_idx_to_mn_limit,
|
||||
permuted_idx_to_expanded_idx, num_non_exiting_tiles,
|
||||
token_final_scales
|
||||
]
|
||||
|
||||
_, best_tactic = tuner.choose_one(
|
||||
"trtllm::cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell",
|
||||
[runner],
|
||||
runner.get_tuning_config(),
|
||||
inputs,
|
||||
)
|
||||
runner(inputs, tactic=best_tactic)
|
||||
|
||||
@torch.library.custom_op(
|
||||
"trtllm::cute_dsl_nvfp4_grouped_gemm_finalize_blackwell",
|
||||
mutates_args=(),
|
||||
@ -1205,25 +1247,32 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
output_dtype: torch.dtype,
|
||||
scaling_vector_size: int = 16,
|
||||
) -> torch.Tensor:
|
||||
tuner = AutoTuner.get()
|
||||
|
||||
runner = Sm100BlockScaledContiguousGroupedGemmFinalizeFusionRunner(
|
||||
num_experts, top_k, num_local_experts, local_expert_offset,
|
||||
tile_size, output_dtype, scaling_vector_size)
|
||||
inputs = [
|
||||
input, weight, input_scale, weight_scale, alpha,
|
||||
tile_idx_to_group_idx, tile_idx_to_mn_limit,
|
||||
permuted_idx_to_expanded_idx, num_non_exiting_tiles,
|
||||
token_final_scales
|
||||
]
|
||||
|
||||
_, best_tactic = tuner.choose_one(
|
||||
"trtllm::cute_dsl_nvfp4_grouped_gemm_finalize_blackwell",
|
||||
[runner],
|
||||
runner.get_tuning_config(),
|
||||
inputs,
|
||||
num_tokens = token_final_scales.size(0)
|
||||
n = weight.size(1)
|
||||
output = torch.zeros(num_tokens,
|
||||
n,
|
||||
dtype=output_dtype,
|
||||
device=input.device)
|
||||
torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell(
|
||||
input=input,
|
||||
weight=weight,
|
||||
input_scale=input_scale,
|
||||
weight_scale=weight_scale,
|
||||
alpha=alpha,
|
||||
output=output,
|
||||
tile_idx_to_group_idx=tile_idx_to_group_idx,
|
||||
tile_idx_to_mn_limit=tile_idx_to_mn_limit,
|
||||
permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx,
|
||||
num_non_exiting_tiles=num_non_exiting_tiles,
|
||||
token_final_scales=token_final_scales,
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
num_local_experts=num_local_experts,
|
||||
local_expert_offset=local_expert_offset,
|
||||
tile_size=tile_size,
|
||||
output_dtype=output_dtype,
|
||||
scaling_vector_size=scaling_vector_size,
|
||||
)
|
||||
output = runner(inputs, tactic=best_tactic)
|
||||
return output
|
||||
|
||||
@torch.library.register_fake(
|
||||
@ -1275,9 +1324,9 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
self.tile_size = tile_size
|
||||
self.scaling_vector_size = scaling_vector_size
|
||||
|
||||
if get_sm_version() != 100:
|
||||
if (sm_version := get_sm_version()) not in (100, 103):
|
||||
raise ValueError(
|
||||
f"SM version {get_sm_version()} is not supported for {self.__class__.__name__}, it only supports SM 100"
|
||||
f"{self.__class__.kernel_class.__name__} supports SM 100 (B200) and SM 103 (B300) only, but got SM {sm_version}"
|
||||
)
|
||||
|
||||
def unique_id(self):
|
||||
|
||||
@ -2631,6 +2631,7 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
):
|
||||
scale_k = k // scaling_vector_size
|
||||
interm_size = n // 2
|
||||
scale_interm_size = interm_size // scaling_vector_size
|
||||
num_tiles = m // tile_size
|
||||
a = cute.make_tensor(a_ptr, layout=cute.make_ordered_layout((m, k, 1), order=(1, 0, 2)))
|
||||
b = cute.make_tensor(b_ptr, layout=cute.make_ordered_layout((n, k, l), order=(1, 0, 2)))
|
||||
@ -2652,7 +2653,7 @@ class Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel:
|
||||
c_sf = cute.make_tensor(
|
||||
c_sf_ptr,
|
||||
layout=cute.make_ordered_layout(
|
||||
(32, 4, interm_size // 128, 4, scale_k // 4, l), order=(2, 1, 4, 0, 3, 5)
|
||||
(32, 4, m // 128, 4, scale_interm_size // 4, 1), order=(2, 1, 4, 0, 3, 5)
|
||||
),
|
||||
)
|
||||
alpha = cute.make_tensor(alpha_ptr, layout=cute.make_layout((l,)))
|
||||
|
||||
@ -47,7 +47,9 @@ from .communication import (
|
||||
NVLinkOneSided,
|
||||
NVLinkTwoSided,
|
||||
)
|
||||
from .fused_moe_cute_dsl import CuteDslFusedMoE
|
||||
from .fused_moe_cutlass import CutlassFusedMoE
|
||||
from .fused_moe_deepgemm import DeepGemmFusedMoE
|
||||
from .fused_moe_trtllm_gen import TRTLLMGenFusedMoE
|
||||
|
||||
|
||||
@ -56,7 +58,7 @@ class ConfigurableMoE(MoE):
|
||||
Configurable MoE layer using composition pattern with automatic configuration
|
||||
|
||||
This class orchestrates the MoE execution flow by composing:
|
||||
- moe_backend: Existing FusedMoE implementation (CutlassFusedMoE, WideEPMoE, etc.)
|
||||
- moe_backend: Existing FusedMoE implementation (CutlassFusedMoE, CuteDslFusedMoE, etc.)
|
||||
Note: Current FusedMoE implementations are used as backends (transitional).
|
||||
Future will have dedicated MoEBackend interface.
|
||||
- Communication: Handles distributed communication (auto-selected)
|
||||
@ -797,7 +799,7 @@ class ConfigurableMoE(MoE):
|
||||
"""
|
||||
Get the current MoE backend implementation
|
||||
|
||||
Note: Returns a FusedMoE instance (e.g., CutlassFusedMoE, WideEPMoE)
|
||||
Note: Returns a FusedMoE instance (e.g., CutlassFusedMoE, CuteDslFusedMoE)
|
||||
"""
|
||||
return self._backend
|
||||
|
||||
@ -902,27 +904,26 @@ class ConfigurableMoE(MoE):
|
||||
Returns:
|
||||
Dict: Backend-specific keyword arguments
|
||||
"""
|
||||
backend_name = self.backend.__class__.__name__
|
||||
kwargs = {}
|
||||
|
||||
# Common parameters for Cutlass and DeepGemm
|
||||
if backend_name in ["CutlassFusedMoE", "DeepGemmFusedMoE"]:
|
||||
if self.backend.__class__ in (CutlassFusedMoE, DeepGemmFusedMoE, CuteDslFusedMoE):
|
||||
pass
|
||||
|
||||
# Cutlass-specific parameters
|
||||
if backend_name == "CutlassFusedMoE":
|
||||
if self.backend.__class__ == CutlassFusedMoE:
|
||||
pass
|
||||
|
||||
# WideEP-specific parameters
|
||||
elif backend_name == "WideEPMoE":
|
||||
pass
|
||||
# CuteDSL-specific parameters
|
||||
elif self.backend.__class__ == CuteDslFusedMoE:
|
||||
kwargs["enable_alltoall"] = self.enable_alltoall
|
||||
|
||||
# DeepGemm-specific parameters
|
||||
elif backend_name == "DeepGemmFusedMoE":
|
||||
elif self.backend.__class__ == DeepGemmFusedMoE:
|
||||
pass
|
||||
|
||||
# TRTLLMGen-specific parameters
|
||||
elif backend_name == "TRTLLMGenFusedMoE":
|
||||
elif self.backend.__class__ == TRTLLMGenFusedMoE:
|
||||
# Determine router_logits based on whether routing has been done
|
||||
# If backend doesn't support load balancer, routing is done before communication
|
||||
# In that case, router_logits should be None (routing already done)
|
||||
|
||||
@ -20,6 +20,8 @@ from .interface import MoE, MoEWeightLoadingMode
|
||||
from .moe_load_balancer import get_moe_load_balancer
|
||||
from .routing import BaseMoeRoutingMethod
|
||||
|
||||
ENABLE_CONFIGURABLE_MOE = os.environ.get("ENABLE_CONFIGURABLE_MOE", "0") == "1"
|
||||
|
||||
|
||||
def get_moe_cls(
|
||||
model_config: ModelConfig,
|
||||
@ -33,7 +35,16 @@ def get_moe_cls(
|
||||
elif moe_backend.upper() == "VANILLA":
|
||||
return VanillaMoE
|
||||
elif moe_backend.upper() == "CUTEDSL":
|
||||
return CuteDslFusedMoE
|
||||
if quant_config is not None and (
|
||||
quant_config.quant_mode.has_fp8_block_scales()
|
||||
or quant_config.quant_mode.has_nvfp4()):
|
||||
return CuteDslFusedMoE
|
||||
else:
|
||||
logger.warning(
|
||||
"CuteDslFusedMoE only supports fp8_block_scales and nvfp4. "
|
||||
f"Check out details in quant_config: {quant_config}. Using CutlassFusedMoE instead."
|
||||
)
|
||||
return CutlassFusedMoE
|
||||
elif moe_backend.upper() == "DEEPGEMM":
|
||||
return DeepGemmFusedMoE
|
||||
elif moe_backend.upper() == "TRTLLM":
|
||||
@ -48,8 +59,8 @@ def get_moe_cls(
|
||||
else:
|
||||
logger.warning(
|
||||
"TRTLLMGenFusedMoE only supports fp8_block_scales, nvfp4, w4a16_mxfp4, w4a8_mxfp4_fp8 and w4a8_mxfp4_mxfp8. "
|
||||
f"Check out details in quant_config: {quant_config}"
|
||||
"Using CutlassFusedMoE instead.")
|
||||
f"Check out details in quant_config: {quant_config}. Using CutlassFusedMoE instead."
|
||||
)
|
||||
return CutlassFusedMoE
|
||||
elif moe_backend.upper() == "WIDEEP":
|
||||
return WideEPMoE
|
||||
@ -129,8 +140,8 @@ def create_moe_backend(
|
||||
moe_load_balancer = get_moe_load_balancer()
|
||||
if moe_load_balancer is not None:
|
||||
assert moe_cls in [
|
||||
WideEPMoE, CutlassFusedMoE, TRTLLMGenFusedMoE
|
||||
], "MoE Load Balance is only supported in WideEPMoE, CutlassFusedMoE and TRTLLMGenFusedMoE now."
|
||||
WideEPMoE, CutlassFusedMoE, TRTLLMGenFusedMoE, CuteDslFusedMoE
|
||||
], "MoE Load Balance is only supported in WideEPMoE, CutlassFusedMoE, TRTLLMGenFusedMoE and CuteDslFusedMoE now."
|
||||
|
||||
if bias:
|
||||
assert moe_cls in [CutlassFusedMoE, TritonFusedMoE, TRTLLMGenFusedMoE
|
||||
@ -229,6 +240,8 @@ def create_moe_backend(
|
||||
weight_loading_mode=weight_loading_mode,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
layer_idx=layer_idx,
|
||||
init_load_balancer=init_load_balancer,
|
||||
without_comm=without_comm,
|
||||
)
|
||||
elif moe_cls == DeepGemmFusedMoE:
|
||||
return moe_cls(
|
||||
@ -331,13 +344,9 @@ def create_moe(
|
||||
|
||||
moe_cls = get_moe_cls(model_config, override_quant_config)
|
||||
|
||||
# Check if ENABLE_CONFIGURABLE_MOE environment variable is set
|
||||
enable_configurable_moe = os.environ.get('ENABLE_CONFIGURABLE_MOE',
|
||||
'0') == '1'
|
||||
|
||||
if enable_configurable_moe:
|
||||
# ConfigurableMoE is only supported for TRTLLMGenFusedMoE backend
|
||||
if moe_cls == TRTLLMGenFusedMoE:
|
||||
if ENABLE_CONFIGURABLE_MOE or moe_cls == CuteDslFusedMoE:
|
||||
# ConfigurableMoE only supports TRTLLMGenFusedMoE and CuteDslFusedMoE backends
|
||||
if moe_cls in (TRTLLMGenFusedMoE, CuteDslFusedMoE):
|
||||
return ConfigurableMoE(
|
||||
routing_method=routing_method,
|
||||
num_experts=num_experts,
|
||||
@ -358,12 +367,13 @@ def create_moe(
|
||||
else:
|
||||
# Check if this is a TRTLLM backend request that fallback to CutlassFusedMoE
|
||||
requested_backend = model_config.moe_backend.upper()
|
||||
if requested_backend == "TRTLLM" and moe_cls == CutlassFusedMoE:
|
||||
if requested_backend in ("TRTLLM",
|
||||
"CUTEDSL") and moe_cls == CutlassFusedMoE:
|
||||
# Workaround for test cases where TRTLLM backend fallbacks to CutlassFusedMoE due to quant_config incompatibility
|
||||
# Log warning and continue with the fallback backend
|
||||
logger.warning(
|
||||
f"ENABLE_CONFIGURABLE_MOE is set but TRTLLM backend fallback to {moe_cls.__name__} due to quant_config. "
|
||||
f"ConfigurableMoE only supports TRTLLMGenFusedMoE backend. "
|
||||
f"ConfigurableMoE only supports TRTLLMGenFusedMoE and CuteDslFusedMoE backends. "
|
||||
f"Continuing with legacy MoE backend {moe_cls.__name__}.")
|
||||
else:
|
||||
# For other incompatible backends, raise error
|
||||
|
||||
@ -8,7 +8,7 @@ from tensorrt_llm._utils import is_sm_100f
|
||||
|
||||
from ...distributed import allgather
|
||||
from ...model_config import ModelConfig
|
||||
from ...utils import AuxStreamType, Fp4QuantizedTensor, ceil_div
|
||||
from ...utils import AuxStreamType, Fp4QuantizedTensor
|
||||
from .fused_moe_cutlass import CutlassFusedMoE
|
||||
from .interface import AlltoallMethodType
|
||||
from .quantization import MoEWeightLoadingMode, NVFP4CuteDslFusedMoEMethod
|
||||
@ -180,6 +180,8 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
VANILLA,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
layer_idx: Optional[int] = None,
|
||||
init_load_balancer: bool = True,
|
||||
without_comm: bool = False,
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
@ -194,6 +196,8 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
weight_loading_mode=weight_loading_mode,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
layer_idx=layer_idx,
|
||||
init_load_balancer=init_load_balancer,
|
||||
without_comm=without_comm,
|
||||
)
|
||||
|
||||
def select_alltoall_method_type(self) -> AlltoallMethodType:
|
||||
@ -206,51 +210,163 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
return NVFP4CuteDslFusedMoEMethod()
|
||||
return super()._get_quant_method()
|
||||
|
||||
def forward_chunk_unquantized(
|
||||
self,
|
||||
x: Union[torch.Tensor, Fp4QuantizedTensor],
|
||||
router_logits: torch.Tensor,
|
||||
output_dtype: Optional[torch.dtype] = None,
|
||||
all_rank_num_tokens: Optional[List[int]] = None,
|
||||
use_dp_padding: Optional[bool] = None,
|
||||
repeating_info: tuple = (True, True),
|
||||
) -> torch.Tensor:
|
||||
assert not self.has_any_quant
|
||||
return super().forward_chunk(x,
|
||||
router_logits,
|
||||
output_dtype=output_dtype,
|
||||
all_rank_num_tokens=all_rank_num_tokens,
|
||||
use_dp_padding=use_dp_padding,
|
||||
repeating_info=repeating_info)
|
||||
def quantize_input(self,
|
||||
x: Union[torch.Tensor, Fp4QuantizedTensor],
|
||||
post_quant_comm: bool = True):
|
||||
"""Quantize inputs prior to post-communication (alltoall/allgather) or before MoE computation.
|
||||
|
||||
def forward_chunk_fp8_block_scales(
|
||||
self,
|
||||
x: Union[torch.Tensor, Fp4QuantizedTensor],
|
||||
router_logits: torch.Tensor,
|
||||
output_dtype: Optional[torch.dtype] = None,
|
||||
all_rank_num_tokens: Optional[List[int]] = None,
|
||||
use_dp_padding: Optional[bool] = None,
|
||||
repeating_info: tuple = (True, True),
|
||||
Args:
|
||||
x: Input tensor to quantize
|
||||
post_quant_comm:
|
||||
If True, quantize for post-quant communication path.
|
||||
If False, quantize for non-communication path
|
||||
|
||||
Returns: (x, x_sf) where x_sf is already reshaped to 2D if needed
|
||||
|
||||
For quantization methods that produce scaling factors:
|
||||
- x_sf is reshaped from 1D to 2D: [num_elements] -> [batch_size, ceil_div(hidden_size, scaling_vector_size)]
|
||||
- The 2D shape is required for proper handling in alltoall/allgather operations
|
||||
- scaling_vector_size is typically the group size for block-wise quantization
|
||||
"""
|
||||
x_sf = None
|
||||
if self.has_nvfp4:
|
||||
if isinstance(x, Fp4QuantizedTensor):
|
||||
assert not x.is_sf_swizzled, "Fp4QuantizedTensor should not be swizzled before communication"
|
||||
x_row = x.shape[0]
|
||||
x, x_sf = x.fp4_tensor, x.scaling_factor
|
||||
else:
|
||||
x_row = x.shape[0]
|
||||
x, x_sf = torch.ops.trtllm.fp4_quantize(
|
||||
x, self.fc31_input_scale, self.scaling_vector_size, False,
|
||||
False)
|
||||
elif self.has_deepseek_fp8_block_scales:
|
||||
# FP8 block scales doesn't support permutation of quantized inputs.
|
||||
# WAR: The quantization is in run_moe_fp8_block_scales.
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} doesn't support quantization mode {self.quant_config.quant_mode}."
|
||||
)
|
||||
|
||||
if x_sf is not None:
|
||||
x_sf = x_sf.view(x_row, -1)
|
||||
return x, x_sf
|
||||
|
||||
def run_moe_nvfp4(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
token_selected_experts: torch.Tensor,
|
||||
token_final_scales: Optional[torch.Tensor],
|
||||
x_sf: Optional[torch.Tensor] = None,
|
||||
enable_alltoall: bool = False,
|
||||
) -> torch.Tensor:
|
||||
assert self.has_nvfp4
|
||||
output_dtype = torch.bfloat16
|
||||
tile_size = 128
|
||||
|
||||
tile_idx_to_expert_idx, tile_idx_to_mn_limit, expanded_idx_to_permuted_idx, permuted_idx_to_expanded_idx, total_num_padded_tokens, num_non_exiting_tiles = torch.ops.trtllm.moe_sort(
|
||||
token_selected_experts=token_selected_experts,
|
||||
token_final_scales=token_final_scales,
|
||||
num_experts=self.num_slots,
|
||||
top_k=self.routing_method.experts_per_token,
|
||||
local_expert_offset=self.slot_start,
|
||||
local_num_experts=self.expert_size_per_partition,
|
||||
tile_tokens_dim=tile_size,
|
||||
)
|
||||
x, x_sf = torch.ops.trtllm.moe_permute(
|
||||
input=x.view(torch.float4_e2m1fn_x2),
|
||||
input_sf=x_sf,
|
||||
tile_idx_to_mn_limit=tile_idx_to_mn_limit,
|
||||
permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx,
|
||||
num_non_exiting_tiles=num_non_exiting_tiles,
|
||||
tile_tokens_dim=tile_size,
|
||||
top_k=self.routing_method.experts_per_token,
|
||||
)
|
||||
x, x_sf = torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_swiglu_blackwell(
|
||||
input=x.view(torch.float4_e2m1fn_x2),
|
||||
weight=self.w3_w1_weight.view(torch.float4_e2m1fn_x2),
|
||||
input_scale=x_sf.view(torch.uint8),
|
||||
weight_scale=self.quant_scales.fc1_weight_block.view(torch.uint8),
|
||||
alpha=self.quant_scales.fc1_global,
|
||||
tile_idx_to_group_idx=tile_idx_to_expert_idx,
|
||||
num_non_exiting_tiles=num_non_exiting_tiles,
|
||||
global_sf=self.fc2_input_scale,
|
||||
num_experts=self.num_slots,
|
||||
top_k=self.routing_method.experts_per_token,
|
||||
num_local_experts=self.expert_size_per_partition,
|
||||
local_expert_offset=self.slot_start,
|
||||
tile_size=tile_size,
|
||||
)
|
||||
if self.use_fused_finalize:
|
||||
output = torch.empty((token_final_scales.size(0), self.hidden_size),
|
||||
dtype=output_dtype,
|
||||
device=x.device)
|
||||
torch.ops.trtllm.moe_output_memset_inplace(
|
||||
input=output,
|
||||
tile_idx_to_mn_limit=tile_idx_to_mn_limit,
|
||||
expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx,
|
||||
permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx,
|
||||
num_non_exiting_tiles=num_non_exiting_tiles,
|
||||
tile_tokens_dim=tile_size,
|
||||
top_k=self.routing_method.experts_per_token,
|
||||
ep_size=self.mapping.moe_ep_size,
|
||||
enable_alltoall=enable_alltoall,
|
||||
)
|
||||
torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell(
|
||||
input=x.view(torch.float4_e2m1fn_x2),
|
||||
weight=self.w2_weight.view(torch.float4_e2m1fn_x2),
|
||||
input_scale=x_sf.view(torch.uint8),
|
||||
weight_scale=self.quant_scales.fc2_weight_block.view(
|
||||
torch.uint8),
|
||||
alpha=self.quant_scales.fc2_global,
|
||||
output=output,
|
||||
tile_idx_to_group_idx=tile_idx_to_expert_idx,
|
||||
tile_idx_to_mn_limit=tile_idx_to_mn_limit,
|
||||
permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx,
|
||||
num_non_exiting_tiles=num_non_exiting_tiles,
|
||||
token_final_scales=token_final_scales,
|
||||
num_experts=self.num_slots,
|
||||
top_k=self.routing_method.experts_per_token,
|
||||
num_local_experts=self.expert_size_per_partition,
|
||||
local_expert_offset=self.slot_start,
|
||||
tile_size=tile_size,
|
||||
output_dtype=output_dtype,
|
||||
)
|
||||
x = output
|
||||
else:
|
||||
x = torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_blackwell(
|
||||
input=x.view(torch.float4_e2m1fn_x2),
|
||||
weight=self.w2_weight.view(torch.float4_e2m1fn_x2),
|
||||
input_scale=x_sf.view(torch.uint8),
|
||||
weight_scale=self.quant_scales.fc2_weight_block.view(
|
||||
torch.uint8),
|
||||
alpha=self.quant_scales.fc2_global,
|
||||
tile_idx_to_group_idx=tile_idx_to_expert_idx,
|
||||
num_non_exiting_tiles=num_non_exiting_tiles,
|
||||
num_experts=self.num_slots,
|
||||
top_k=self.routing_method.experts_per_token,
|
||||
num_local_experts=self.expert_size_per_partition,
|
||||
local_expert_offset=self.slot_start,
|
||||
tile_size=tile_size,
|
||||
output_dtype=output_dtype,
|
||||
)
|
||||
x = torch.ops.trtllm.moe_unpermute(
|
||||
permuted_input=x,
|
||||
expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx,
|
||||
topk_scales=token_final_scales,
|
||||
)
|
||||
return x
|
||||
|
||||
def run_moe_fp8_block_scales(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
token_selected_experts: torch.Tensor,
|
||||
token_final_scales: Optional[torch.Tensor],
|
||||
x_sf: Optional[torch.Tensor] = None,
|
||||
enable_alltoall: bool = False,
|
||||
) -> torch.Tensor:
|
||||
assert self.has_deepseek_fp8_block_scales
|
||||
|
||||
# apply routing
|
||||
token_selected_experts, token_final_scales = self.routing_method.apply(
|
||||
router_logits)
|
||||
assert token_selected_experts.shape[
|
||||
1] == self.routing_method.experts_per_token
|
||||
assert token_selected_experts.shape == token_final_scales.shape
|
||||
assert token_selected_experts.shape[0] == router_logits.shape[0]
|
||||
assert token_final_scales.dtype == torch.float32
|
||||
assert token_selected_experts.dtype == torch.int32
|
||||
|
||||
if self.apply_router_weight_on_input:
|
||||
assert self.routing_method.top_k == 1, "Current workaround only supports top-1 routing"
|
||||
assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input"
|
||||
x = x * token_final_scales.to(x.dtype)
|
||||
# TODO: remove this once we have correct fusedmoe kernel ready
|
||||
token_final_scales = None
|
||||
|
||||
assert x_sf is None
|
||||
weight_dtype = self.w3_w1_weight.dtype
|
||||
|
||||
(
|
||||
@ -304,7 +420,7 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
permuted_row_to_unpermuted_row_tensor,
|
||||
token_selected_experts,
|
||||
expert_first_token_offset_tensor,
|
||||
False, # enable_alltoall
|
||||
enable_alltoall,
|
||||
x.shape[0], # num_rows
|
||||
x.shape[1], # (possibly padded) hidden_size
|
||||
self.unpadded_hidden_size, # original hidden size
|
||||
@ -317,140 +433,50 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
)
|
||||
return h4
|
||||
|
||||
def forward_chunk_nvfp4(
|
||||
self,
|
||||
x: Union[torch.Tensor, Fp4QuantizedTensor],
|
||||
router_logits: torch.Tensor,
|
||||
output_dtype: Optional[torch.dtype] = None,
|
||||
all_rank_num_tokens: Optional[List[int]] = None,
|
||||
use_dp_padding: Optional[bool] = None,
|
||||
repeating_info: tuple = (True, True),
|
||||
def run_moe(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
token_selected_experts: torch.Tensor,
|
||||
token_final_scales: Optional[torch.Tensor],
|
||||
x_sf: Optional[torch.Tensor] = None,
|
||||
enable_alltoall: bool = False,
|
||||
) -> torch.Tensor:
|
||||
assert self.has_nvfp4
|
||||
"""
|
||||
Run MoE computation with CuteDSL backend.
|
||||
|
||||
if isinstance(x, Fp4QuantizedTensor):
|
||||
assert output_dtype is not None
|
||||
else:
|
||||
output_dtype = x.dtype
|
||||
This method encapsulates the core MoE computation logic, handling different
|
||||
quantization schemes (fp8_block_scales and nvfp4).
|
||||
|
||||
# apply routing
|
||||
token_selected_experts, token_final_scales = self.routing_method.apply(
|
||||
router_logits)
|
||||
assert token_selected_experts.shape[
|
||||
1] == self.routing_method.experts_per_token
|
||||
assert token_selected_experts.shape == token_final_scales.shape
|
||||
assert token_selected_experts.shape[0] == router_logits.shape[0]
|
||||
assert token_final_scales.dtype == torch.float32
|
||||
assert token_selected_experts.dtype == torch.int32
|
||||
Args:
|
||||
# Standard MoE interface parameters:
|
||||
x: Input hidden states (may be pre-quantized)
|
||||
token_selected_experts: Expert IDs [num_tokens, top_k]. If EPLB is enabled,
|
||||
this represents expert slots [num_tokens, top_k] instead.
|
||||
token_final_scales: Final scaling factors for each token
|
||||
x_sf: Input scale factors (optional, for certain quantization schemes)
|
||||
enable_alltoall: Whether alltoall communication is enabled.
|
||||
|
||||
run_post_quant_allgather = self.use_dp and self.parallel_size > 1
|
||||
if run_post_quant_allgather:
|
||||
if isinstance(x, Fp4QuantizedTensor):
|
||||
assert not x.is_sf_swizzled, "Fp4QuantizedTensor should not be swizzled before communication"
|
||||
x, x_sf = x.fp4_tensor, x.scaling_factor
|
||||
else:
|
||||
x, x_sf = torch.ops.trtllm.fp4_quantize(
|
||||
x, self.fc31_input_scale, self.scaling_vector_size, False,
|
||||
False)
|
||||
# note: we use uint8 to store 2 fp4 values
|
||||
x_row, x_col = x.size(0), x.size(1) * 2
|
||||
else:
|
||||
if not isinstance(x, Fp4QuantizedTensor):
|
||||
x, x_sf = torch.ops.trtllm.fp4_quantize(
|
||||
x, self.fc31_input_scale, self.scaling_vector_size, False,
|
||||
False)
|
||||
|
||||
if run_post_quant_allgather:
|
||||
# Original allgather logic
|
||||
if x_sf is not None:
|
||||
x_sf = x_sf.view(x_row, ceil_div(x_col,
|
||||
self.scaling_vector_size))
|
||||
assert x_sf.dim(
|
||||
) == 2, "The hidden states scaling factor should be 2D tensor before allgather"
|
||||
|
||||
x, x_sf, token_selected_experts, token_final_scales = allgather(
|
||||
[x, x_sf, token_selected_experts, token_final_scales],
|
||||
self.mapping,
|
||||
dim=0,
|
||||
sizes=None if use_dp_padding else all_rank_num_tokens)
|
||||
|
||||
tile_size = 128
|
||||
tile_idx_to_expert_idx, tile_idx_to_mn_limit, expanded_idx_to_permuted_idx, permuted_idx_to_expanded_idx, total_num_padded_tokens, num_non_exiting_tiles = torch.ops.trtllm.moe_sort(
|
||||
token_selected_experts=token_selected_experts,
|
||||
token_final_scales=token_final_scales,
|
||||
num_experts=self.num_slots,
|
||||
top_k=self.routing_method.experts_per_token,
|
||||
local_expert_offset=self.slot_start,
|
||||
local_num_experts=self.expert_size_per_partition,
|
||||
tile_tokens_dim=tile_size,
|
||||
)
|
||||
x, x_sf = torch.ops.trtllm.moe_permute(
|
||||
input=x.view(torch.float4_e2m1fn_x2),
|
||||
input_sf=x_sf,
|
||||
tile_idx_to_mn_limit=tile_idx_to_mn_limit,
|
||||
permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx,
|
||||
num_non_exiting_tiles=num_non_exiting_tiles,
|
||||
tile_tokens_dim=tile_size,
|
||||
top_k=self.routing_method.experts_per_token,
|
||||
)
|
||||
x, x_sf = torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_swiglu_blackwell(
|
||||
input=x.view(torch.float4_e2m1fn_x2),
|
||||
weight=self.w3_w1_weight.view(torch.float4_e2m1fn_x2),
|
||||
input_scale=x_sf.view(torch.uint8),
|
||||
weight_scale=self.quant_scales.fc1_weight_block.view(torch.uint8),
|
||||
alpha=self.quant_scales.fc1_global,
|
||||
tile_idx_to_group_idx=tile_idx_to_expert_idx,
|
||||
num_non_exiting_tiles=num_non_exiting_tiles,
|
||||
global_sf=self.fc2_input_scale,
|
||||
num_experts=self.num_slots,
|
||||
top_k=self.routing_method.experts_per_token,
|
||||
num_local_experts=self.expert_size_per_partition,
|
||||
local_expert_offset=self.slot_start,
|
||||
tile_size=tile_size,
|
||||
)
|
||||
if self.use_fused_finalize:
|
||||
x = torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_finalize_blackwell(
|
||||
input=x.view(torch.float4_e2m1fn_x2),
|
||||
weight=self.w2_weight.view(torch.float4_e2m1fn_x2),
|
||||
input_scale=x_sf.view(torch.uint8),
|
||||
weight_scale=self.quant_scales.fc2_weight_block.view(
|
||||
torch.uint8),
|
||||
alpha=self.quant_scales.fc2_global,
|
||||
tile_idx_to_group_idx=tile_idx_to_expert_idx,
|
||||
tile_idx_to_mn_limit=tile_idx_to_mn_limit,
|
||||
permuted_idx_to_expanded_idx=permuted_idx_to_expanded_idx,
|
||||
num_non_exiting_tiles=num_non_exiting_tiles,
|
||||
Returns:
|
||||
final_hidden_states tensor.
|
||||
"""
|
||||
if self.has_nvfp4:
|
||||
return self.run_moe_nvfp4(
|
||||
x=x,
|
||||
token_selected_experts=token_selected_experts,
|
||||
token_final_scales=token_final_scales,
|
||||
num_experts=self.num_slots,
|
||||
top_k=self.routing_method.experts_per_token,
|
||||
num_local_experts=self.expert_size_per_partition,
|
||||
local_expert_offset=self.slot_start,
|
||||
tile_size=tile_size,
|
||||
output_dtype=output_dtype,
|
||||
)
|
||||
x_sf=x_sf,
|
||||
enable_alltoall=enable_alltoall)
|
||||
elif self.has_deepseek_fp8_block_scales:
|
||||
return self.run_moe_fp8_block_scales(
|
||||
x=x,
|
||||
token_selected_experts=token_selected_experts,
|
||||
token_final_scales=token_final_scales,
|
||||
x_sf=x_sf,
|
||||
enable_alltoall=enable_alltoall)
|
||||
else:
|
||||
x = torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_blackwell(
|
||||
input=x.view(torch.float4_e2m1fn_x2),
|
||||
weight=self.w2_weight.view(torch.float4_e2m1fn_x2),
|
||||
input_scale=x_sf.view(torch.uint8),
|
||||
weight_scale=self.quant_scales.fc2_weight_block.view(
|
||||
torch.uint8),
|
||||
alpha=self.quant_scales.fc2_global,
|
||||
tile_idx_to_group_idx=tile_idx_to_expert_idx,
|
||||
num_non_exiting_tiles=num_non_exiting_tiles,
|
||||
num_experts=self.num_slots,
|
||||
top_k=self.routing_method.experts_per_token,
|
||||
num_local_experts=self.expert_size_per_partition,
|
||||
local_expert_offset=self.slot_start,
|
||||
tile_size=tile_size,
|
||||
output_dtype=output_dtype,
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} doesn't support quantization mode {self.quant_config.quant_mode}."
|
||||
)
|
||||
x = torch.ops.trtllm.moe_unpermute(
|
||||
permuted_input=x,
|
||||
expanded_idx_to_permuted_idx=expanded_idx_to_permuted_idx,
|
||||
topk_scales=token_final_scales,
|
||||
)
|
||||
return x
|
||||
|
||||
def forward_chunk(
|
||||
self,
|
||||
@ -461,32 +487,30 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
use_dp_padding: Optional[bool] = None,
|
||||
repeating_info: tuple = (True, True),
|
||||
) -> torch.Tensor:
|
||||
if self.has_any_quant:
|
||||
if self.has_nvfp4:
|
||||
return self.forward_chunk_nvfp4(
|
||||
x,
|
||||
router_logits,
|
||||
output_dtype=output_dtype,
|
||||
all_rank_num_tokens=all_rank_num_tokens,
|
||||
use_dp_padding=use_dp_padding,
|
||||
repeating_info=repeating_info)
|
||||
elif self.has_deepseek_fp8_block_scales:
|
||||
return self.forward_chunk_fp8_block_scales(
|
||||
x,
|
||||
router_logits,
|
||||
output_dtype=output_dtype,
|
||||
all_rank_num_tokens=all_rank_num_tokens,
|
||||
use_dp_padding=use_dp_padding,
|
||||
repeating_info=repeating_info)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"unsupported quantization mode for CUTEDSL backend: {self.quant_config.quant_mode}"
|
||||
)
|
||||
else:
|
||||
return self.forward_chunk_unquantized(
|
||||
x,
|
||||
router_logits,
|
||||
output_dtype=output_dtype,
|
||||
all_rank_num_tokens=all_rank_num_tokens,
|
||||
use_dp_padding=use_dp_padding,
|
||||
repeating_info=repeating_info)
|
||||
# Currently, the default path is that ConfigurableMoE calls CuteDslFusedMoE.run_moe.
|
||||
# This forward_chunk method is a reference implementation of the legacy path.
|
||||
# Apply routing
|
||||
token_selected_experts, token_final_scales = self.routing_method.apply(
|
||||
router_logits)
|
||||
assert token_selected_experts.shape[
|
||||
1] == self.routing_method.experts_per_token
|
||||
assert token_selected_experts.shape == token_final_scales.shape
|
||||
assert token_selected_experts.shape[0] == router_logits.shape[0]
|
||||
assert token_final_scales.dtype == torch.float32
|
||||
assert token_selected_experts.dtype == torch.int32
|
||||
|
||||
x, x_sf = self.quantize_input(x)
|
||||
|
||||
if self.use_dp and self.parallel_size > 1:
|
||||
x, x_sf, token_selected_experts, token_final_scales = allgather(
|
||||
[x, x_sf, token_selected_experts, token_final_scales],
|
||||
self.mapping,
|
||||
dim=0,
|
||||
sizes=None if use_dp_padding else all_rank_num_tokens)
|
||||
|
||||
x = self.run_moe(x=x,
|
||||
token_selected_experts=token_selected_experts,
|
||||
token_final_scales=token_final_scales,
|
||||
x_sf=x_sf,
|
||||
enable_alltoall=False)
|
||||
return x
|
||||
|
||||
@ -76,6 +76,7 @@ class CutlassFusedMoE(MoE):
|
||||
swiglu_beta: Optional[torch.Tensor] = None,
|
||||
swiglu_limit: Optional[torch.Tensor] = None,
|
||||
init_load_balancer: bool = True,
|
||||
without_comm: bool = False,
|
||||
activation_type: ActivationType = ActivationType.Swiglu,
|
||||
):
|
||||
|
||||
@ -138,49 +139,58 @@ class CutlassFusedMoE(MoE):
|
||||
self.has_been_profiled = False
|
||||
self.has_been_profiled_min_latency = False
|
||||
|
||||
# TODO: AlltoAll code is largely duplicated with WideEPMoE. Consider refactor and reuse in the future.
|
||||
self.alltoall_method_type = self.select_alltoall_method_type()
|
||||
logger.info_once(
|
||||
f"{self.__class__.__name__} selects alltoall_method_type {self.alltoall_method_type!r}",
|
||||
key="alltoall_method_type")
|
||||
self.alltoall_workspace = None
|
||||
self.alltoall_prepare_workspace = None
|
||||
self.use_low_precision_combine = False
|
||||
if self.enable_alltoall:
|
||||
self.use_low_precision_combine = model_config.use_low_precision_moe_combine
|
||||
# When without_comm=True, skip communication initialization (ConfigurableMoE will handle it)
|
||||
if not without_comm:
|
||||
self.alltoall_method_type = self.select_alltoall_method_type()
|
||||
logger.info_once(
|
||||
f"{self.__class__.__name__} selects alltoall_method_type {self.alltoall_method_type!r}",
|
||||
key="alltoall_method_type")
|
||||
self.alltoall_workspace = None
|
||||
self.alltoall_prepare_workspace = None
|
||||
self.use_low_precision_combine = False
|
||||
if self.enable_alltoall:
|
||||
self.use_low_precision_combine = model_config.use_low_precision_moe_combine
|
||||
|
||||
if self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided:
|
||||
MnnvlMemory.initialize()
|
||||
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
|
||||
model_config.mapping)
|
||||
self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace(
|
||||
model_config.mapping)
|
||||
elif self.alltoall_method_type == AlltoallMethodType.NVLinkOneSided:
|
||||
# Calculate required workspace size
|
||||
ep_size = self.mapping.moe_ep_size
|
||||
max_num_tokens = model_config.max_num_tokens
|
||||
hidden_size = self.hidden_size
|
||||
dtype = self.dtype or torch.float16
|
||||
if self.alltoall_method_type == AlltoallMethodType.NVLinkTwoSided:
|
||||
MnnvlMemory.initialize()
|
||||
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
|
||||
model_config.mapping)
|
||||
self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace(
|
||||
model_config.mapping)
|
||||
elif self.alltoall_method_type == AlltoallMethodType.NVLinkOneSided:
|
||||
# Calculate required workspace size
|
||||
ep_size = self.mapping.moe_ep_size
|
||||
max_num_tokens = model_config.max_num_tokens
|
||||
hidden_size = self.hidden_size
|
||||
dtype = self.dtype or torch.float16
|
||||
|
||||
workspace_size = MoeAlltoAll.calculate_required_workspace_size(
|
||||
ep_size, self.routing_method.experts_per_token,
|
||||
max_num_tokens, hidden_size, dtype)
|
||||
workspace_size = MoeAlltoAll.calculate_required_workspace_size(
|
||||
ep_size, self.routing_method.experts_per_token,
|
||||
max_num_tokens, hidden_size, dtype)
|
||||
|
||||
self.moe_a2a = MoeAlltoAll(
|
||||
mapping=self.mapping,
|
||||
max_num_tokens=model_config.max_num_tokens,
|
||||
top_k=self.routing_method.experts_per_token,
|
||||
num_experts=self.num_slots,
|
||||
workspace_size_per_rank=workspace_size,
|
||||
)
|
||||
elif self.alltoall_method_type == AlltoallMethodType.DeepEP or self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
|
||||
raise NotImplementedError(
|
||||
"DeepEP and DeepEPLowLatency are not supported for CutlassFusedMoE yet"
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported alltoall method type: {self.alltoall_method_type!r}"
|
||||
)
|
||||
self.moe_a2a = MoeAlltoAll(
|
||||
mapping=self.mapping,
|
||||
max_num_tokens=model_config.max_num_tokens,
|
||||
top_k=self.routing_method.experts_per_token,
|
||||
num_experts=self.num_slots,
|
||||
workspace_size_per_rank=workspace_size,
|
||||
)
|
||||
elif self.alltoall_method_type == AlltoallMethodType.DeepEP or self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
|
||||
raise NotImplementedError(
|
||||
"DeepEP and DeepEPLowLatency are not supported for CutlassFusedMoE yet"
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported alltoall method type: {self.alltoall_method_type!r}"
|
||||
)
|
||||
else:
|
||||
# When without_comm=True, set minimal attributes
|
||||
# Communication will be handled by parent wrapper (e.g., ConfigurableMoE)
|
||||
self.alltoall_method_type = AlltoallMethodType.NotEnabled
|
||||
self.alltoall_workspace = None
|
||||
self.alltoall_prepare_workspace = None
|
||||
self.use_low_precision_combine = False
|
||||
self.moe_a2a = None
|
||||
|
||||
# If True, the router weight will be multiplied on the input rather than at the end of FC2
|
||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||
|
||||
@ -156,14 +156,14 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
raise NotImplementedError(
|
||||
f"Unsupported alltoall method type: {self.alltoall_method_type!r}"
|
||||
)
|
||||
else:
|
||||
# When without_comm=True, set minimal attributes
|
||||
# Communication will be handled by parent wrapper (e.g., ConfigurableMoE)
|
||||
self.alltoall_method_type = AlltoallMethodType.NotEnabled
|
||||
self.alltoall_workspace = None
|
||||
self.alltoall_prepare_workspace = None
|
||||
self.use_low_precision_combine = False
|
||||
self.moe_a2a = None
|
||||
else:
|
||||
# When without_comm=True, set minimal attributes
|
||||
# Communication will be handled by parent wrapper (e.g., ConfigurableMoE)
|
||||
self.alltoall_method_type = AlltoallMethodType.NotEnabled
|
||||
self.alltoall_workspace = None
|
||||
self.alltoall_prepare_workspace = None
|
||||
self.use_low_precision_combine = False
|
||||
self.moe_a2a = None
|
||||
|
||||
self._weights_created = False
|
||||
if not model_config.skip_create_weights_in_init:
|
||||
|
||||
@ -2082,34 +2082,44 @@ class NVFP4CutlassFusedMoEMethod(NVFP4FusedMoEMethod):
|
||||
|
||||
class NVFP4CuteDslFusedMoEMethod(NVFP4CutlassFusedMoEMethod):
|
||||
|
||||
def post_load_weights(self, module: torch.nn.Module):
|
||||
super().post_load_weights(module)
|
||||
def load_expert_w3_w1_weight(self, module: torch.nn.Module,
|
||||
w1_weight: torch.Tensor,
|
||||
w3_weight: torch.Tensor,
|
||||
dst_w3_w1_weight: torch.Tensor):
|
||||
super().load_expert_w3_w1_weight(module, w1_weight, w3_weight,
|
||||
dst_w3_w1_weight)
|
||||
|
||||
# Interleave FC1 weight and scales for GEMM1 + SwiGLU fusion.
|
||||
w3_w1_weight = module.w3_w1_weight.data.view(float4_e2m1x2)
|
||||
m = w3_w1_weight.size(1)
|
||||
n = w3_w1_weight.size(2) * 2
|
||||
# Interleave FC1 weight for GEMM1 + SwiGLU fusion.
|
||||
w3_w1_weight = dst_w3_w1_weight.cuda().view(float4_e2m1x2)
|
||||
w3_w1_weight_interleaved = interleave_linear_and_gate(w3_w1_weight,
|
||||
group_size=64,
|
||||
dim=1)
|
||||
dim=0)
|
||||
w3_w1_weight_interleaved = w3_w1_weight_interleaved.view(
|
||||
module.w3_w1_weight.data.dtype)
|
||||
module.w3_w1_weight.data.copy_(w3_w1_weight_interleaved)
|
||||
dst_w3_w1_weight.dtype)
|
||||
dst_w3_w1_weight.copy_(w3_w1_weight_interleaved)
|
||||
|
||||
w3_w1_weight_scale = module.quant_scales.fc1_weight_block.data.view(
|
||||
float4_sf_dtype)
|
||||
def load_expert_w3_w1_weight_scale_nvfp4(
|
||||
self, module: torch.nn.Module, w1_weight_scale: torch.Tensor,
|
||||
w3_weight_scale: torch.Tensor,
|
||||
dst_w3_w1_weight_scale: torch.Tensor):
|
||||
super().load_expert_w3_w1_weight_scale_nvfp4(module, w1_weight_scale,
|
||||
w3_weight_scale,
|
||||
dst_w3_w1_weight_scale)
|
||||
|
||||
# Interleave FC1 scales for GEMM1 + SwiGLU fusion.
|
||||
n = module.intermediate_size_per_partition * 2
|
||||
k = module.hidden_size
|
||||
w3_w1_weight_scale = dst_w3_w1_weight_scale.cuda().view(float4_sf_dtype)
|
||||
w3_w1_weight_scale_unswizzled = unswizzle_sf(
|
||||
w3_w1_weight_scale, m, n).view(-1, m,
|
||||
n // module.scaling_vector_size)
|
||||
w3_w1_weight_scale, n, k).view(n, k // module.scaling_vector_size)
|
||||
w3_w1_weight_scale_unswizzled_interleaved = interleave_linear_and_gate(
|
||||
w3_w1_weight_scale_unswizzled, group_size=64, dim=1)
|
||||
w3_w1_weight_scale_unswizzled, group_size=64, dim=0)
|
||||
w3_w1_weight_scale_interleaved = swizzle_sf(
|
||||
w3_w1_weight_scale_unswizzled_interleaved, m,
|
||||
n).view(-1, m, n // module.scaling_vector_size)
|
||||
w3_w1_weight_scale_unswizzled_interleaved, n,
|
||||
k).view(n, k // module.scaling_vector_size)
|
||||
w3_w1_weight_scale_interleaved = w3_w1_weight_scale_interleaved.view(
|
||||
module.quant_scales.fc1_weight_block.data.dtype)
|
||||
module.quant_scales.fc1_weight_block.data.copy_(
|
||||
w3_w1_weight_scale_interleaved)
|
||||
dst_w3_w1_weight_scale.dtype)
|
||||
dst_w3_w1_weight_scale.copy_(w3_w1_weight_scale_interleaved)
|
||||
|
||||
|
||||
class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod):
|
||||
|
||||
@ -112,13 +112,13 @@ class GemmRunner(TunableRunner):
|
||||
mutates_args=())
|
||||
def get_best_gemm_tactic(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
|
||||
runners = [GemmRunner()]
|
||||
tunner = AutoTuner.get()
|
||||
tuner = AutoTuner.get()
|
||||
tuning_config = TuningConfig(dynamic_tensor_specs=(DynamicTensorSpec(
|
||||
input_idx=0,
|
||||
dim_idx=0,
|
||||
gen_tuning_buckets=get_power_of_2_num_tokens_buckets,
|
||||
map_to_tuning_buckets=next_positive_power_of_2), ), )
|
||||
runner, tactic = tunner.choose_one(
|
||||
runner, tactic = tuner.choose_one(
|
||||
"autotuner_test::get_best_gemm_tactic",
|
||||
runners,
|
||||
tuning_config,
|
||||
@ -175,20 +175,20 @@ def test_autotuner_try_block():
|
||||
|
||||
x, w = torch.randn(M, 64), torch.randn(64, 128)
|
||||
runners = [PartialCrashedRunner()]
|
||||
tunner = AutoTuner.get()
|
||||
tuner = AutoTuner.get()
|
||||
tuning_config = TuningConfig(dynamic_tensor_specs=(DynamicTensorSpec(
|
||||
input_idx=0,
|
||||
dim_idx=0,
|
||||
gen_tuning_buckets=get_power_of_2_num_tokens_buckets,
|
||||
map_to_tuning_buckets=next_positive_power_of_2), ), )
|
||||
with autotune():
|
||||
runner, tactic = tunner.choose_one("test_autotuner_try_block", runners,
|
||||
tuning_config, [x, w])
|
||||
runner, tactic = tuner.choose_one("test_autotuner_try_block", runners,
|
||||
tuning_config, [x, w])
|
||||
|
||||
m = M // 2
|
||||
while m >= 1:
|
||||
_, tactic = tunner.choose_one("test_autotuner_try_block", runners,
|
||||
tuning_config, [torch.randn(m, 64), w])
|
||||
_, tactic = tuner.choose_one("test_autotuner_try_block", runners,
|
||||
tuning_config, [torch.randn(m, 64), w])
|
||||
assert tactic in [
|
||||
-1, 0
|
||||
], f"Expect only tactic -1, 0 being chosen, but got tactic {tactic}."
|
||||
|
||||
@ -1364,9 +1364,10 @@ def test_fused_moe_nvfp4(dtype, moe_backend):
|
||||
if dtype == torch.float16:
|
||||
pytest.skip(
|
||||
"CUTEDSL NVFP4 MoE backend does not support float16 yet")
|
||||
if get_sm_version() != 100:
|
||||
if get_sm_version() not in (100, 103):
|
||||
pytest.skip(
|
||||
"CUTEDSL NVFP4 MoE backend is only supported on SM 100 GPUs")
|
||||
"CUTEDSL NVFP4 MoE backend supports SM 100 (B200) and SM 103 (B300) only"
|
||||
)
|
||||
|
||||
test_all_kernels = True
|
||||
if get_sm_version() == 120:
|
||||
|
||||
@ -227,6 +227,60 @@ def test_moe_unpermute(dtype: str, num_tokens: int, top_k: int, tile_size: int):
|
||||
torch.testing.assert_close(x, x_ref)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tile_size", [128, 256])
|
||||
@pytest.mark.parametrize("ep_size", [1, 8, 32])
|
||||
@pytest.mark.parametrize("top_k", [1, 2, 8])
|
||||
@pytest.mark.parametrize("num_tokens", [128, 515, 1024])
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16", "float16"])
|
||||
def test_moe_output_memset_inplace(
|
||||
dtype: str, num_tokens: int, top_k: int, ep_size: int, tile_size: int
|
||||
):
|
||||
dtype = getattr(torch, dtype)
|
||||
hidden_size = 4096
|
||||
num_experts = 256
|
||||
num_local_experts = num_experts // ep_size
|
||||
enable_alltoall = True
|
||||
|
||||
routing_logits = torch.randn(num_tokens, num_experts, device="cuda")
|
||||
token_final_scales, token_selected_experts = routing_logits.topk(top_k, dim=-1)
|
||||
token_selected_experts = token_selected_experts.to(torch.int32)
|
||||
token_final_scales = token_final_scales.softmax(dim=-1).to(torch.float32)
|
||||
|
||||
(
|
||||
tile_idx_to_group_idx,
|
||||
tile_idx_to_mn_limit,
|
||||
expanded_idx_to_permuted_idx,
|
||||
permuted_idx_to_expanded_idx,
|
||||
total_num_padded_tokens,
|
||||
num_non_exiting_tiles,
|
||||
) = torch.ops.trtllm.moe_sort(
|
||||
token_selected_experts=token_selected_experts,
|
||||
token_final_scales=token_final_scales,
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
local_expert_offset=0,
|
||||
local_num_experts=num_local_experts,
|
||||
tile_tokens_dim=tile_size,
|
||||
)
|
||||
|
||||
x = torch.ones(num_tokens, hidden_size, dtype=dtype, device="cuda")
|
||||
torch.ops.trtllm.moe_output_memset_inplace(
|
||||
x,
|
||||
tile_idx_to_mn_limit,
|
||||
expanded_idx_to_permuted_idx,
|
||||
permuted_idx_to_expanded_idx,
|
||||
num_non_exiting_tiles,
|
||||
tile_size,
|
||||
top_k,
|
||||
ep_size,
|
||||
enable_alltoall=enable_alltoall,
|
||||
)
|
||||
x_ref = torch.zeros_like(x)
|
||||
if enable_alltoall and ep_size > top_k:
|
||||
x_ref[(expanded_idx_to_permuted_idx < 0).all(dim=-1)] = 1
|
||||
torch.testing.assert_close(x, x_ref)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tile_size", [128, 256])
|
||||
@pytest.mark.parametrize("top_k", [1, 2, 8])
|
||||
@pytest.mark.parametrize("num_tokens", [128, 515, 1024])
|
||||
@ -257,7 +311,10 @@ def test_moe_swiglu(dtype: str, num_tokens: int, top_k: int, tile_size: int):
|
||||
torch.testing.assert_close(y[:num_permuted_tokens], y_ref[:num_permuted_tokens])
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_sm_version() != 100, reason="This test is only supported on SM 100 GPUs")
|
||||
@pytest.mark.skipif(
|
||||
get_sm_version() not in (100, 103),
|
||||
reason="This test is only supported on SM 100 and SM 103 GPUs",
|
||||
)
|
||||
@pytest.mark.parametrize("tile_size", [128, 256])
|
||||
@pytest.mark.parametrize("top_k", [1, 2, 8])
|
||||
@pytest.mark.parametrize("num_tokens", [128, 515, 1024])
|
||||
@ -332,7 +389,10 @@ def test_moe_gelu(dtype: str, num_tokens: int, top_k: int, tile_size: int):
|
||||
torch.testing.assert_close(y[:num_permuted_tokens], y_ref[:num_permuted_tokens])
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_sm_version() != 100, reason="This test is only supported on SM 100 GPUs")
|
||||
@pytest.mark.skipif(
|
||||
get_sm_version() not in (100, 103),
|
||||
reason="This test is only supported on SM 100 and SM 103 GPUs",
|
||||
)
|
||||
@pytest.mark.parametrize("tile_size", [128])
|
||||
@pytest.mark.parametrize("ep_size", [1, 8, 32])
|
||||
@pytest.mark.parametrize("top_k", [1, 2, 8])
|
||||
@ -425,7 +485,10 @@ def test_nvfp4_grouped_gemm_blackwell(num_tokens: int, top_k: int, ep_size: int,
|
||||
torch.testing.assert_close(c[:num_valid_permuted_tokens], c_ref[:num_valid_permuted_tokens])
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_sm_version() != 100, reason="This test is only supported on SM 100 GPUs")
|
||||
@pytest.mark.skipif(
|
||||
get_sm_version() not in (100, 103),
|
||||
reason="This test is only supported on SM 100 and SM 103 GPUs",
|
||||
)
|
||||
@pytest.mark.parametrize("tile_size", [128])
|
||||
@pytest.mark.parametrize("ep_size", [1, 8, 32])
|
||||
@pytest.mark.parametrize("top_k", [1, 2, 8])
|
||||
@ -523,7 +586,10 @@ def test_nvfp4_grouped_gemm_finalize_blackwell(
|
||||
assert match_ratio > 0.99
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_sm_version() != 100, reason="This test is only supported on SM 100 GPUs")
|
||||
@pytest.mark.skipif(
|
||||
get_sm_version() not in (100, 103),
|
||||
reason="This test is only supported on SM 100 and SM 103 GPUs",
|
||||
)
|
||||
@pytest.mark.parametrize("tile_size", [128])
|
||||
@pytest.mark.parametrize("ep_size", [1, 8, 32])
|
||||
@pytest.mark.parametrize("top_k", [1, 2, 8])
|
||||
|
||||
Loading…
Reference in New Issue
Block a user