mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 15:55:08 +08:00
[None][fix] Ensure that the W4A8 custom input scale remains aligned across all ranks (#7614)
Signed-off-by: Yilin Zhang <18275976+yilin-void@users.noreply.github.com>
This commit is contained in:
parent
cf55927064
commit
103b554734
@ -1144,16 +1144,20 @@ class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase):
|
||||
weight_scale_name = "weight_scale"
|
||||
|
||||
assert (len(module.interleave) == 2)
|
||||
|
||||
# Ensure that the input_scale remains aligned across all ranks for W4A8 custom.
|
||||
input_scale_expert_ids = module.initial_local_expert_ids if not w4a8_custom else range(
|
||||
module.num_experts)
|
||||
# fc31 scales
|
||||
all_w3_input_scales = [
|
||||
load_weight_shard(weights[f"{expert_id}.w3.input_scale"],
|
||||
device=self.device)
|
||||
for expert_id in module.initial_local_expert_ids
|
||||
for expert_id in input_scale_expert_ids
|
||||
]
|
||||
all_w1_input_scales = [
|
||||
load_weight_shard(weights[f"{expert_id}.w1.input_scale"],
|
||||
device=self.device)
|
||||
for expert_id in module.initial_local_expert_ids
|
||||
for expert_id in input_scale_expert_ids
|
||||
]
|
||||
all_w3_w1_input_scales_max = torch.max(
|
||||
torch.stack(all_w3_input_scales),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user