mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Cherry pick https://github.com/NVIDIA/TensorRT-LLM/pull/4447 (#4517)
fix: skip weights defined in create_weights for pp. (#4447) Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
parent
74928b55e9
commit
f8bd372c59
@ -152,9 +152,15 @@ def skip_forward(
|
||||
if hasattr(module, 'skip_forward'):
|
||||
module.forward = module.skip_forward
|
||||
remove_weights(module, ignore_modules)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Fail to skip forward since {module.__class__.__name__} "
|
||||
f"does not have `skip_forward`.")
|
||||
|
||||
|
||||
def forward_after_recv(forward_fn):
|
||||
if hasattr(forward_fn, "__wrapped_by_forward_after_recv__"):
|
||||
return forward_fn
|
||||
|
||||
def forward_after_recv_fn(
|
||||
position_ids,
|
||||
@ -176,10 +182,13 @@ def forward_after_recv(forward_fn):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
forward_after_recv_fn.__wrapped_by_forward_after_recv__ = True
|
||||
return forward_after_recv_fn
|
||||
|
||||
|
||||
def forward_before_send(forward_fn):
|
||||
if hasattr(forward_fn, "__wrapped_by_forward_before_send__"):
|
||||
return forward_fn
|
||||
|
||||
def forward_before_send_fn(
|
||||
position_ids,
|
||||
@ -204,6 +213,7 @@ def forward_before_send(forward_fn):
|
||||
pp_send(hidden_states)
|
||||
return output
|
||||
|
||||
forward_before_send_fn.__wrapped_by_forward_before_send__ = True
|
||||
return forward_before_send_fn
|
||||
|
||||
|
||||
@ -411,6 +421,8 @@ class DecoderModelForCausalLM(nn.Module,
|
||||
for module in self.epilogue:
|
||||
skip_forward(module)
|
||||
|
||||
self.model.__pp_init__()
|
||||
|
||||
def __post_init__(self):
|
||||
# 1. mixed precision
|
||||
quant_config_dict = self.model_config.quant_config_dict
|
||||
|
||||
Loading…
Reference in New Issue
Block a user