[fix] Fix W4A8 weight loading error in WInt4AFP8FusedMoEMethod (#5026)

Signed-off-by: Xiaowei Wang <100599594+xiaoweiw-nv@users.noreply.github.com>
This commit is contained in:
Xiaowei Wang 2025-06-10 15:09:06 +08:00 committed by GitHub
parent 12ffdcbf53
commit ec6b1821c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -550,13 +550,13 @@ class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase):
module.intermediate_size_per_partition // 2)
fc31_act_scale = nn.Parameter(torch.empty(1,
self.hidden_size,
dtype=self.dtype),
module.hidden_size,
dtype=module.dtype),
requires_grad=False)
module.register_parameter("fc31_act_scale", fc31_act_scale)
fc2_act_scale = nn.Parameter(torch.empty(
1, self.intermediate_size_per_partition, 1, dtype=self.dtype),
1, module.intermediate_size_per_partition, 1, dtype=module.dtype),
requires_grad=False)
module.register_parameter("fc2_act_scale", fc2_act_scale)
@ -701,11 +701,11 @@ class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase):
all_w3_w1_input_scales_max = torch.max(
torch.stack(all_w3_input_scales),
torch.stack(all_w1_input_scales)).max()
self.fc31_act_scale.data.copy_(
torch.ones_like(self.fc31_act_scale) *
module.fc31_act_scale.data.copy_(
torch.ones_like(module.fc31_act_scale) *
(1 / all_w3_w1_input_scales_max))
self.fc31_alpha.data.copy_((torch.ones_like(self.fc31_alpha) *
all_w3_w1_input_scales_max).float())
module.fc31_alpha.data.copy_((torch.ones_like(module.fc31_alpha) *
all_w3_w1_input_scales_max).float())
all_w3_scales = [
load_weight_shard(weights[f"{expert_id}.w3.weight_scale_inv"],
@ -744,11 +744,12 @@ class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase):
for expert_id in module.initial_local_expert_ids
]
all_w2_input_scales_max = torch.stack(all_w2_input_scales).to(
self.dtype).max()
self.fc2_act_scale.data.copy_(
torch.ones_like(self.fc2_act_scale) * (1 / all_w2_input_scales_max))
self.fc2_alpha.data.copy_(
(torch.ones_like(self.fc2_alpha) * all_w2_input_scales_max).float())
module.dtype).max()
module.fc2_act_scale.data.copy_(
torch.ones_like(module.fc2_act_scale) *
(1 / all_w2_input_scales_max))
module.fc2_alpha.data.copy_((torch.ones_like(module.fc2_alpha) *
all_w2_input_scales_max).float())
all_w2_scales = [
load_weight_shard(weights[f"{expert_id}.w2.weight_scale_inv"],