From 2d5ebb3fe87b8027dfe4b392a51359a8487b7b33 Mon Sep 17 00:00:00 2001 From: Harris Nover <249353502+hnover-nv@users.noreply.github.com> Date: Wed, 11 Feb 2026 10:01:36 -0700 Subject: [PATCH] [None][chore] Merge residual+hidden into layer norm at the end of each NemotronH MTP, and remove a % operation (#11406) Signed-off-by: Harris Nover <249353502+hnover-nv@users.noreply.github.com> --- tensorrt_llm/_torch/models/modeling_nemotron_h.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_nemotron_h.py b/tensorrt_llm/_torch/models/modeling_nemotron_h.py index ee77603ba5..e50c67991c 100644 --- a/tensorrt_llm/_torch/models/modeling_nemotron_h.py +++ b/tensorrt_llm/_torch/models/modeling_nemotron_h.py @@ -658,11 +658,10 @@ class NemotronHMTPDecoderLayer(NemotronHLayer): ) if self.has_end_norm: - if residual is not None: - hidden_states = hidden_states + residual - residual = None - - hidden_states = self.final_layernorm(hidden_states) + hidden_states, residual = self.final_layernorm( + hidden_states, residual) + # The last step, so don't forward the residual. + residual = None return hidden_states, residual @@ -690,9 +689,7 @@ class NemotronHMTP(nn.Module): # Build pattern-based layers self.layers = nn.ModuleDict() - for i in range(self.pattern_len): - step_rel_idx = i % self.pattern_len - + for step_rel_idx in range(self.pattern_len): char = self.pattern_str[step_rel_idx] is_start_of_step = step_rel_idx == 0 @@ -710,7 +707,7 @@ class NemotronHMTP(nn.Module): skip_create_weights_in_init, ) - self.layers[str(i)] = NemotronHMTPDecoderLayer( + self.layers[str(step_rel_idx)] = NemotronHMTPDecoderLayer( model_config=sublayer_model_config, layer_idx=self.layer_idx, aux_stream_dict=aux_stream_dict,