diff --git a/trainer/train_distill_reason.py b/trainer/train_distill_reason.py index 7bd6bce..ab7866c 100644 --- a/trainer/train_distill_reason.py +++ b/trainer/train_distill_reason.py @@ -89,7 +89,7 @@ def train_epoch(epoch, wandb): spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60)) if (wandb is not None) and (not ddp or dist.get_rank() == 0): - wandb.log({"loss": loss * args.accumulation_steps, + wandb.log({"loss": loss.item() * args.accumulation_steps, "lr": optimizer.param_groups[-1]['lr'], "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60}) @@ -208,7 +208,7 @@ if __name__ == "__main__": optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) if ddp: - model._ddp_params_and_buffers_to_ignore = {"pos_cis"} + model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} model = DistributedDataParallel(model, device_ids=[ddp_local_rank]) iter_per_epoch = len(train_loader) diff --git a/trainer/train_distillation.py b/trainer/train_distillation.py index cccfee1..f4e2d41 100644 --- a/trainer/train_distillation.py +++ b/trainer/train_distillation.py @@ -257,7 +257,7 @@ if __name__ == "__main__": optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) if ddp: - model._ddp_params_and_buffers_to_ignore = {"pos_cis"} + model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} model = DistributedDataParallel(model, device_ids=[ddp_local_rank]) iter_per_epoch = len(train_loader) diff --git a/trainer/train_dpo.py b/trainer/train_dpo.py index db17164..c4acf36 100644 --- a/trainer/train_dpo.py +++ b/trainer/train_dpo.py @@ -113,7 +113,7 @@ def train_epoch(epoch, wandb): spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60)) if (wandb is not None) and (not ddp or dist.get_rank() == 0): - wandb.log({"loss": loss * args.accumulation_steps, + wandb.log({"loss": loss.item() * args.accumulation_steps, "lr": optimizer.param_groups[-1]['lr'], "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60}) @@ -240,7 +240,7 @@ if __name__ == "__main__": optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) if ddp: - model._ddp_params_and_buffers_to_ignore = {"pos_cis"} + model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} model = DistributedDataParallel(model, device_ids=[ddp_local_rank]) iter_per_epoch = len(train_loader) diff --git a/trainer/train_full_sft.py b/trainer/train_full_sft.py index 49a3ff0..58f87aa 100644 --- a/trainer/train_full_sft.py +++ b/trainer/train_full_sft.py @@ -76,7 +76,7 @@ def train_epoch(epoch, wandb): spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60)) if (wandb is not None) and (not ddp or dist.get_rank() == 0): - wandb.log({"loss": loss * args.accumulation_steps, + wandb.log({"loss": loss.item() * args.accumulation_steps, "lr": optimizer.param_groups[-1]['lr'], "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60}) @@ -194,7 +194,7 @@ if __name__ == "__main__": optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) if ddp: - model._ddp_params_and_buffers_to_ignore = {"pos_cis"} + model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} model = DistributedDataParallel(model, device_ids=[ddp_local_rank]) iter_per_epoch = len(train_loader) diff --git a/trainer/train_grpo.py b/trainer/train_grpo.py index a3e943c..0402f41 100755 --- a/trainer/train_grpo.py +++ b/trainer/train_grpo.py @@ -309,7 +309,7 @@ if __name__ == "__main__": scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10) if ddp: - model._ddp_params_and_buffers_to_ignore = {"pos_cis"} + model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} model = DistributedDataParallel(model, device_ids=[ddp_local_rank]) for epoch in range(args.epochs): diff --git a/trainer/train_lora.py b/trainer/train_lora.py index 7d74b9f..5da1607 100644 --- a/trainer/train_lora.py +++ b/trainer/train_lora.py @@ -77,7 +77,7 @@ def train_epoch(epoch, wandb): spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60)) if (wandb is not None) and (not ddp or dist.get_rank() == 0): - wandb.log({"loss": loss * args.accumulation_steps, + wandb.log({"loss": loss.item() * args.accumulation_steps, "lr": optimizer.param_groups[-1]['lr'], "epoch_Time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60}) diff --git a/trainer/train_ppo.py b/trainer/train_ppo.py index dd67775..1132da9 100644 --- a/trainer/train_ppo.py +++ b/trainer/train_ppo.py @@ -359,8 +359,8 @@ if __name__ == "__main__": # 如果使用分布式训练,包装模型 if ddp: - actor_model._ddp_params_and_buffers_to_ignore = {"pos_cis"} - critic_model._ddp_params_and_buffers_to_ignore = {"pos_cis"} + actor_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} + critic_model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} actor_model = DistributedDataParallel(actor_model, device_ids=[ddp_local_rank]) critic_model = DistributedDataParallel(critic_model, device_ids=[ddp_local_rank]) # old_actor_model 不需要DDP包装,因为它只在主进程上用于计算,并且不进行梯度更新 diff --git a/trainer/train_pretrain.py b/trainer/train_pretrain.py index 6eed5a7..4166db5 100644 --- a/trainer/train_pretrain.py +++ b/trainer/train_pretrain.py @@ -192,7 +192,7 @@ if __name__ == "__main__": optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) if ddp: - model._ddp_params_and_buffers_to_ignore = {"pos_cis"} + model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} model = DistributedDataParallel(model, device_ids=[ddp_local_rank]) iter_per_epoch = len(train_loader) diff --git a/trainer/train_spo.py b/trainer/train_spo.py index 7e44d3a..8db67b9 100755 --- a/trainer/train_spo.py +++ b/trainer/train_spo.py @@ -358,7 +358,7 @@ if __name__ == "__main__": scheduler = CosineAnnealingLR(optimizer, T_max=total_optimizer_steps, eta_min=args.learning_rate / 10) if ddp: - model._ddp_params_and_buffers_to_ignore = {"pos_cis"} + model._ddp_params_and_buffers_to_ignore = {"freqs_cos", "freqs_sin"} model = DistributedDataParallel(model, device_ids=[ddp_local_rank]) value_tracker = AutoAdaptiveValueTracker(rho_mode='kl', rho_const=0.9, D_half=0.06, clip_lower=0.5, clip_upper=0.96)