Merge pull request #698 from readlnh/master

[fix] 修复训练脚本中 1-indexed step 与 0-indexed 逻辑混用的问题
This commit is contained in:
jingyaogong
2026-03-24 13:41:20 +08:00
committed by GitHub
9 changed files with 30 additions and 30 deletions
+4 -4
View File
@@ -91,20 +91,20 @@ def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_st
scaler.scale(loss).backward()
if (step + 1) % args.accumulation_steps == 0:
if step % args.accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
if step % args.log_interval == 0 or step == iters - 1:
if step % args.log_interval == 0 or step == iters:
spend_time = time.time() - start_time
current_loss = loss.item() * args.accumulation_steps
current_ce_loss = ce_loss_raw.item()
current_aux_loss = res.aux_loss.item() if lm_config_student.use_moe else 0.0
current_lr = optimizer.param_groups[-1]['lr']
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
eta_min = spend_time / step * iters // 60 - spend_time // 60
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, ce: {current_ce_loss:.4f}, aux_loss: {current_aux_loss:.4f}, distill: {distill_loss.item():.4f}, learning_rate: {current_lr:.8f}, epoch_time: {eta_min:.3f}min')
@@ -118,7 +118,7 @@ def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_st
"epoch_time": eta_min
})
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
if (step % args.save_interval == 0 or step == iters) and is_main_process():
model.eval()
moe_suffix = '_moe' if lm_config_student.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config_student.hidden_size}{moe_suffix}.pth'
+4 -4
View File
@@ -85,26 +85,26 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=
scaler.scale(loss).backward()
if (step + 1) % args.accumulation_steps == 0:
if step % args.accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
if step % args.log_interval == 0 or step == iters - 1:
if step % args.log_interval == 0 or step == iters:
spend_time = time.time() - start_time
current_loss = loss.item() * args.accumulation_steps
current_dpo_loss = dpo_loss_val.item()
current_aux_loss = outputs.aux_loss.item()
current_lr = optimizer.param_groups[-1]['lr']
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
eta_min = spend_time / step * iters // 60 - spend_time // 60
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, dpo_loss: {current_dpo_loss:.4f}, aux_loss: {current_aux_loss:.4f}, learning_rate: {current_lr:.8f}, epoch_time: {eta_min:.3f}min')
if wandb: wandb.log({"loss": current_loss, "dpo_loss": current_dpo_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
if (step % args.save_interval == 0 or step == iters) and is_main_process():
model.eval()
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
+4 -4
View File
@@ -36,7 +36,7 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
scaler.scale(loss).backward()
if (step + 1) % args.accumulation_steps == 0:
if step % args.accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
@@ -45,17 +45,17 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
optimizer.zero_grad(set_to_none=True)
if step % args.log_interval == 0 or step == iters - 1:
if step % args.log_interval == 0 or step == iters:
spend_time = time.time() - start_time
current_loss = loss.item() * args.accumulation_steps
current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0
current_logits_loss = current_loss - current_aux_loss
current_lr = optimizer.param_groups[-1]['lr']
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
eta_min = spend_time / step * iters // 60 - spend_time // 60
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min')
if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
if (step % args.save_interval == 0 or step == iters) and is_main_process():
model.eval()
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
+2 -2
View File
@@ -148,7 +148,7 @@ def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_token
loss = (policy_loss + aux_loss) / args.accumulation_steps # scalar
loss.backward()
if (step + 1) % args.accumulation_steps == 0:
if step % args.accumulation_steps == 0:
if args.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step()
@@ -176,7 +176,7 @@ def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_token
"learning_rate": current_lr
})
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
if (step % args.save_interval == 0 or step == iters) and is_main_process():
model.eval()
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
+4 -4
View File
@@ -37,24 +37,24 @@ def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
scaler.scale(loss).backward()
if (step + 1) % args.accumulation_steps == 0:
if step % args.accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(lora_params, args.grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
if step % args.log_interval == 0 or step == iters - 1:
if step % args.log_interval == 0 or step == iters:
spend_time = time.time() - start_time
current_loss = loss.item() * args.accumulation_steps
current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0
current_logits_loss = current_loss - current_aux_loss
current_lr = optimizer.param_groups[-1]['lr']
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
eta_min = spend_time / step * iters // 60 - spend_time // 60
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min')
if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
if (step % args.save_interval == 0 or step == iters) and is_main_process():
model.eval()
lora_save_path = f'{args.save_dir}/{args.lora_name}_{lm_config.hidden_size}.pth'
# LoRA只保存LoRA权重
+2 -2
View File
@@ -174,7 +174,7 @@ def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_sche
loss = (policy_loss + args.vf_coef * value_loss + args.kl_coef * kl_ref + aux_loss) / args.accumulation_steps # scalar
loss.backward()
if (step + 1) % args.accumulation_steps == 0:
if step % args.accumulation_steps == 0:
clip_grad_norm_(actor_model.parameters(), args.grad_clip)
clip_grad_norm_(critic_model.parameters(), args.grad_clip)
actor_optimizer.step()
@@ -226,7 +226,7 @@ def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_sche
old_actor_model.load_state_dict({k: v.detach().cpu() for k, v in state_dict.items()})
old_actor_model.to(args.device)
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
if (step % args.save_interval == 0 or step == iters) and is_main_process():
actor_model.eval()
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
+4 -4
View File
@@ -36,7 +36,7 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
scaler.scale(loss).backward()
if (step + 1) % args.accumulation_steps == 0:
if step % args.accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
@@ -45,17 +45,17 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
optimizer.zero_grad(set_to_none=True)
if step % args.log_interval == 0 or step == iters - 1:
if step % args.log_interval == 0 or step == iters:
spend_time = time.time() - start_time
current_loss = loss.item() * args.accumulation_steps
current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0
current_logits_loss = current_loss - current_aux_loss
current_lr = optimizer.param_groups[-1]['lr']
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
eta_min = spend_time / step * iters // 60 - spend_time // 60
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min')
if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
if (step % args.save_interval == 0 or step == iters) and is_main_process():
model.eval()
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
+4 -4
View File
@@ -56,24 +56,24 @@ def train_epoch(epoch, loader, iters, tokenizer, lm_config, start_step=0, wandb=
scaler.scale(loss).backward()
if (step + 1) % args.accumulation_steps == 0:
if step % args.accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
if step % args.log_interval == 0 or step == iters - 1:
if step % args.log_interval == 0 or step == iters:
spend_time = time.time() - start_time
current_loss = loss.item() * args.accumulation_steps
current_aux_loss = res.aux_loss.item() if res.aux_loss is not None else 0.0
current_logits_loss = logits_loss.item()
current_lr = optimizer.param_groups[-1]['lr']
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
eta_min = spend_time / step * iters // 60 - spend_time // 60
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, lr: {current_lr:.8f}, epoch_time: {eta_min:.1f}min')
if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
if (step % args.save_interval == 0 or step == iters) and is_main_process():
model.eval()
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'
+2 -2
View File
@@ -191,7 +191,7 @@ def spo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokeni
response_masks = completion_mask.float() # [B, R]
rho = value_tracker.update(rewards, per_token_logps.detach(), response_masks)
if (step + 1) % args.accumulation_steps == 0:
if step % args.accumulation_steps == 0:
if args.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
optimizer.step()
@@ -224,7 +224,7 @@ def spo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokeni
"learning_rate": current_lr
})
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
if (step % args.save_interval == 0 or step == iters) and is_main_process():
model.eval()
moe_suffix = '_moe' if lm_config.use_moe else ''
ckp = f'{args.save_dir}/{args.save_weight}_{lm_config.hidden_size}{moe_suffix}.pth'