Remove swizzle for wide ep (#6328)

Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com>
Co-authored-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com>
This commit is contained in:
Zongfei Jing 2025-07-30 13:59:58 +08:00 committed by GitHub
parent ab6fb9f05d
commit 158ebb089c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 86 additions and 57 deletions

View File

@ -780,7 +780,7 @@ public:
auto stream = streamPtr->get();
MoeMinLatencyParams min_latency_params;
#ifdef USING_OSS_CUTLASS_MOE_GEMM
mMoERunner.runMoe(mInputTensor, nullptr, mSelectedExperts, mUseFinalScale ? mScaleProbs : nullptr,
mMoERunner.runMoe(mInputTensor, nullptr, true, mSelectedExperts, mUseFinalScale ? mScaleProbs : nullptr,
mExpertWeight1, mExpertBias1, mActType, mExpertWeight2, mExpertBias2, mQuantParams, mTotalTokens,
mHiddenSize, mInterSize, mNumExperts, mK, mWorkspace, mFinalOutput, mSourceToExpandedMap,
parallelism_config, /*enable_alltoall=*/false, mUseLora, mLoraParams,

View File

@ -366,14 +366,14 @@ public:
= 0;
virtual std::vector<cutlass_extensions::CutlassGemmConfig> getTactics() = 0;
virtual void runMoe(void const* input_activations, void const* input_sf, int const* token_selected_experts,
float const* token_final_scales, void const* fc1_expert_weights, void const* fc1_expert_biases,
ActivationType fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases,
QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output,
int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool const enable_alltoall,
bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode,
MoeMinLatencyParams& min_latency_params, cudaStream_t stream)
virtual void runMoe(void const* input_activations, void const* input_sf, bool const swizzled_input_sf,
int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights,
void const* fc1_expert_biases, ActivationType fc1_activation_type, void const* fc2_expert_weights,
void const* fc2_expert_biases, QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size,
int64_t const inter_size, int const num_experts, int const experts_per_token, char* workspace_ptr,
void* final_output, int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config,
bool const enable_alltoall, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale,
bool min_latency_mode, MoeMinLatencyParams& min_latency_params, cudaStream_t stream)
= 0;
// Aliases for profiling the gemms
@ -513,14 +513,14 @@ public:
return RunnerType::getConfigs(sm);
}
void runMoe(void const* input_activations, void const* input_sf, int const* token_selected_experts,
float const* token_final_scales, void const* fc1_expert_weights, void const* fc1_expert_biases,
ActivationType fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases,
QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output,
int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool const enable_alltoall,
bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode,
MoeMinLatencyParams& min_latency_params, cudaStream_t stream) override;
void runMoe(void const* input_activations, void const* input_sf, bool const swizzled_input_sf,
int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights,
void const* fc1_expert_biases, ActivationType fc1_activation_type, void const* fc2_expert_weights,
void const* fc2_expert_biases, QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size,
int64_t const inter_size, int const num_experts, int const experts_per_token, char* workspace_ptr,
void* final_output, int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config,
bool const enable_alltoall, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale,
bool min_latency_mode, MoeMinLatencyParams& min_latency_params, cudaStream_t stream) override;
// We make these GEMM1 & GEMM2 static because they need to be stateless for the profiler to work
static void gemm1(MoeGemmRunner<T, WeightType, OutputType, ScaleBiasType>& gemm_runner,

View File

@ -1017,7 +1017,7 @@ __device__ uint32_t quantizePackedFP4Value(ComputeElem& post_act_val, float glob
__device__ void writeSF(int64_t num_tokens_before_expert, int64_t expert_id, int64_t source_token_id, int64_t token_id,
int64_t elem_idx, int64_t num_cols, int64_t max_tokens_per_expert,
TmaWarpSpecializedGroupedGemmInput::ElementSF* act_sf_flat,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf)
TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, bool const swizzled_input_sf = true)
{
static constexpr int CVT_FP4_NUM_THREADS_PER_SF = NVFP4_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD;
@ -1032,11 +1032,31 @@ __device__ void writeSF(int64_t num_tokens_before_expert, int64_t expert_id, int
elem_idx, std::nullopt /* numRows */, num_cols, act_sf_expert, FP4QuantizationSFLayout::SWIZZLED);
if (sf_out)
{
auto const sf_in = cvt_quant_to_fp4_get_sf_out_offset<TmaWarpSpecializedGroupedGemmInput::ElementSF,
CVT_FP4_NUM_THREADS_PER_SF, NVFP4_VEC_SIZE>(std::nullopt /* batchIdx */, source_token_id, elem_idx,
std::nullopt /* numRows */, num_cols, const_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(input_sf),
FP4QuantizationSFLayout::SWIZZLED);
*sf_out = *sf_in;
if (input_sf)
{
if (swizzled_input_sf)
{
auto const sf_in = cvt_quant_to_fp4_get_sf_out_offset<TmaWarpSpecializedGroupedGemmInput::ElementSF,
CVT_FP4_NUM_THREADS_PER_SF, NVFP4_VEC_SIZE>(std::nullopt /* batchIdx */, source_token_id, elem_idx,
std::nullopt /* numRows */, num_cols,
const_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(input_sf),
FP4QuantizationSFLayout::SWIZZLED);
*sf_out = *sf_in;
}
else
{
auto const sf_in = cvt_quant_to_fp4_get_sf_out_offset<TmaWarpSpecializedGroupedGemmInput::ElementSF,
CVT_FP4_NUM_THREADS_PER_SF, NVFP4_VEC_SIZE>(std::nullopt /* batchIdx */, source_token_id, elem_idx,
std::nullopt /* numRows */, num_cols,
const_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(input_sf),
FP4QuantizationSFLayout::LINEAR);
*sf_out = *sf_in;
}
}
else
{
*sf_out = 0x00;
}
}
}
@ -1428,7 +1448,8 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp
int const* permuted_row_to_unpermuted_row, int64_t const num_rows, int64_t const cols, int64_t const k,
float const* fc1_act_global_scale, bool use_per_expert_act_scale, int64_t const* expert_first_token_offset,
TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, int64_t const num_experts_per_node)
TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, bool const swizzled_input_sf,
int64_t const num_experts_per_node)
{
#ifdef ENABLE_FP4
constexpr bool is_fp4 = std::is_same_v<ExpandedActivationsType, __nv_fp4_e2m1>;
@ -1498,7 +1519,7 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp
{
assert(act_scale_idx == 0 && "Cannot use per-expert act scale for pre-quantized activations");
writeSF(num_tokens_before_expert, expert, source_row, permuted_row, elem_index, cols, num_rows,
fc1_act_sf_flat, input_sf);
fc1_act_sf_flat, input_sf, swizzled_input_sf);
dest_row_ptr[elem_index] = in_vec;
}
}
@ -1528,7 +1549,7 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
int const* permuted_row_to_unpermuted_row, int64_t const num_rows, int64_t const cols, int const k,
int const num_experts_per_node, float const* fc1_act_global_scale, bool use_per_expert_act_scale,
int64_t* expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, cudaStream_t stream)
TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, bool const swizzled_input_sf, cudaStream_t stream)
{
#ifdef ENABLE_FP4
// TODO Currently this is a bit hacky because we assume we are in FP8_MXFP4 mode if activations are FP8.
@ -1566,7 +1587,7 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
config.attrs = attrs;
cudaLaunchKernelEx(&config, func, unpermuted_input, permuted_output, unpermuted_scales, permuted_scales,
permuted_row_to_unpermuted_row, num_rows, cols, k, fc1_act_global_scale, use_per_expert_act_scale,
expert_first_token_offset, fc1_act_sf_flat, input_sf, num_experts_per_node);
expert_first_token_offset, fc1_act_sf_flat, input_sf, swizzled_input_sf, num_experts_per_node);
}
enum class ScaleMode : int
@ -3103,14 +3124,14 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
template <class T, class WeightType, class OutputType, class InputType, class BackBoneType, class Enable>
void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enable>::runMoe(
void const* input_activations_void, void const* input_sf_void, int const* token_selected_experts,
float const* token_final_scales, void const* fc1_expert_weights_void, void const* fc1_expert_biases_void,
ActivationType fc1_activation_type, void const* fc2_expert_weights_void, void const* fc2_expert_biases_void,
QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size,
int const full_num_experts, int const experts_per_token, char* workspace_ptr, void* final_output_void,
int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool const enable_alltoall,
bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode,
MoeMinLatencyParams& min_latency_params, cudaStream_t stream)
void const* input_activations_void, void const* input_sf_void, bool const swizzled_input_sf,
int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights_void,
void const* fc1_expert_biases_void, ActivationType fc1_activation_type, void const* fc2_expert_weights_void,
void const* fc2_expert_biases_void, QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size,
int64_t const inter_size, int const full_num_experts, int const experts_per_token, char* workspace_ptr,
void* final_output_void, int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config,
bool const enable_alltoall, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale,
bool min_latency_mode, MoeMinLatencyParams& min_latency_params, cudaStream_t stream)
{
static constexpr bool int_scales_required
= std::is_same<WeightType, uint8_t>::value || std::is_same<WeightType, cutlass::uint4b_t>::value;
@ -3314,7 +3335,8 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
expandInputRowsKernelLauncher(input_activations, reinterpret_cast<ExpandedActivationsType*>(permuted_data_),
token_topk_unpermuted_scales, permuted_token_final_scales_, permuted_row_to_unpermuted_row_, num_rows,
hidden_size, experts_per_token, num_experts_per_node, quant_params.fp4.fc1.act_global_scale,
use_per_expert_act_scale, expert_first_token_offset_, fc1_fp4_act_scale_, input_sf, stream);
use_per_expert_act_scale, expert_first_token_offset_, fc1_fp4_act_scale_, input_sf, swizzled_input_sf,
stream);
sync_check_cuda_error(stream);

View File

@ -957,7 +957,7 @@ int MixtureOfExpertsPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
MoeMinLatencyParams min_latency_params{};
mMOERunner->setTactic(gemm1, gemm2);
#ifdef USING_OSS_CUTLASS_MOE_GEMM
mMOERunner->runMoe(inputs[getInputTensorIndex()], nullptr,
mMOERunner->runMoe(inputs[getInputTensorIndex()], nullptr, true,
static_cast<int const*>(inputs[getTokenSelectedExpertsIndex()]),
hasFinalScales() ? static_cast<float const*>(inputs[getTokenFinalScalesIndex()]) : nullptr,
inputs[getExpertWeights1Index()], hasBias() ? inputs[getExpertBias1Index()] : nullptr, mActivationType,
@ -968,7 +968,7 @@ int MixtureOfExpertsPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
/*enable_alltoall=*/false, hasLora(), lora_params, /*use_deepseek_fp8_block_scale=*/false,
/*min_latency_mode=*/false, min_latency_params, stream);
#else
mMOERunner->runMoe(inputs[getInputTensorIndex()], nullptr,
mMOERunner->runMoe(inputs[getInputTensorIndex()], nullptr, true,
static_cast<int const*>(inputs[getTokenSelectedExpertsIndex()]),
hasFinalScales() ? static_cast<float const*>(inputs[getTokenFinalScalesIndex()]) : nullptr,
inputs[getExpertWeights1Index()], hasBias() ? inputs[getExpertBias1Index()] : nullptr, mActivationType,

View File

@ -215,9 +215,10 @@ public:
torch::optional<torch::Tensor> const& fc1_expert_biases, torch::Tensor const& fc2_expert_weights,
torch::optional<torch::Tensor> const& fc2_expert_biases,
torch::optional<c10::ArrayRef<torch::Tensor>> const& quant_scales,
torch::optional<torch::Tensor> const& input_sf, int64_t const tp_size, int64_t const tp_rank,
int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank,
bool const enable_alltoall, bool min_latency_mode, torch::optional<c10::ArrayRef<int64_t>> const& profile_ids)
torch::optional<torch::Tensor> const& input_sf, bool const swizzled_input_sf, int64_t const tp_size,
int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size,
int64_t const cluster_rank, bool const enable_alltoall, bool min_latency_mode,
torch::optional<c10::ArrayRef<int64_t>> const& profile_ids)
{
std::lock_guard<std::mutex> lock(mMutex);
// Free the profile workspace to save memory
@ -297,7 +298,7 @@ public:
::tensorrt_llm::kernels::LoraParams lora_params{};
#ifdef USING_OSS_CUTLASS_MOE_GEMM
mKernelRunner->runMoe(input.const_data_ptr(),
input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr,
input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr, swizzled_input_sf,
reinterpret_cast<int const*>(token_selected_experts.const_data_ptr()),
token_final_scales.has_value() ? reinterpret_cast<float const*>(token_final_scales.value().const_data_ptr())
: nullptr,
@ -311,7 +312,7 @@ public:
mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream);
#else
mKernelRunner->runMoe(input.const_data_ptr(),
input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr,
input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr, swizzled_input_sf,
reinterpret_cast<int const*>(token_selected_experts.const_data_ptr()),
token_final_scales.has_value() ? reinterpret_cast<float const*>(token_final_scales.value().const_data_ptr())
: nullptr,
@ -333,9 +334,10 @@ public:
torch::Tensor const& fc1_expert_weights, torch::optional<torch::Tensor> const& fc1_expert_biases,
torch::Tensor const& fc2_expert_weights, torch::optional<torch::Tensor> const& fc2_expert_biases,
torch::optional<c10::ArrayRef<torch::Tensor>> const& quant_scales,
torch::optional<torch::Tensor> const& input_sf, int64_t const tp_size, int64_t const tp_rank,
int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank,
bool const enable_alltoall, bool min_latency_mode, torch::optional<c10::ArrayRef<int64_t>> const& profile_ids)
torch::optional<torch::Tensor> const& input_sf, bool const swizzled_input_sf, int64_t const tp_size,
int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size,
int64_t const cluster_rank, bool const enable_alltoall, bool min_latency_mode,
torch::optional<c10::ArrayRef<int64_t>> const& profile_ids)
{
std::lock_guard<std::mutex> lock(mMutex);
@ -424,7 +426,7 @@ public:
::tensorrt_llm::kernels::LoraParams lora_params{};
#ifdef USING_OSS_CUTLASS_MOE_GEMM
mKernelRunner->runMoe(input.const_data_ptr(),
input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr,
input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr, swizzled_input_sf,
reinterpret_cast<int const*>(token_selected_experts.const_data_ptr()),
token_final_scales.has_value() ? reinterpret_cast<float const*>(token_final_scales.value().const_data_ptr())
: nullptr,
@ -438,7 +440,7 @@ public:
mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream);
#else
mKernelRunner->runMoe(input.const_data_ptr(),
input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr,
input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr, swizzled_input_sf,
reinterpret_cast<int const*>(token_selected_experts.const_data_ptr()),
token_final_scales.has_value() ? reinterpret_cast<float const*>(token_final_scales.value().const_data_ptr())
: nullptr,

View File

@ -1106,15 +1106,16 @@ protected:
MoeMinLatencyParams min_latency_params;
mMoERunner.setTactic(tactic1, tactic2);
#ifdef USING_OSS_CUTLASS_MOE_GEMM
mMoERunner.runMoe(mInputTensor, nullptr, mSelectedExpert, mTokenFinalScales, weight1_ptr, bias1_ptr, mActType,
weight2_ptr, bias2_ptr, quant_params, mTotalTokens, mHiddenSize, mInterSize / parallelism_config.tp_size,
mNumExperts, mK, mWorkspace, mFinalOutput, mSourceToExpandedMap, parallelism_config, enable_alltoall,
mUseLora, lora_params, useFp8BlockScales, minLatencyMode, min_latency_params, stream);
mMoERunner.runMoe(mInputTensor, nullptr, true, mSelectedExpert, mTokenFinalScales, weight1_ptr, bias1_ptr,
mActType, weight2_ptr, bias2_ptr, quant_params, mTotalTokens, mHiddenSize,
mInterSize / parallelism_config.tp_size, mNumExperts, mK, mWorkspace, mFinalOutput, mSourceToExpandedMap,
parallelism_config, enable_alltoall, mUseLora, lora_params, useFp8BlockScales, minLatencyMode,
min_latency_params, stream);
#else
mMoERunner.runMoe(mInputTensor, nullptr, mSelectedExpert, mTokenFinalScales, weight1_ptr, bias1_ptr, mActType,
weight2_ptr, bias2_ptr, quant_params, mTotalTokens, mHiddenSize, mInterSize / parallelism_config.tp_size,
mNumExperts, mK, mWorkspace, mFinalOutput, mSourceToExpandedMap, parallelism_config, mUseLora, lora_params,
useFp8BlockScales, minLatencyMode, min_latency_params, stream);
mMoERunner.runMoe(mInputTensor, nullptr, true, mSelectedExpert, mTokenFinalScales, weight1_ptr, bias1_ptr,
mActType, weight2_ptr, bias2_ptr, quant_params, mTotalTokens, mHiddenSize,
mInterSize / parallelism_config.tp_size, mNumExperts, mK, mWorkspace, mFinalOutput, mSourceToExpandedMap,
parallelism_config, mUseLora, lora_params, useFp8BlockScales, minLatencyMode, min_latency_params, stream);
#endif
check_cuda_error(cudaStreamSynchronize(stream));

View File

@ -124,6 +124,7 @@ def fused_moe(
output_dtype: torch.dtype,
quant_scales: List[torch.Tensor],
input_sf: Optional[torch.Tensor] = None,
swizzled_input_sf: bool = True,
tp_size: int = 1,
tp_rank: int = 0,
ep_size: int = 1,
@ -191,6 +192,7 @@ def fused_moe(
fc2_expert_biases,
quant_scales,
input_sf,
swizzled_input_sf,
tp_size,
tp_rank,
ep_size,

View File

@ -275,6 +275,7 @@ class CutlassFusedMoE(MoE):
output_dtype,
quant_scales=self.quant_scales,
input_sf=x_sf,
swizzled_input_sf=True,
tp_size=self.tp_size,
tp_rank=self.tp_rank,
ep_size=self.ep_size,

View File

@ -603,6 +603,7 @@ class WideEPMoE(MoE):
output_dtype,
quant_scales=quant_scales,
input_sf=x_sf,
swizzled_input_sf=sf_swizzle,
tp_size=self.tp_size,
tp_rank=self.tp_rank,
ep_size=ep_size,
@ -881,7 +882,7 @@ class WideEPMoE(MoE):
self.alltoall_workspace,
self.ep_rank, self.ep_size)
if self.has_nvfp4:
if self.has_nvfp4 and is_sf_swizzle:
x_sf = swizzle_sf(x_sf, x.shape[0], x.shape[1] * 2,
self.scaling_vector_size)