[https://nvbugs/5449155][fix] Fix DeepSeek R1 weight loading for TP16 (#6913)

Signed-off-by: Aurelien Chartier <2567591+achartier@users.noreply.github.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
This commit is contained in:
Aurelien Chartier 2025-08-18 19:25:43 -07:00 committed by Jonas Yang CN
parent 21291f3d8e
commit 93e623b455

View File

@ -650,9 +650,12 @@ class FP8BlockScalesLinearMethod(LinearMethodBase):
load_weights_vanilla_helper(module, weights)
scale_name = self._get_scale_name(weights)
weight_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
module.tp_rank,
module.tp_mode).squeeze()
full_weight_scale = weights[0][scale_name]
# modelopt fp8_pb_wo can have 2 extra singleton dimensions
if full_weight_scale.dim() == 4:
full_weight_scale = full_weight_scale.squeeze(1).squeeze(-1)
weight_scale = load_weight_shard(full_weight_scale, module.tp_size,
module.tp_rank, module.tp_mode)
copy_weight(module.weight_scale, weight_scale)
if "input_scale" in weights[0]:
copy_weight(module.input_scale, weights[0]["input_scale"])
@ -665,13 +668,23 @@ class FP8BlockScalesLinearMethod(LinearMethodBase):
fused_weight = torch.cat((q_weight, k_weight, v_weight))
scale_name = self._get_scale_name(weights)
q_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
full_q_scale = weights[0][scale_name]
full_k_scale = weights[1][scale_name]
full_v_scale = weights[2][scale_name]
# modelopt fp8_pb_wo can have 2 extra singleton dimensions
if full_q_scale.dim() == 4:
full_q_scale = full_q_scale.squeeze(1).squeeze(-1)
if full_k_scale.dim() == 4:
full_k_scale = full_k_scale.squeeze(1).squeeze(-1)
if full_v_scale.dim() == 4:
full_v_scale = full_v_scale.squeeze(1).squeeze(-1)
q_scale = load_weight_shard(full_q_scale, module.tp_size,
module.tp_rank, module.tp_mode)
k_scale = load_weight_shard(weights[1][scale_name], module.tp_size,
k_scale = load_weight_shard(full_k_scale, module.tp_size,
module.tp_rank, module.tp_mode)
v_scale = load_weight_shard(weights[2][scale_name], module.tp_size,
v_scale = load_weight_shard(full_v_scale, module.tp_size,
module.tp_rank, module.tp_mode)
fused_fp8_block_scale = torch.cat((q_scale, k_scale, v_scale)).squeeze()
fused_fp8_block_scale = torch.cat((q_scale, k_scale, v_scale))
copy_weight(module.weight, fused_weight)
copy_weight(module.weight_scale, fused_fp8_block_scale)
@ -683,11 +696,18 @@ class FP8BlockScalesLinearMethod(LinearMethodBase):
fused_weight = torch.cat((gate_weight, up_weight))
scale_name = self._get_scale_name(weights)
left_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
full_left_scale = weights[0][scale_name]
full_right_scale = weights[1][scale_name]
# modelopt fp8_pb_wo can have 2 extra singleton dimensions
if full_left_scale.dim() == 4:
full_left_scale = full_left_scale.squeeze(1).squeeze(-1)
if full_right_scale.dim() == 4:
full_right_scale = full_right_scale.squeeze(1).squeeze(-1)
left_scale = load_weight_shard(full_left_scale, module.tp_size,
module.tp_rank, module.tp_mode)
right_scale = load_weight_shard(weights[1][scale_name], module.tp_size,
right_scale = load_weight_shard(full_right_scale, module.tp_size,
module.tp_rank, module.tp_mode)
fused_scale = torch.cat([left_scale, right_scale], dim=0).squeeze()
fused_scale = torch.cat([left_scale, right_scale], dim=0)
copy_weight(module.weight, fused_weight)
copy_weight(module.weight_scale, fused_scale)