fix: skip weights defined in create_weights for pp. (#4447)

Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
Yuxian Qiu 2025-05-21 13:30:21 +08:00 committed by GitHub
parent 74928b55e9
commit f8bd372c59
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -152,9 +152,15 @@ def skip_forward(
if hasattr(module, 'skip_forward'):
module.forward = module.skip_forward
remove_weights(module, ignore_modules)
else:
logger.warning(
f"Fail to skip forward since {module.__class__.__name__} "
f"does not have `skip_forward`.")
def forward_after_recv(forward_fn):
if hasattr(forward_fn, "__wrapped_by_forward_after_recv__"):
return forward_fn
def forward_after_recv_fn(
position_ids,
@ -176,10 +182,13 @@ def forward_after_recv(forward_fn):
**kwargs,
)
forward_after_recv_fn.__wrapped_by_forward_after_recv__ = True
return forward_after_recv_fn
def forward_before_send(forward_fn):
if hasattr(forward_fn, "__wrapped_by_forward_before_send__"):
return forward_fn
def forward_before_send_fn(
position_ids,
@ -204,6 +213,7 @@ def forward_before_send(forward_fn):
pp_send(hidden_states)
return output
forward_before_send_fn.__wrapped_by_forward_before_send__ = True
return forward_before_send_fn
@ -411,6 +421,8 @@ class DecoderModelForCausalLM(nn.Module,
for module in self.epilogue:
skip_forward(module)
self.model.__pp_init__()
def __post_init__(self):
# 1. mixed precision
quant_config_dict = self.model_config.quant_config_dict