From 2acd03030ac08c72a5e154a82db1a2a47c91e6b8 Mon Sep 17 00:00:00 2001 From: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Date: Wed, 14 Jan 2026 09:36:35 +0800 Subject: [PATCH] [https://nvbugs/5781589][fix] Implement pp skip forward for all spec workers. (#10578) Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> --- .../_torch/models/modeling_deepseekv3.py | 2 -- tensorrt_llm/_torch/models/modeling_glm.py | 2 -- .../_torch/models/modeling_speculative.py | 2 ++ tensorrt_llm/_torch/models/modeling_utils.py | 5 +-- tensorrt_llm/_torch/speculative/interface.py | 31 ++++++++++++++++++ tensorrt_llm/_torch/speculative/mtp.py | 32 ------------------- 6 files changed, 36 insertions(+), 38 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index f475280f85..f5e3169fb8 100755 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -1721,8 +1721,6 @@ class DeepseekV3ForCausalLM(SpecDecOneEngineForCausalLM[DeepseekV3Model, self.model_config.quant_config.exclude_modules.extend( extend_exclude_modules) self.model.layers.extend(self.draft_model.mtp_layers) - self.epilogue.extend(self.draft_model.mtp_layers) - self.epilogue.append(self.spec_worker) # Undo any manipulations done to mapping. if self.mapping_with_cp is not None: diff --git a/tensorrt_llm/_torch/models/modeling_glm.py b/tensorrt_llm/_torch/models/modeling_glm.py index 94ae57cef9..3c97997352 100644 --- a/tensorrt_llm/_torch/models/modeling_glm.py +++ b/tensorrt_llm/_torch/models/modeling_glm.py @@ -982,8 +982,6 @@ class Glm4MoeForCausalLM(SpecDecOneEngineForCausalLM[Glm4Model, PretrainedConfig ) self.model_config.quant_config.exclude_modules.extend(extend_exclude_modules) self.model.layers.extend(self.draft_model.mtp_layers) - self.epilogue.extend(self.draft_model.mtp_layers) - self.epilogue.append(self.spec_worker) def forward( self, diff --git a/tensorrt_llm/_torch/models/modeling_speculative.py b/tensorrt_llm/_torch/models/modeling_speculative.py index dc4b3b1d54..06888f4422 100755 --- a/tensorrt_llm/_torch/models/modeling_speculative.py +++ b/tensorrt_llm/_torch/models/modeling_speculative.py @@ -918,6 +918,8 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig], self.spec_worker = get_spec_worker(model_config.spec_config, model_config, model_config.mapping) + self.epilogue.append(self.draft_model) + self.epilogue.append(self.spec_worker) if self.draft_config is not None and model_config.spec_config.eagle3_model_arch == "llama3": for key, value in self.draft_config.extra_attrs.items(): diff --git a/tensorrt_llm/_torch/models/modeling_utils.py b/tensorrt_llm/_torch/models/modeling_utils.py index 16ddb1812a..f577922fcb 100755 --- a/tensorrt_llm/_torch/models/modeling_utils.py +++ b/tensorrt_llm/_torch/models/modeling_utils.py @@ -157,6 +157,8 @@ def skip_forward( if hasattr(module, 'skip_forward'): module.forward = module.skip_forward remove_weights(module, ignore_modules) + elif isinstance(module, DecoderModelForCausalLM): + remove_weights(module, ignore_modules) else: logger.warning( f"Fail to skip forward since {module.__class__.__name__} " @@ -301,8 +303,7 @@ class DecoderModel(nn.Module, metaclass=PPInitCaller): assert num_hidden_layers >= mapping.pp_size, f"{num_hidden_layers} layers are not enough for PP{mapping.pp_size}" pp_layer_list = mapping.pp_layers(num_hidden_layers) has_pp_layer = len(pp_layer_list) > 0 - for layer_idx in range(num_hidden_layers): - layer = self.layers[layer_idx] + for layer_idx, layer in enumerate(self.layers): is_last_layer = (layer_idx == num_hidden_layers - 1) if layer_idx not in pp_layer_list: # keep next layer's input_layernorm's weights for fusion diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index 99e9468f0c..8dffa020c2 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -380,6 +380,37 @@ class SpecWorkerBase(nn.Module, ABC): Subclasses should override this property. """ + def skip_forward( + self, + input_ids, + position_ids, + hidden_states, + logits, + attn_metadata, + spec_metadata, + draft_model, + ): + batch_size = attn_metadata.num_seqs + accepted_tokens = torch.empty((batch_size, (self.max_draft_len + 1)), + dtype=torch.int, + device=logits.device) + num_accepted_tokens = torch.ones(batch_size, + dtype=torch.int, + device=logits.device) + next_draft_tokens = torch.empty((batch_size, self.max_draft_len), + dtype=torch.int, + device=logits.device) + next_new_tokens = torch.empty((batch_size, (self.max_draft_len + 1)), + dtype=torch.int, + device=logits.device) + return { + 'logits': logits, + 'new_tokens': accepted_tokens, + 'new_tokens_lens': num_accepted_tokens, + 'next_draft_tokens': next_draft_tokens, + 'next_new_tokens': next_new_tokens + } + def set_guided_decoder(self, guided_decoder: "CapturableGuidedDecoder") -> bool: self.guided_decoder = guided_decoder diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index 8d037275a0..1e67caf561 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -558,38 +558,6 @@ class MTPWorker(SpecWorkerBase): 'next_new_tokens': next_new_tokens } - def skip_forward( - self, - input_ids, - position_ids, - hidden_states, - logits, - attn_metadata, - spec_metadata, - draft_model, - ): - batch_size = attn_metadata.num_seqs - mtp_num_modules = self.spec_config.num_nextn_predict_layers - accepted_tokens = torch.empty((batch_size, (mtp_num_modules + 1)), - dtype=torch.int, - device=logits.device) - num_accepted_tokens = torch.ones(batch_size, - dtype=torch.int, - device=logits.device) - next_draft_tokens = torch.empty((batch_size, mtp_num_modules), - dtype=torch.int, - device=logits.device) - next_new_tokens = torch.empty((batch_size, (mtp_num_modules + 1)), - dtype=torch.int, - device=logits.device) - return { - 'logits': logits, - 'new_tokens': accepted_tokens, - 'new_tokens_lens': num_accepted_tokens, - 'next_draft_tokens': next_draft_tokens, - 'next_new_tokens': next_new_tokens - } - def update_mtp_hidden_states( self, input_ids: torch.IntTensor,