mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
fix: glm5.1 pp model loading (#42944)
Signed-off-by: UranusSeven <109661872+UranusSeven@users.noreply.github.com>
This commit is contained in:
@@ -35,7 +35,7 @@ from .deepseek_v2 import (
|
||||
_try_load_fp8_indexer_wk,
|
||||
get_spec_layer_idx_from_weight_name,
|
||||
)
|
||||
from .utils import maybe_prefix
|
||||
from .utils import get_pp_missing_layer_names, maybe_prefix
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -267,6 +267,7 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
|
||||
),
|
||||
)
|
||||
|
||||
pp_missing_layer_names = get_pp_missing_layer_names(self)
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
_pending_wk_fp8: dict = {} # FP8 indexer wk dequant buffer
|
||||
@@ -282,7 +283,12 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts):
|
||||
name = self._rewrite_spec_layer_name(spec_layer, name)
|
||||
|
||||
if _try_load_fp8_indexer_wk(
|
||||
name, loaded_weight, _pending_wk_fp8, params_dict, loaded_params
|
||||
name,
|
||||
loaded_weight,
|
||||
_pending_wk_fp8,
|
||||
params_dict,
|
||||
loaded_params,
|
||||
pp_missing_layer_names,
|
||||
):
|
||||
continue
|
||||
|
||||
|
||||
@@ -105,6 +105,7 @@ from .interfaces import (
|
||||
)
|
||||
from .utils import (
|
||||
PPMissingLayer,
|
||||
get_pp_missing_layer_names,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory,
|
||||
make_layers,
|
||||
@@ -742,7 +743,9 @@ class Indexer(nn.Module):
|
||||
return self.indexer_op(hidden_states, q_fp8, k, weights)
|
||||
|
||||
|
||||
def _try_load_fp8_indexer_wk(name, tensor, buf, params_dict, loaded_params):
|
||||
def _try_load_fp8_indexer_wk(
|
||||
name, tensor, buf, params_dict, loaded_params, pp_missing_layer_names
|
||||
):
|
||||
"""
|
||||
We fuse the WK and weights_proj projections, but in some checkpoints WK is stored
|
||||
in FP8 with a separate weight_scale_inv, while weights_proj is stored in BF16.
|
||||
@@ -758,6 +761,12 @@ def _try_load_fp8_indexer_wk(name, tensor, buf, params_dict, loaded_params):
|
||||
return False # WK is not in FP8 format, ignore.
|
||||
# Buffer this tensor (weight or scale) until both have arrived.
|
||||
layer_prefix = name.rsplit(".wk.", 1)[0] # e.g. "model.layers.0.self_attn.indexer"
|
||||
fused_name = f"{layer_prefix}.wk_weights_proj.weight"
|
||||
if any(
|
||||
name.startswith(missing_layer_name)
|
||||
for missing_layer_name in pp_missing_layer_names
|
||||
):
|
||||
return True
|
||||
entry = buf.setdefault(layer_prefix, {})
|
||||
entry["weight" if is_weight else "scale"] = tensor
|
||||
if "weight" not in entry or "scale" not in entry:
|
||||
@@ -775,7 +784,6 @@ def _try_load_fp8_indexer_wk(name, tensor, buf, params_dict, loaded_params):
|
||||
)
|
||||
|
||||
# Load the dequantized weight into shard 0 of the fused buffer.
|
||||
fused_name = f"{layer_prefix}.wk_weights_proj.weight"
|
||||
param = params_dict[fused_name]
|
||||
param.weight_loader(param, weight_bf16, 0)
|
||||
loaded_params.add(fused_name)
|
||||
@@ -1379,6 +1387,7 @@ class DeepseekV2Model(nn.Module):
|
||||
num_redundant_experts=self.num_redundant_experts,
|
||||
)
|
||||
|
||||
pp_missing_layer_names = get_pp_missing_layer_names(self)
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
@@ -1394,7 +1403,12 @@ class DeepseekV2Model(nn.Module):
|
||||
)
|
||||
|
||||
if _try_load_fp8_indexer_wk(
|
||||
name, loaded_weight, _pending_wk_fp8, params_dict, loaded_params
|
||||
name,
|
||||
loaded_weight,
|
||||
_pending_wk_fp8,
|
||||
params_dict,
|
||||
loaded_params,
|
||||
pp_missing_layer_names,
|
||||
):
|
||||
continue
|
||||
|
||||
|
||||
Reference in New Issue
Block a user