mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 02:02:01 +08:00
[https://nvbugs/5781589][fix] Implement pp skip forward for all spec workers. (#10578)
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
parent
bc119f5644
commit
2acd03030a
@ -1721,8 +1721,6 @@ class DeepseekV3ForCausalLM(SpecDecOneEngineForCausalLM[DeepseekV3Model,
|
||||
self.model_config.quant_config.exclude_modules.extend(
|
||||
extend_exclude_modules)
|
||||
self.model.layers.extend(self.draft_model.mtp_layers)
|
||||
self.epilogue.extend(self.draft_model.mtp_layers)
|
||||
self.epilogue.append(self.spec_worker)
|
||||
|
||||
# Undo any manipulations done to mapping.
|
||||
if self.mapping_with_cp is not None:
|
||||
|
||||
@ -982,8 +982,6 @@ class Glm4MoeForCausalLM(SpecDecOneEngineForCausalLM[Glm4Model, PretrainedConfig
|
||||
)
|
||||
self.model_config.quant_config.exclude_modules.extend(extend_exclude_modules)
|
||||
self.model.layers.extend(self.draft_model.mtp_layers)
|
||||
self.epilogue.extend(self.draft_model.mtp_layers)
|
||||
self.epilogue.append(self.spec_worker)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@ -918,6 +918,8 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig],
|
||||
self.spec_worker = get_spec_worker(model_config.spec_config,
|
||||
model_config,
|
||||
model_config.mapping)
|
||||
self.epilogue.append(self.draft_model)
|
||||
self.epilogue.append(self.spec_worker)
|
||||
|
||||
if self.draft_config is not None and model_config.spec_config.eagle3_model_arch == "llama3":
|
||||
for key, value in self.draft_config.extra_attrs.items():
|
||||
|
||||
@ -157,6 +157,8 @@ def skip_forward(
|
||||
if hasattr(module, 'skip_forward'):
|
||||
module.forward = module.skip_forward
|
||||
remove_weights(module, ignore_modules)
|
||||
elif isinstance(module, DecoderModelForCausalLM):
|
||||
remove_weights(module, ignore_modules)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Fail to skip forward since {module.__class__.__name__} "
|
||||
@ -301,8 +303,7 @@ class DecoderModel(nn.Module, metaclass=PPInitCaller):
|
||||
assert num_hidden_layers >= mapping.pp_size, f"{num_hidden_layers} layers are not enough for PP{mapping.pp_size}"
|
||||
pp_layer_list = mapping.pp_layers(num_hidden_layers)
|
||||
has_pp_layer = len(pp_layer_list) > 0
|
||||
for layer_idx in range(num_hidden_layers):
|
||||
layer = self.layers[layer_idx]
|
||||
for layer_idx, layer in enumerate(self.layers):
|
||||
is_last_layer = (layer_idx == num_hidden_layers - 1)
|
||||
if layer_idx not in pp_layer_list:
|
||||
# keep next layer's input_layernorm's weights for fusion
|
||||
|
||||
@ -380,6 +380,37 @@ class SpecWorkerBase(nn.Module, ABC):
|
||||
Subclasses should override this property.
|
||||
"""
|
||||
|
||||
def skip_forward(
|
||||
self,
|
||||
input_ids,
|
||||
position_ids,
|
||||
hidden_states,
|
||||
logits,
|
||||
attn_metadata,
|
||||
spec_metadata,
|
||||
draft_model,
|
||||
):
|
||||
batch_size = attn_metadata.num_seqs
|
||||
accepted_tokens = torch.empty((batch_size, (self.max_draft_len + 1)),
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
num_accepted_tokens = torch.ones(batch_size,
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
next_draft_tokens = torch.empty((batch_size, self.max_draft_len),
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
next_new_tokens = torch.empty((batch_size, (self.max_draft_len + 1)),
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
return {
|
||||
'logits': logits,
|
||||
'new_tokens': accepted_tokens,
|
||||
'new_tokens_lens': num_accepted_tokens,
|
||||
'next_draft_tokens': next_draft_tokens,
|
||||
'next_new_tokens': next_new_tokens
|
||||
}
|
||||
|
||||
def set_guided_decoder(self,
|
||||
guided_decoder: "CapturableGuidedDecoder") -> bool:
|
||||
self.guided_decoder = guided_decoder
|
||||
|
||||
@ -558,38 +558,6 @@ class MTPWorker(SpecWorkerBase):
|
||||
'next_new_tokens': next_new_tokens
|
||||
}
|
||||
|
||||
def skip_forward(
|
||||
self,
|
||||
input_ids,
|
||||
position_ids,
|
||||
hidden_states,
|
||||
logits,
|
||||
attn_metadata,
|
||||
spec_metadata,
|
||||
draft_model,
|
||||
):
|
||||
batch_size = attn_metadata.num_seqs
|
||||
mtp_num_modules = self.spec_config.num_nextn_predict_layers
|
||||
accepted_tokens = torch.empty((batch_size, (mtp_num_modules + 1)),
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
num_accepted_tokens = torch.ones(batch_size,
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
next_draft_tokens = torch.empty((batch_size, mtp_num_modules),
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
next_new_tokens = torch.empty((batch_size, (mtp_num_modules + 1)),
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
return {
|
||||
'logits': logits,
|
||||
'new_tokens': accepted_tokens,
|
||||
'new_tokens_lens': num_accepted_tokens,
|
||||
'next_draft_tokens': next_draft_tokens,
|
||||
'next_new_tokens': next_new_tokens
|
||||
}
|
||||
|
||||
def update_mtp_hidden_states(
|
||||
self,
|
||||
input_ids: torch.IntTensor,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user