mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 07:53:55 +08:00
[TRTLLM-10669][fix] Fix Eagle3 draft model weight loading for throughput checkpoint (#11010)
Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
This commit is contained in:
parent
9384cf8458
commit
c37531c3f7
@ -482,9 +482,26 @@ class Eagle3ForCausalLM(DecoderModelForCausalLM[Eagle3DraftModel,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Dict, weight_mapper: BaseWeightMapper):
|
||||
# Remap weight names: some Eagle3 checkpoints use "layers.X.*" naming convention
|
||||
# while the model expects "midlayer.*" naming. Handle both formats.
|
||||
import re
|
||||
remapped_weights = {}
|
||||
# Access num_layers from the inner draft model (self.model is Eagle3DraftModel)
|
||||
num_layers = self.model.num_layers
|
||||
for k, v in weights.items():
|
||||
new_k = k
|
||||
# For single-layer models: "layers.0.*" -> "midlayer.*"
|
||||
# For multi-layer models: "layers.X.*" -> "midlayer.X.*"
|
||||
if num_layers == 1:
|
||||
# Single layer: layers.0.foo -> midlayer.foo
|
||||
new_k = re.sub(r'^layers\.0\.', 'midlayer.', new_k)
|
||||
else:
|
||||
# Multi-layer: layers.X.foo -> midlayer.X.foo
|
||||
new_k = re.sub(r'^layers\.(\d+)\.', r'midlayer.\1.', new_k)
|
||||
remapped_weights[new_k] = v
|
||||
|
||||
new_weights = {}
|
||||
for k, v in weights.items():
|
||||
for k, v in remapped_weights.items():
|
||||
if 'lm_head' not in k:
|
||||
new_k = "model." + k
|
||||
else:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user