mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-06-06 00:04:50 +00:00
[fix] gradient accumulation step alignment
This commit is contained in:
@@ -91,7 +91,7 @@ def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_st
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
if step % args.accumulation_steps == 0:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
scaler.step(optimizer)
|
||||
|
||||
@@ -85,7 +85,7 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
if step % args.accumulation_steps == 0:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
scaler.step(optimizer)
|
||||
|
||||
@@ -36,7 +36,7 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
if step % args.accumulation_steps == 0:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
|
||||
|
||||
@@ -148,7 +148,7 @@ def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_token
|
||||
loss = (policy_loss + aux_loss) / args.accumulation_steps # scalar
|
||||
loss.backward()
|
||||
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
if step % args.accumulation_steps == 0:
|
||||
if args.grad_clip > 0:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
optimizer.step()
|
||||
|
||||
@@ -37,7 +37,7 @@ def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
if step % args.accumulation_steps == 0:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(lora_params, args.grad_clip)
|
||||
scaler.step(optimizer)
|
||||
|
||||
@@ -174,7 +174,7 @@ def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_sche
|
||||
loss = (policy_loss + args.vf_coef * value_loss + args.kl_coef * kl_ref + aux_loss) / args.accumulation_steps # scalar
|
||||
loss.backward()
|
||||
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
if step % args.accumulation_steps == 0:
|
||||
clip_grad_norm_(actor_model.parameters(), args.grad_clip)
|
||||
clip_grad_norm_(critic_model.parameters(), args.grad_clip)
|
||||
actor_optimizer.step()
|
||||
|
||||
@@ -36,7 +36,7 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
if step % args.accumulation_steps == 0:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ def train_epoch(epoch, loader, iters, tokenizer, lm_config, start_step=0, wandb=
|
||||
|
||||
scaler.scale(loss).backward()
|
||||
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
if step % args.accumulation_steps == 0:
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
scaler.step(optimizer)
|
||||
|
||||
@@ -191,7 +191,7 @@ def spo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokeni
|
||||
response_masks = completion_mask.float() # [B, R]
|
||||
rho = value_tracker.update(rewards, per_token_logps.detach(), response_masks)
|
||||
|
||||
if (step + 1) % args.accumulation_steps == 0:
|
||||
if step % args.accumulation_steps == 0:
|
||||
if args.grad_clip > 0:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
||||
optimizer.step()
|
||||
|
||||
Reference in New Issue
Block a user