[None][fix] Ensure that the W4A8 custom input scale remains aligned across all ranks (#7614)

Signed-off-by: Yilin Zhang <18275976+yilin-void@users.noreply.github.com>
This commit is contained in:
Void 2025-09-16 11:04:26 +08:00 committed by GitHub
parent cf55927064
commit 103b554734
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1144,16 +1144,20 @@ class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase):
weight_scale_name = "weight_scale"
assert (len(module.interleave) == 2)
# Ensure that the input_scale remains aligned across all ranks for W4A8 custom.
input_scale_expert_ids = module.initial_local_expert_ids if not w4a8_custom else range(
module.num_experts)
# fc31 scales
all_w3_input_scales = [
load_weight_shard(weights[f"{expert_id}.w3.input_scale"],
device=self.device)
for expert_id in module.initial_local_expert_ids
for expert_id in input_scale_expert_ids
]
all_w1_input_scales = [
load_weight_shard(weights[f"{expert_id}.w1.input_scale"],
device=self.device)
for expert_id in module.initial_local_expert_ids
for expert_id in input_scale_expert_ids
]
all_w3_w1_input_scales_max = torch.max(
torch.stack(all_w3_input_scales),