[TRTLLM-10669][fix] Fix Eagle3 draft model weight loading for throughput checkpoint (#11010)

Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
This commit is contained in:
Guiju Zhang 2026-01-28 15:25:51 -08:00 committed by Yanchao Lu
parent 9384cf8458
commit c37531c3f7

View File

@ -482,9 +482,26 @@ class Eagle3ForCausalLM(DecoderModelForCausalLM[Eagle3DraftModel,
)
def load_weights(self, weights: Dict, weight_mapper: BaseWeightMapper):
# Remap weight names: some Eagle3 checkpoints use "layers.X.*" naming convention
# while the model expects "midlayer.*" naming. Handle both formats.
import re
remapped_weights = {}
# Access num_layers from the inner draft model (self.model is Eagle3DraftModel)
num_layers = self.model.num_layers
for k, v in weights.items():
new_k = k
# For single-layer models: "layers.0.*" -> "midlayer.*"
# For multi-layer models: "layers.X.*" -> "midlayer.X.*"
if num_layers == 1:
# Single layer: layers.0.foo -> midlayer.foo
new_k = re.sub(r'^layers\.0\.', 'midlayer.', new_k)
else:
# Multi-layer: layers.X.foo -> midlayer.X.foo
new_k = re.sub(r'^layers\.(\d+)\.', r'midlayer.\1.', new_k)
remapped_weights[new_k] = v
new_weights = {}
for k, v in weights.items():
for k, v in remapped_weights.items():
if 'lm_head' not in k:
new_k = "model." + k
else: