From f8bd372c59c48fecd8906d199d893a6b94409b75 Mon Sep 17 00:00:00 2001 From: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Date: Wed, 21 May 2025 13:30:21 +0800 Subject: [PATCH] Cherry pick https://github.com/NVIDIA/TensorRT-LLM/pull/4447 (#4517) fix: skip weights defined in create_weights for pp. (#4447) Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> --- tensorrt_llm/_torch/models/modeling_utils.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tensorrt_llm/_torch/models/modeling_utils.py b/tensorrt_llm/_torch/models/modeling_utils.py index daf8138f38..63c6d962ea 100755 --- a/tensorrt_llm/_torch/models/modeling_utils.py +++ b/tensorrt_llm/_torch/models/modeling_utils.py @@ -152,9 +152,15 @@ def skip_forward( if hasattr(module, 'skip_forward'): module.forward = module.skip_forward remove_weights(module, ignore_modules) + else: + logger.warning( + f"Fail to skip forward since {module.__class__.__name__} " + f"does not have `skip_forward`.") def forward_after_recv(forward_fn): + if hasattr(forward_fn, "__wrapped_by_forward_after_recv__"): + return forward_fn def forward_after_recv_fn( position_ids, @@ -176,10 +182,13 @@ def forward_after_recv(forward_fn): **kwargs, ) + forward_after_recv_fn.__wrapped_by_forward_after_recv__ = True return forward_after_recv_fn def forward_before_send(forward_fn): + if hasattr(forward_fn, "__wrapped_by_forward_before_send__"): + return forward_fn def forward_before_send_fn( position_ids, @@ -204,6 +213,7 @@ def forward_before_send(forward_fn): pp_send(hidden_states) return output + forward_before_send_fn.__wrapped_by_forward_before_send__ = True return forward_before_send_fn @@ -411,6 +421,8 @@ class DecoderModelForCausalLM(nn.Module, for module in self.epilogue: skip_forward(module) + self.model.__pp_init__() + def __post_init__(self): # 1. mixed precision quant_config_dict = self.model_config.quant_config_dict