mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5449155][fix] Fix DeepSeek R1 weight loading for TP16 (#6913)
Signed-off-by: Aurelien Chartier <2567591+achartier@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
This commit is contained in:
parent
21291f3d8e
commit
93e623b455
@ -650,9 +650,12 @@ class FP8BlockScalesLinearMethod(LinearMethodBase):
|
||||
load_weights_vanilla_helper(module, weights)
|
||||
|
||||
scale_name = self._get_scale_name(weights)
|
||||
weight_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
|
||||
module.tp_rank,
|
||||
module.tp_mode).squeeze()
|
||||
full_weight_scale = weights[0][scale_name]
|
||||
# modelopt fp8_pb_wo can have 2 extra singleton dimensions
|
||||
if full_weight_scale.dim() == 4:
|
||||
full_weight_scale = full_weight_scale.squeeze(1).squeeze(-1)
|
||||
weight_scale = load_weight_shard(full_weight_scale, module.tp_size,
|
||||
module.tp_rank, module.tp_mode)
|
||||
copy_weight(module.weight_scale, weight_scale)
|
||||
if "input_scale" in weights[0]:
|
||||
copy_weight(module.input_scale, weights[0]["input_scale"])
|
||||
@ -665,13 +668,23 @@ class FP8BlockScalesLinearMethod(LinearMethodBase):
|
||||
fused_weight = torch.cat((q_weight, k_weight, v_weight))
|
||||
|
||||
scale_name = self._get_scale_name(weights)
|
||||
q_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
|
||||
full_q_scale = weights[0][scale_name]
|
||||
full_k_scale = weights[1][scale_name]
|
||||
full_v_scale = weights[2][scale_name]
|
||||
# modelopt fp8_pb_wo can have 2 extra singleton dimensions
|
||||
if full_q_scale.dim() == 4:
|
||||
full_q_scale = full_q_scale.squeeze(1).squeeze(-1)
|
||||
if full_k_scale.dim() == 4:
|
||||
full_k_scale = full_k_scale.squeeze(1).squeeze(-1)
|
||||
if full_v_scale.dim() == 4:
|
||||
full_v_scale = full_v_scale.squeeze(1).squeeze(-1)
|
||||
q_scale = load_weight_shard(full_q_scale, module.tp_size,
|
||||
module.tp_rank, module.tp_mode)
|
||||
k_scale = load_weight_shard(weights[1][scale_name], module.tp_size,
|
||||
k_scale = load_weight_shard(full_k_scale, module.tp_size,
|
||||
module.tp_rank, module.tp_mode)
|
||||
v_scale = load_weight_shard(weights[2][scale_name], module.tp_size,
|
||||
v_scale = load_weight_shard(full_v_scale, module.tp_size,
|
||||
module.tp_rank, module.tp_mode)
|
||||
fused_fp8_block_scale = torch.cat((q_scale, k_scale, v_scale)).squeeze()
|
||||
fused_fp8_block_scale = torch.cat((q_scale, k_scale, v_scale))
|
||||
|
||||
copy_weight(module.weight, fused_weight)
|
||||
copy_weight(module.weight_scale, fused_fp8_block_scale)
|
||||
@ -683,11 +696,18 @@ class FP8BlockScalesLinearMethod(LinearMethodBase):
|
||||
fused_weight = torch.cat((gate_weight, up_weight))
|
||||
|
||||
scale_name = self._get_scale_name(weights)
|
||||
left_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
|
||||
full_left_scale = weights[0][scale_name]
|
||||
full_right_scale = weights[1][scale_name]
|
||||
# modelopt fp8_pb_wo can have 2 extra singleton dimensions
|
||||
if full_left_scale.dim() == 4:
|
||||
full_left_scale = full_left_scale.squeeze(1).squeeze(-1)
|
||||
if full_right_scale.dim() == 4:
|
||||
full_right_scale = full_right_scale.squeeze(1).squeeze(-1)
|
||||
left_scale = load_weight_shard(full_left_scale, module.tp_size,
|
||||
module.tp_rank, module.tp_mode)
|
||||
right_scale = load_weight_shard(weights[1][scale_name], module.tp_size,
|
||||
right_scale = load_weight_shard(full_right_scale, module.tp_size,
|
||||
module.tp_rank, module.tp_mode)
|
||||
fused_scale = torch.cat([left_scale, right_scale], dim=0).squeeze()
|
||||
fused_scale = torch.cat([left_scale, right_scale], dim=0)
|
||||
copy_weight(module.weight, fused_weight)
|
||||
copy_weight(module.weight_scale, fused_scale)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user