[fix] graph-oom & ddp-pos_cis

This commit is contained in:
jingyaogong 2025-10-23 14:22:13 +08:00
parent ce693f8e7f
commit fa7dff8291
9 changed files with 13 additions and 13 deletions

View File

@ -89,7 +89,7 @@ def train_epoch(epoch, wandb):
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
wandb.log({"loss": loss * args.accumulation_steps,
wandb.log({"loss": loss.item() * args.accumulation_steps,
"lr": optimizer.param_groups[-1]['lr'],
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
@ -208,7 +208,7 @@ if __name__ == "__main__":
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
if ddp:
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
iter_per_epoch = len(train_loader)

View File

@ -257,7 +257,7 @@ if __name__ == "__main__":
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
if ddp:
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
iter_per_epoch = len(train_loader)

View File

@ -113,7 +113,7 @@ def train_epoch(epoch, wandb):
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
wandb.log({"loss": loss * args.accumulation_steps,
wandb.log({"loss": loss.item() * args.accumulation_steps,
"lr": optimizer.param_groups[-1]['lr'],
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
@ -240,7 +240,7 @@ if __name__ == "__main__":
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
if ddp:
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
iter_per_epoch = len(train_loader)

View File

@ -76,7 +76,7 @@ def train_epoch(epoch, wandb):
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
wandb.log({"loss": loss * args.accumulation_steps,
wandb.log({"loss": loss.item() * args.accumulation_steps,
"lr": optimizer.param_groups[-1]['lr'],
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})
@ -194,7 +194,7 @@ if __name__ == "__main__":
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
if ddp:
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
iter_per_epoch = len(train_loader)

View File

@ -309,7 +309,7 @@ if __name__ == "__main__":
scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
if ddp:
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
for epoch in range(args.epochs):

View File

@ -77,7 +77,7 @@ def train_epoch(epoch, wandb):
spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60))
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
wandb.log({"loss": loss * args.accumulation_steps,
wandb.log({"loss": loss.item() * args.accumulation_steps,
"lr": optimizer.param_groups[-1]['lr'],
"epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60})

View File

@ -359,8 +359,8 @@ if __name__ == "__main__":
# 如果使用分布式训练,包装模型
if ddp:
actor_model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
critic_model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
actor_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
critic_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
actor_model = DistributedDataParallel(actor_model, device_ids=[ddp_local_rank])
critic_model = DistributedDataParallel(critic_model, device_ids=[ddp_local_rank])
# old_actor_model 不需要DDP包装因为它只在主进程上用于计算并且不进行梯度更新

View File

@ -192,7 +192,7 @@ if __name__ == "__main__":
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
if ddp:
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
iter_per_epoch = len(train_loader)

View File

@ -358,7 +358,7 @@ if __name__ == "__main__":
scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10)
if ddp:
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"}
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
value_tracker = AutoAdaptiveValueTracker(rho_mode='kl', rho_const=0.9, D_half=0.06, clip_lower=0.5, clip_upper=0.96)