[https://nvbugs/5781589][fix] Implement pp skip forward for all spec workers. (#10578)

Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
Yuxian Qiu 2026-01-14 09:36:35 +08:00 committed by GitHub
parent bc119f5644
commit 2acd03030a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 36 additions and 38 deletions

View File

@ -1721,8 +1721,6 @@ class DeepseekV3ForCausalLM(SpecDecOneEngineForCausalLM[DeepseekV3Model,
self.model_config.quant_config.exclude_modules.extend(
extend_exclude_modules)
self.model.layers.extend(self.draft_model.mtp_layers)
self.epilogue.extend(self.draft_model.mtp_layers)
self.epilogue.append(self.spec_worker)
# Undo any manipulations done to mapping.
if self.mapping_with_cp is not None:

View File

@ -982,8 +982,6 @@ class Glm4MoeForCausalLM(SpecDecOneEngineForCausalLM[Glm4Model, PretrainedConfig
)
self.model_config.quant_config.exclude_modules.extend(extend_exclude_modules)
self.model.layers.extend(self.draft_model.mtp_layers)
self.epilogue.extend(self.draft_model.mtp_layers)
self.epilogue.append(self.spec_worker)
def forward(
self,

View File

@ -918,6 +918,8 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig],
self.spec_worker = get_spec_worker(model_config.spec_config,
model_config,
model_config.mapping)
self.epilogue.append(self.draft_model)
self.epilogue.append(self.spec_worker)
if self.draft_config is not None and model_config.spec_config.eagle3_model_arch == "llama3":
for key, value in self.draft_config.extra_attrs.items():

View File

@ -157,6 +157,8 @@ def skip_forward(
if hasattr(module, 'skip_forward'):
module.forward = module.skip_forward
remove_weights(module, ignore_modules)
elif isinstance(module, DecoderModelForCausalLM):
remove_weights(module, ignore_modules)
else:
logger.warning(
f"Fail to skip forward since {module.__class__.__name__} "
@ -301,8 +303,7 @@ class DecoderModel(nn.Module, metaclass=PPInitCaller):
assert num_hidden_layers >= mapping.pp_size, f"{num_hidden_layers} layers are not enough for PP{mapping.pp_size}"
pp_layer_list = mapping.pp_layers(num_hidden_layers)
has_pp_layer = len(pp_layer_list) > 0
for layer_idx in range(num_hidden_layers):
layer = self.layers[layer_idx]
for layer_idx, layer in enumerate(self.layers):
is_last_layer = (layer_idx == num_hidden_layers - 1)
if layer_idx not in pp_layer_list:
# keep next layer's input_layernorm's weights for fusion

View File

@ -380,6 +380,37 @@ class SpecWorkerBase(nn.Module, ABC):
Subclasses should override this property.
"""
def skip_forward(
self,
input_ids,
position_ids,
hidden_states,
logits,
attn_metadata,
spec_metadata,
draft_model,
):
batch_size = attn_metadata.num_seqs
accepted_tokens = torch.empty((batch_size, (self.max_draft_len + 1)),
dtype=torch.int,
device=logits.device)
num_accepted_tokens = torch.ones(batch_size,
dtype=torch.int,
device=logits.device)
next_draft_tokens = torch.empty((batch_size, self.max_draft_len),
dtype=torch.int,
device=logits.device)
next_new_tokens = torch.empty((batch_size, (self.max_draft_len + 1)),
dtype=torch.int,
device=logits.device)
return {
'logits': logits,
'new_tokens': accepted_tokens,
'new_tokens_lens': num_accepted_tokens,
'next_draft_tokens': next_draft_tokens,
'next_new_tokens': next_new_tokens
}
def set_guided_decoder(self,
guided_decoder: "CapturableGuidedDecoder") -> bool:
self.guided_decoder = guided_decoder

View File

@ -558,38 +558,6 @@ class MTPWorker(SpecWorkerBase):
'next_new_tokens': next_new_tokens
}
def skip_forward(
self,
input_ids,
position_ids,
hidden_states,
logits,
attn_metadata,
spec_metadata,
draft_model,
):
batch_size = attn_metadata.num_seqs
mtp_num_modules = self.spec_config.num_nextn_predict_layers
accepted_tokens = torch.empty((batch_size, (mtp_num_modules + 1)),
dtype=torch.int,
device=logits.device)
num_accepted_tokens = torch.ones(batch_size,
dtype=torch.int,
device=logits.device)
next_draft_tokens = torch.empty((batch_size, mtp_num_modules),
dtype=torch.int,
device=logits.device)
next_new_tokens = torch.empty((batch_size, (mtp_num_modules + 1)),
dtype=torch.int,
device=logits.device)
return {
'logits': logits,
'new_tokens': accepted_tokens,
'new_tokens_lens': num_accepted_tokens,
'next_draft_tokens': next_draft_tokens,
'next_new_tokens': next_new_tokens
}
def update_mtp_hidden_states(
self,
input_ids: torch.IntTensor,