diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index c8f30c8960..3602147b68 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -1144,16 +1144,20 @@ class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase): weight_scale_name = "weight_scale" assert (len(module.interleave) == 2) + + # Ensure that the input_scale remains aligned across all ranks for W4A8 custom. + input_scale_expert_ids = module.initial_local_expert_ids if not w4a8_custom else range( + module.num_experts) # fc31 scales all_w3_input_scales = [ load_weight_shard(weights[f"{expert_id}.w3.input_scale"], device=self.device) - for expert_id in module.initial_local_expert_ids + for expert_id in input_scale_expert_ids ] all_w1_input_scales = [ load_weight_shard(weights[f"{expert_id}.w1.input_scale"], device=self.device) - for expert_id in module.initial_local_expert_ids + for expert_id in input_scale_expert_ids ] all_w3_w1_input_scales_max = torch.max( torch.stack(all_w3_input_scales),