From c37531c3f767e5b4c9a3682608ec8aec348ef62c Mon Sep 17 00:00:00 2001 From: Guiju Zhang <7135567+cascade812@users.noreply.github.com> Date: Wed, 28 Jan 2026 15:25:51 -0800 Subject: [PATCH] [TRTLLM-10669][fix] Fix Eagle3 draft model weight loading for throughput checkpoint (#11010) Signed-off-by: Guiju Zhang <7135567+cascade812@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> --- .../_torch/models/modeling_speculative.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/models/modeling_speculative.py b/tensorrt_llm/_torch/models/modeling_speculative.py index 1849f1acf2..a22e196a93 100755 --- a/tensorrt_llm/_torch/models/modeling_speculative.py +++ b/tensorrt_llm/_torch/models/modeling_speculative.py @@ -482,9 +482,26 @@ class Eagle3ForCausalLM(DecoderModelForCausalLM[Eagle3DraftModel, ) def load_weights(self, weights: Dict, weight_mapper: BaseWeightMapper): + # Remap weight names: some Eagle3 checkpoints use "layers.X.*" naming convention + # while the model expects "midlayer.*" naming. Handle both formats. + import re + remapped_weights = {} + # Access num_layers from the inner draft model (self.model is Eagle3DraftModel) + num_layers = self.model.num_layers + for k, v in weights.items(): + new_k = k + # For single-layer models: "layers.0.*" -> "midlayer.*" + # For multi-layer models: "layers.X.*" -> "midlayer.X.*" + if num_layers == 1: + # Single layer: layers.0.foo -> midlayer.foo + new_k = re.sub(r'^layers\.0\.', 'midlayer.', new_k) + else: + # Multi-layer: layers.X.foo -> midlayer.X.foo + new_k = re.sub(r'^layers\.(\d+)\.', r'midlayer.\1.', new_k) + remapped_weights[new_k] = v new_weights = {} - for k, v in weights.items(): + for k, v in remapped_weights.items(): if 'lm_head' not in k: new_k = "model." + k else: