mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[fix] Fix W4A8 weight loading error in WInt4AFP8FusedMoEMethod (#5026)
Signed-off-by: Xiaowei Wang <100599594+xiaoweiw-nv@users.noreply.github.com>
This commit is contained in:
parent
12ffdcbf53
commit
ec6b1821c7
@ -550,13 +550,13 @@ class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase):
|
||||
module.intermediate_size_per_partition // 2)
|
||||
|
||||
fc31_act_scale = nn.Parameter(torch.empty(1,
|
||||
self.hidden_size,
|
||||
dtype=self.dtype),
|
||||
module.hidden_size,
|
||||
dtype=module.dtype),
|
||||
requires_grad=False)
|
||||
module.register_parameter("fc31_act_scale", fc31_act_scale)
|
||||
|
||||
fc2_act_scale = nn.Parameter(torch.empty(
|
||||
1, self.intermediate_size_per_partition, 1, dtype=self.dtype),
|
||||
1, module.intermediate_size_per_partition, 1, dtype=module.dtype),
|
||||
requires_grad=False)
|
||||
module.register_parameter("fc2_act_scale", fc2_act_scale)
|
||||
|
||||
@ -701,11 +701,11 @@ class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase):
|
||||
all_w3_w1_input_scales_max = torch.max(
|
||||
torch.stack(all_w3_input_scales),
|
||||
torch.stack(all_w1_input_scales)).max()
|
||||
self.fc31_act_scale.data.copy_(
|
||||
torch.ones_like(self.fc31_act_scale) *
|
||||
module.fc31_act_scale.data.copy_(
|
||||
torch.ones_like(module.fc31_act_scale) *
|
||||
(1 / all_w3_w1_input_scales_max))
|
||||
self.fc31_alpha.data.copy_((torch.ones_like(self.fc31_alpha) *
|
||||
all_w3_w1_input_scales_max).float())
|
||||
module.fc31_alpha.data.copy_((torch.ones_like(module.fc31_alpha) *
|
||||
all_w3_w1_input_scales_max).float())
|
||||
|
||||
all_w3_scales = [
|
||||
load_weight_shard(weights[f"{expert_id}.w3.weight_scale_inv"],
|
||||
@ -744,11 +744,12 @@ class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase):
|
||||
for expert_id in module.initial_local_expert_ids
|
||||
]
|
||||
all_w2_input_scales_max = torch.stack(all_w2_input_scales).to(
|
||||
self.dtype).max()
|
||||
self.fc2_act_scale.data.copy_(
|
||||
torch.ones_like(self.fc2_act_scale) * (1 / all_w2_input_scales_max))
|
||||
self.fc2_alpha.data.copy_(
|
||||
(torch.ones_like(self.fc2_alpha) * all_w2_input_scales_max).float())
|
||||
module.dtype).max()
|
||||
module.fc2_act_scale.data.copy_(
|
||||
torch.ones_like(module.fc2_act_scale) *
|
||||
(1 / all_w2_input_scales_max))
|
||||
module.fc2_alpha.data.copy_((torch.ones_like(module.fc2_alpha) *
|
||||
all_w2_input_scales_max).float())
|
||||
|
||||
all_w2_scales = [
|
||||
load_weight_shard(weights[f"{expert_id}.w2.weight_scale_inv"],
|
||||
|
||||
Loading…
Reference in New Issue
Block a user