fix: glm5.1 pp model loading (#42944)

Signed-off-by: UranusSeven <109661872+UranusSeven@users.noreply.github.com>
This commit is contained in:
Uranus
2026-06-01 15:14:47 +08:00
committed by GitHub
parent 98f1279815
commit 1f6048abe5
2 changed files with 25 additions and 5 deletions
+8 -2
View File
@@ -35,7 +35,7 @@ from .deepseek_v2 import (
_try_load_fp8_indexer_wk,
get_spec_layer_idx_from_weight_name,
)
from .utils import maybe_prefix
from .utils import get_pp_missing_layer_names, maybe_prefix
logger = init_logger(__name__)
@@ -267,6 +267,7 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
),
)
pp_missing_layer_names = get_pp_missing_layer_names(self)
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
_pending_wk_fp8: dict = {} # FP8 indexer wk dequant buffer
@@ -282,7 +283,12 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
name = self._rewrite_spec_layer_name(spec_layer, name)
if _try_load_fp8_indexer_wk(
name, loaded_weight, _pending_wk_fp8, params_dict, loaded_params
name,
loaded_weight,
_pending_wk_fp8,
params_dict,
loaded_params,
pp_missing_layer_names,
):
continue
+17 -3
View File
@@ -105,6 +105,7 @@ from .interfaces import (
)
from .utils import (
PPMissingLayer,
get_pp_missing_layer_names,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
make_layers,
@@ -742,7 +743,9 @@ class Indexer(nn.Module):
return self.indexer_op(hidden_states, q_fp8, k, weights)
def _try_load_fp8_indexer_wk(name, tensor, buf, params_dict, loaded_params):
def _try_load_fp8_indexer_wk(
name, tensor, buf, params_dict, loaded_params, pp_missing_layer_names
):
"""
We fuse the WK and weights_proj projections, but in some checkpoints WK is stored
in FP8 with a separate weight_scale_inv, while weights_proj is stored in BF16.
@@ -758,6 +761,12 @@ def _try_load_fp8_indexer_wk(name, tensor, buf, params_dict, loaded_params):
return False # WK is not in FP8 format, ignore.
# Buffer this tensor (weight or scale) until both have arrived.
layer_prefix = name.rsplit(".wk.", 1)[0] # e.g. "model.layers.0.self_attn.indexer"
fused_name = f"{layer_prefix}.wk_weights_proj.weight"
if any(
name.startswith(missing_layer_name)
for missing_layer_name in pp_missing_layer_names
):
return True
entry = buf.setdefault(layer_prefix, {})
entry["weight" if is_weight else "scale"] = tensor
if "weight" not in entry or "scale" not in entry:
@@ -775,7 +784,6 @@ def _try_load_fp8_indexer_wk(name, tensor, buf, params_dict, loaded_params):
)
# Load the dequantized weight into shard 0 of the fused buffer.
fused_name = f"{layer_prefix}.wk_weights_proj.weight"
param = params_dict[fused_name]
param.weight_loader(param, weight_bf16, 0)
loaded_params.add(fused_name)
@@ -1379,6 +1387,7 @@ class DeepseekV2Model(nn.Module):
num_redundant_experts=self.num_redundant_experts,
)
pp_missing_layer_names = get_pp_missing_layer_names(self)
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
@@ -1394,7 +1403,12 @@ class DeepseekV2Model(nn.Module):
)
if _try_load_fp8_indexer_wk(
name, loaded_weight, _pending_wk_fp8, params_dict, loaded_params
name,
loaded_weight,
_pending_wk_fp8,
params_dict,
loaded_params,
pp_missing_layer_names,
):
continue