From 158ebb089c6efba3a7e1eb650f598b4aa9553eea Mon Sep 17 00:00:00 2001 From: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com> Date: Wed, 30 Jul 2025 13:59:58 +0800 Subject: [PATCH] Remove swizzle for wide ep (#6328) Signed-off-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com> Co-authored-by: Jiang Shao <91270701+StudyingShao@users.noreply.github.com> --- .../mixtureOfExpertsBackendBenchmarkFixture.h | 2 +- .../cutlass_kernels/include/moe_kernels.h | 32 +++++----- .../cutlass_kernels/moe_gemm/moe_kernels.cu | 60 +++++++++++++------ .../mixtureOfExpertsPlugin.cpp | 4 +- cpp/tensorrt_llm/thop/moeOp.cpp | 22 +++---- .../kernels/mixtureOfExpertsTest.cu | 17 +++--- .../_torch/custom_ops/torch_custom_ops.py | 2 + .../modules/fused_moe/fused_moe_cutlass.py | 1 + .../modules/fused_moe/fused_moe_wide_ep.py | 3 +- 9 files changed, 86 insertions(+), 57 deletions(-) diff --git a/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h b/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h index 816af1f094..35f3bbd99b 100644 --- a/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h +++ b/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h @@ -780,7 +780,7 @@ public: auto stream = streamPtr->get(); MoeMinLatencyParams min_latency_params; #ifdef USING_OSS_CUTLASS_MOE_GEMM - mMoERunner.runMoe(mInputTensor, nullptr, mSelectedExperts, mUseFinalScale ? mScaleProbs : nullptr, + mMoERunner.runMoe(mInputTensor, nullptr, true, mSelectedExperts, mUseFinalScale ? mScaleProbs : nullptr, mExpertWeight1, mExpertBias1, mActType, mExpertWeight2, mExpertBias2, mQuantParams, mTotalTokens, mHiddenSize, mInterSize, mNumExperts, mK, mWorkspace, mFinalOutput, mSourceToExpandedMap, parallelism_config, /*enable_alltoall=*/false, mUseLora, mLoraParams, diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h index 6adf5cbf34..043208d9b2 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h @@ -366,14 +366,14 @@ public: = 0; virtual std::vector getTactics() = 0; - virtual void runMoe(void const* input_activations, void const* input_sf, int const* token_selected_experts, - float const* token_final_scales, void const* fc1_expert_weights, void const* fc1_expert_biases, - ActivationType fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases, - QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, - int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output, - int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool const enable_alltoall, - bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode, - MoeMinLatencyParams& min_latency_params, cudaStream_t stream) + virtual void runMoe(void const* input_activations, void const* input_sf, bool const swizzled_input_sf, + int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights, + void const* fc1_expert_biases, ActivationType fc1_activation_type, void const* fc2_expert_weights, + void const* fc2_expert_biases, QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, + int64_t const inter_size, int const num_experts, int const experts_per_token, char* workspace_ptr, + void* final_output, int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, + bool const enable_alltoall, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, + bool min_latency_mode, MoeMinLatencyParams& min_latency_params, cudaStream_t stream) = 0; // Aliases for profiling the gemms @@ -513,14 +513,14 @@ public: return RunnerType::getConfigs(sm); } - void runMoe(void const* input_activations, void const* input_sf, int const* token_selected_experts, - float const* token_final_scales, void const* fc1_expert_weights, void const* fc1_expert_biases, - ActivationType fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases, - QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, - int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output, - int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool const enable_alltoall, - bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode, - MoeMinLatencyParams& min_latency_params, cudaStream_t stream) override; + void runMoe(void const* input_activations, void const* input_sf, bool const swizzled_input_sf, + int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights, + void const* fc1_expert_biases, ActivationType fc1_activation_type, void const* fc2_expert_weights, + void const* fc2_expert_biases, QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, + int64_t const inter_size, int const num_experts, int const experts_per_token, char* workspace_ptr, + void* final_output, int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, + bool const enable_alltoall, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, + bool min_latency_mode, MoeMinLatencyParams& min_latency_params, cudaStream_t stream) override; // We make these GEMM1 & GEMM2 static because they need to be stateless for the profiler to work static void gemm1(MoeGemmRunner& gemm_runner, diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu index 963ba2a291..04149e28f0 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu @@ -1017,7 +1017,7 @@ __device__ uint32_t quantizePackedFP4Value(ComputeElem& post_act_val, float glob __device__ void writeSF(int64_t num_tokens_before_expert, int64_t expert_id, int64_t source_token_id, int64_t token_id, int64_t elem_idx, int64_t num_cols, int64_t max_tokens_per_expert, TmaWarpSpecializedGroupedGemmInput::ElementSF* act_sf_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf) + TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, bool const swizzled_input_sf = true) { static constexpr int CVT_FP4_NUM_THREADS_PER_SF = NVFP4_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD; @@ -1032,11 +1032,31 @@ __device__ void writeSF(int64_t num_tokens_before_expert, int64_t expert_id, int elem_idx, std::nullopt /* numRows */, num_cols, act_sf_expert, FP4QuantizationSFLayout::SWIZZLED); if (sf_out) { - auto const sf_in = cvt_quant_to_fp4_get_sf_out_offset(std::nullopt /* batchIdx */, source_token_id, elem_idx, - std::nullopt /* numRows */, num_cols, const_cast(input_sf), - FP4QuantizationSFLayout::SWIZZLED); - *sf_out = *sf_in; + if (input_sf) + { + if (swizzled_input_sf) + { + auto const sf_in = cvt_quant_to_fp4_get_sf_out_offset(std::nullopt /* batchIdx */, source_token_id, elem_idx, + std::nullopt /* numRows */, num_cols, + const_cast(input_sf), + FP4QuantizationSFLayout::SWIZZLED); + *sf_out = *sf_in; + } + else + { + auto const sf_in = cvt_quant_to_fp4_get_sf_out_offset(std::nullopt /* batchIdx */, source_token_id, elem_idx, + std::nullopt /* numRows */, num_cols, + const_cast(input_sf), + FP4QuantizationSFLayout::LINEAR); + *sf_out = *sf_in; + } + } + else + { + *sf_out = 0x00; + } } } @@ -1428,7 +1448,8 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp int const* permuted_row_to_unpermuted_row, int64_t const num_rows, int64_t const cols, int64_t const k, float const* fc1_act_global_scale, bool use_per_expert_act_scale, int64_t const* expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, int64_t const num_experts_per_node) + TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, bool const swizzled_input_sf, + int64_t const num_experts_per_node) { #ifdef ENABLE_FP4 constexpr bool is_fp4 = std::is_same_v; @@ -1498,7 +1519,7 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp { assert(act_scale_idx == 0 && "Cannot use per-expert act scale for pre-quantized activations"); writeSF(num_tokens_before_expert, expert, source_row, permuted_row, elem_index, cols, num_rows, - fc1_act_sf_flat, input_sf); + fc1_act_sf_flat, input_sf, swizzled_input_sf); dest_row_ptr[elem_index] = in_vec; } } @@ -1528,7 +1549,7 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input, int const* permuted_row_to_unpermuted_row, int64_t const num_rows, int64_t const cols, int const k, int const num_experts_per_node, float const* fc1_act_global_scale, bool use_per_expert_act_scale, int64_t* expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, cudaStream_t stream) + TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, bool const swizzled_input_sf, cudaStream_t stream) { #ifdef ENABLE_FP4 // TODO Currently this is a bit hacky because we assume we are in FP8_MXFP4 mode if activations are FP8. @@ -1566,7 +1587,7 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input, config.attrs = attrs; cudaLaunchKernelEx(&config, func, unpermuted_input, permuted_output, unpermuted_scales, permuted_scales, permuted_row_to_unpermuted_row, num_rows, cols, k, fc1_act_global_scale, use_per_expert_act_scale, - expert_first_token_offset, fc1_act_sf_flat, input_sf, num_experts_per_node); + expert_first_token_offset, fc1_act_sf_flat, input_sf, swizzled_input_sf, num_experts_per_node); } enum class ScaleMode : int @@ -3103,14 +3124,14 @@ void CutlassMoeFCRunner void CutlassMoeFCRunner::runMoe( - void const* input_activations_void, void const* input_sf_void, int const* token_selected_experts, - float const* token_final_scales, void const* fc1_expert_weights_void, void const* fc1_expert_biases_void, - ActivationType fc1_activation_type, void const* fc2_expert_weights_void, void const* fc2_expert_biases_void, - QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, - int const full_num_experts, int const experts_per_token, char* workspace_ptr, void* final_output_void, - int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool const enable_alltoall, - bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode, - MoeMinLatencyParams& min_latency_params, cudaStream_t stream) + void const* input_activations_void, void const* input_sf_void, bool const swizzled_input_sf, + int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights_void, + void const* fc1_expert_biases_void, ActivationType fc1_activation_type, void const* fc2_expert_weights_void, + void const* fc2_expert_biases_void, QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, + int64_t const inter_size, int const full_num_experts, int const experts_per_token, char* workspace_ptr, + void* final_output_void, int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, + bool const enable_alltoall, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, + bool min_latency_mode, MoeMinLatencyParams& min_latency_params, cudaStream_t stream) { static constexpr bool int_scales_required = std::is_same::value || std::is_same::value; @@ -3314,7 +3335,8 @@ void CutlassMoeFCRunner(permuted_data_), token_topk_unpermuted_scales, permuted_token_final_scales_, permuted_row_to_unpermuted_row_, num_rows, hidden_size, experts_per_token, num_experts_per_node, quant_params.fp4.fc1.act_global_scale, - use_per_expert_act_scale, expert_first_token_offset_, fc1_fp4_act_scale_, input_sf, stream); + use_per_expert_act_scale, expert_first_token_offset_, fc1_fp4_act_scale_, input_sf, swizzled_input_sf, + stream); sync_check_cuda_error(stream); diff --git a/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp b/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp index 6db0e4a382..a81ddde283 100644 --- a/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp +++ b/cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp @@ -957,7 +957,7 @@ int MixtureOfExpertsPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, MoeMinLatencyParams min_latency_params{}; mMOERunner->setTactic(gemm1, gemm2); #ifdef USING_OSS_CUTLASS_MOE_GEMM - mMOERunner->runMoe(inputs[getInputTensorIndex()], nullptr, + mMOERunner->runMoe(inputs[getInputTensorIndex()], nullptr, true, static_cast(inputs[getTokenSelectedExpertsIndex()]), hasFinalScales() ? static_cast(inputs[getTokenFinalScalesIndex()]) : nullptr, inputs[getExpertWeights1Index()], hasBias() ? inputs[getExpertBias1Index()] : nullptr, mActivationType, @@ -968,7 +968,7 @@ int MixtureOfExpertsPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, /*enable_alltoall=*/false, hasLora(), lora_params, /*use_deepseek_fp8_block_scale=*/false, /*min_latency_mode=*/false, min_latency_params, stream); #else - mMOERunner->runMoe(inputs[getInputTensorIndex()], nullptr, + mMOERunner->runMoe(inputs[getInputTensorIndex()], nullptr, true, static_cast(inputs[getTokenSelectedExpertsIndex()]), hasFinalScales() ? static_cast(inputs[getTokenFinalScalesIndex()]) : nullptr, inputs[getExpertWeights1Index()], hasBias() ? inputs[getExpertBias1Index()] : nullptr, mActivationType, diff --git a/cpp/tensorrt_llm/thop/moeOp.cpp b/cpp/tensorrt_llm/thop/moeOp.cpp index b3f9ef876e..ad1a6b5b92 100644 --- a/cpp/tensorrt_llm/thop/moeOp.cpp +++ b/cpp/tensorrt_llm/thop/moeOp.cpp @@ -215,9 +215,10 @@ public: torch::optional const& fc1_expert_biases, torch::Tensor const& fc2_expert_weights, torch::optional const& fc2_expert_biases, torch::optional> const& quant_scales, - torch::optional const& input_sf, int64_t const tp_size, int64_t const tp_rank, - int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank, - bool const enable_alltoall, bool min_latency_mode, torch::optional> const& profile_ids) + torch::optional const& input_sf, bool const swizzled_input_sf, int64_t const tp_size, + int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, + int64_t const cluster_rank, bool const enable_alltoall, bool min_latency_mode, + torch::optional> const& profile_ids) { std::lock_guard lock(mMutex); // Free the profile workspace to save memory @@ -297,7 +298,7 @@ public: ::tensorrt_llm::kernels::LoraParams lora_params{}; #ifdef USING_OSS_CUTLASS_MOE_GEMM mKernelRunner->runMoe(input.const_data_ptr(), - input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr, + input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr, swizzled_input_sf, reinterpret_cast(token_selected_experts.const_data_ptr()), token_final_scales.has_value() ? reinterpret_cast(token_final_scales.value().const_data_ptr()) : nullptr, @@ -311,7 +312,7 @@ public: mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream); #else mKernelRunner->runMoe(input.const_data_ptr(), - input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr, + input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr, swizzled_input_sf, reinterpret_cast(token_selected_experts.const_data_ptr()), token_final_scales.has_value() ? reinterpret_cast(token_final_scales.value().const_data_ptr()) : nullptr, @@ -333,9 +334,10 @@ public: torch::Tensor const& fc1_expert_weights, torch::optional const& fc1_expert_biases, torch::Tensor const& fc2_expert_weights, torch::optional const& fc2_expert_biases, torch::optional> const& quant_scales, - torch::optional const& input_sf, int64_t const tp_size, int64_t const tp_rank, - int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank, - bool const enable_alltoall, bool min_latency_mode, torch::optional> const& profile_ids) + torch::optional const& input_sf, bool const swizzled_input_sf, int64_t const tp_size, + int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, + int64_t const cluster_rank, bool const enable_alltoall, bool min_latency_mode, + torch::optional> const& profile_ids) { std::lock_guard lock(mMutex); @@ -424,7 +426,7 @@ public: ::tensorrt_llm::kernels::LoraParams lora_params{}; #ifdef USING_OSS_CUTLASS_MOE_GEMM mKernelRunner->runMoe(input.const_data_ptr(), - input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr, + input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr, swizzled_input_sf, reinterpret_cast(token_selected_experts.const_data_ptr()), token_final_scales.has_value() ? reinterpret_cast(token_final_scales.value().const_data_ptr()) : nullptr, @@ -438,7 +440,7 @@ public: mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, stream); #else mKernelRunner->runMoe(input.const_data_ptr(), - input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr, + input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr, swizzled_input_sf, reinterpret_cast(token_selected_experts.const_data_ptr()), token_final_scales.has_value() ? reinterpret_cast(token_final_scales.value().const_data_ptr()) : nullptr, diff --git a/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu b/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu index a44ca2a4a8..a904c0e6d5 100644 --- a/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu +++ b/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu @@ -1106,15 +1106,16 @@ protected: MoeMinLatencyParams min_latency_params; mMoERunner.setTactic(tactic1, tactic2); #ifdef USING_OSS_CUTLASS_MOE_GEMM - mMoERunner.runMoe(mInputTensor, nullptr, mSelectedExpert, mTokenFinalScales, weight1_ptr, bias1_ptr, mActType, - weight2_ptr, bias2_ptr, quant_params, mTotalTokens, mHiddenSize, mInterSize / parallelism_config.tp_size, - mNumExperts, mK, mWorkspace, mFinalOutput, mSourceToExpandedMap, parallelism_config, enable_alltoall, - mUseLora, lora_params, useFp8BlockScales, minLatencyMode, min_latency_params, stream); + mMoERunner.runMoe(mInputTensor, nullptr, true, mSelectedExpert, mTokenFinalScales, weight1_ptr, bias1_ptr, + mActType, weight2_ptr, bias2_ptr, quant_params, mTotalTokens, mHiddenSize, + mInterSize / parallelism_config.tp_size, mNumExperts, mK, mWorkspace, mFinalOutput, mSourceToExpandedMap, + parallelism_config, enable_alltoall, mUseLora, lora_params, useFp8BlockScales, minLatencyMode, + min_latency_params, stream); #else - mMoERunner.runMoe(mInputTensor, nullptr, mSelectedExpert, mTokenFinalScales, weight1_ptr, bias1_ptr, mActType, - weight2_ptr, bias2_ptr, quant_params, mTotalTokens, mHiddenSize, mInterSize / parallelism_config.tp_size, - mNumExperts, mK, mWorkspace, mFinalOutput, mSourceToExpandedMap, parallelism_config, mUseLora, lora_params, - useFp8BlockScales, minLatencyMode, min_latency_params, stream); + mMoERunner.runMoe(mInputTensor, nullptr, true, mSelectedExpert, mTokenFinalScales, weight1_ptr, bias1_ptr, + mActType, weight2_ptr, bias2_ptr, quant_params, mTotalTokens, mHiddenSize, + mInterSize / parallelism_config.tp_size, mNumExperts, mK, mWorkspace, mFinalOutput, mSourceToExpandedMap, + parallelism_config, mUseLora, lora_params, useFp8BlockScales, minLatencyMode, min_latency_params, stream); #endif check_cuda_error(cudaStreamSynchronize(stream)); diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index 6d1e8c05e5..a4099a9712 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -124,6 +124,7 @@ def fused_moe( output_dtype: torch.dtype, quant_scales: List[torch.Tensor], input_sf: Optional[torch.Tensor] = None, + swizzled_input_sf: bool = True, tp_size: int = 1, tp_rank: int = 0, ep_size: int = 1, @@ -191,6 +192,7 @@ def fused_moe( fc2_expert_biases, quant_scales, input_sf, + swizzled_input_sf, tp_size, tp_rank, ep_size, diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py index 82e01a83cd..a24708ed47 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py @@ -275,6 +275,7 @@ class CutlassFusedMoE(MoE): output_dtype, quant_scales=self.quant_scales, input_sf=x_sf, + swizzled_input_sf=True, tp_size=self.tp_size, tp_rank=self.tp_rank, ep_size=self.ep_size, diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py index 829a449748..07873fbfad 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py @@ -603,6 +603,7 @@ class WideEPMoE(MoE): output_dtype, quant_scales=quant_scales, input_sf=x_sf, + swizzled_input_sf=sf_swizzle, tp_size=self.tp_size, tp_rank=self.tp_rank, ep_size=ep_size, @@ -881,7 +882,7 @@ class WideEPMoE(MoE): self.alltoall_workspace, self.ep_rank, self.ep_size) - if self.has_nvfp4: + if self.has_nvfp4 and is_sf_swizzle: x_sf = swizzle_sf(x_sf, x.shape[0], x.shape[1] * 2, self.scaling_vector_size)