mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-01-13 19:57:20 +08:00
[update] aux loss
This commit is contained in:
parent
c65335b56f
commit
9d898576ac
@ -284,7 +284,7 @@ class MoEGate(nn.Module):
|
|||||||
fi = ce * self.n_routed_experts
|
fi = ce * self.n_routed_experts
|
||||||
aux_loss = (Pi * fi).sum() * self.alpha
|
aux_loss = (Pi * fi).sum() * self.alpha
|
||||||
else:
|
else:
|
||||||
aux_loss = 0
|
aux_loss = scores.new_zeros(1).squeeze()
|
||||||
return topk_idx, topk_weight, aux_loss
|
return topk_idx, topk_weight, aux_loss
|
||||||
|
|
||||||
|
|
||||||
@ -423,12 +423,7 @@ class MiniMindModel(nn.Module):
|
|||||||
|
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
aux_loss = sum(
|
aux_loss = sum([l.mlp.aux_loss for l in self.layers if isinstance(l.mlp, MOEFeedForward)], hidden_states.new_zeros(1).squeeze())
|
||||||
layer.mlp.aux_loss
|
|
||||||
for layer in self.layers
|
|
||||||
if isinstance(layer.mlp, MOEFeedForward)
|
|
||||||
)
|
|
||||||
|
|
||||||
return hidden_states, presents, aux_loss
|
return hidden_states, presents, aux_loss
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -53,8 +53,8 @@ def train_epoch(epoch, loader, iters, tokenizer, lm_config, start_step=0, wandb=
|
|||||||
loss_mask_sum = loss_mask.sum()
|
loss_mask_sum = loss_mask.sum()
|
||||||
loss_mask[sp_ids] = 10 # 对思考标签增加10倍权重
|
loss_mask[sp_ids] = 10 # 对思考标签增加10倍权重
|
||||||
loss_mask = loss_mask.view(Y.size())
|
loss_mask = loss_mask.view(Y.size())
|
||||||
loss = (loss * loss_mask).sum() / loss_mask_sum
|
logits_loss = (loss * loss_mask).sum() / loss_mask_sum
|
||||||
loss += res.aux_loss
|
loss = logits_loss + res.aux_loss
|
||||||
loss = loss / args.accumulation_steps
|
loss = loss / args.accumulation_steps
|
||||||
|
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
@ -69,12 +69,14 @@ def train_epoch(epoch, loader, iters, tokenizer, lm_config, start_step=0, wandb=
|
|||||||
if step % args.log_interval == 0 or step == iters - 1:
|
if step % args.log_interval == 0 or step == iters - 1:
|
||||||
spend_time = time.time() - start_time
|
spend_time = time.time() - start_time
|
||||||
current_loss = loss.item() * args.accumulation_steps
|
current_loss = loss.item() * args.accumulation_steps
|
||||||
|
current_logits_loss = logits_loss.item()
|
||||||
|
current_aux_loss = res.aux_loss.item()
|
||||||
current_lr = optimizer.param_groups[-1]['lr']
|
current_lr = optimizer.param_groups[-1]['lr']
|
||||||
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
|
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
|
||||||
|
|
||||||
Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:')
|
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, learning_rate: {current_lr:.8f}, epoch_time: {eta_min:.3f}min')
|
||||||
|
|
||||||
if wandb: wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min})
|
if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
|
||||||
|
|
||||||
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|||||||
@ -71,9 +71,9 @@ def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_st
|
|||||||
ignore_index=0,
|
ignore_index=0,
|
||||||
reduction='none'
|
reduction='none'
|
||||||
)
|
)
|
||||||
ce_loss = torch.sum(ce_loss * loss_mask_flat) / loss_mask_flat.sum()
|
ce_loss_raw = torch.sum(ce_loss * loss_mask_flat) / loss_mask_flat.sum()
|
||||||
if lm_config_student.use_moe:
|
if lm_config_student.use_moe: ce_loss = ce_loss_raw + res.aux_loss
|
||||||
ce_loss += res.aux_loss
|
else: ce_loss = ce_loss_raw
|
||||||
|
|
||||||
# 2) Distillation Loss
|
# 2) Distillation Loss
|
||||||
if teacher_model is not None:
|
if teacher_model is not None:
|
||||||
@ -100,18 +100,21 @@ def train_epoch(epoch, loader, iters, teacher_model, lm_config_student, start_st
|
|||||||
if step % args.log_interval == 0 or step == iters - 1:
|
if step % args.log_interval == 0 or step == iters - 1:
|
||||||
spend_time = time.time() - start_time
|
spend_time = time.time() - start_time
|
||||||
current_loss = loss.item() * args.accumulation_steps
|
current_loss = loss.item() * args.accumulation_steps
|
||||||
|
current_ce_loss = ce_loss_raw.item()
|
||||||
|
current_aux_loss = res.aux_loss.item() if lm_config_student.use_moe else 0.0
|
||||||
current_lr = optimizer.param_groups[-1]['lr']
|
current_lr = optimizer.param_groups[-1]['lr']
|
||||||
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
|
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
|
||||||
|
|
||||||
Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} ce:{ce_loss.item():.4f} distill:{distill_loss.item():.4f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:')
|
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, ce: {current_ce_loss:.4f}, aux_loss: {current_aux_loss:.4f}, distill: {distill_loss.item():.4f}, learning_rate: {current_lr:.8f}, epoch_time: {eta_min:.3f}min')
|
||||||
|
|
||||||
if wandb:
|
if wandb:
|
||||||
wandb.log({
|
wandb.log({
|
||||||
"loss": current_loss,
|
"loss": current_loss,
|
||||||
"ce_loss": ce_loss.item(),
|
"ce_loss": current_ce_loss,
|
||||||
|
"aux_loss": current_aux_loss,
|
||||||
"distill_loss": distill_loss.item() if teacher_model is not None else 0.0,
|
"distill_loss": distill_loss.item() if teacher_model is not None else 0.0,
|
||||||
"lr": current_lr,
|
"learning_rate": current_lr,
|
||||||
"epoch_Time": eta_min
|
"epoch_time": eta_min
|
||||||
})
|
})
|
||||||
|
|
||||||
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
||||||
|
|||||||
@ -79,7 +79,8 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=
|
|||||||
logits = outputs.logits
|
logits = outputs.logits
|
||||||
policy_log_probs = logits_to_log_probs(logits, y)
|
policy_log_probs = logits_to_log_probs(logits, y)
|
||||||
|
|
||||||
loss = dpo_loss(ref_log_probs, policy_log_probs, mask, beta=beta)
|
dpo_loss_val = dpo_loss(ref_log_probs, policy_log_probs, mask, beta=beta)
|
||||||
|
loss = dpo_loss_val + outputs.aux_loss
|
||||||
loss = loss / args.accumulation_steps
|
loss = loss / args.accumulation_steps
|
||||||
|
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
@ -94,12 +95,14 @@ def train_epoch(epoch, loader, iters, ref_model, lm_config, start_step=0, wandb=
|
|||||||
if step % args.log_interval == 0 or step == iters - 1:
|
if step % args.log_interval == 0 or step == iters - 1:
|
||||||
spend_time = time.time() - start_time
|
spend_time = time.time() - start_time
|
||||||
current_loss = loss.item() * args.accumulation_steps
|
current_loss = loss.item() * args.accumulation_steps
|
||||||
|
current_dpo_loss = dpo_loss_val.item()
|
||||||
|
current_aux_loss = outputs.aux_loss.item()
|
||||||
current_lr = optimizer.param_groups[-1]['lr']
|
current_lr = optimizer.param_groups[-1]['lr']
|
||||||
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
|
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
|
||||||
|
|
||||||
Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:')
|
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, dpo_loss: {current_dpo_loss:.4f}, aux_loss: {current_aux_loss:.4f}, learning_rate: {current_lr:.8f}, epoch_time: {eta_min:.3f}min')
|
||||||
|
|
||||||
if wandb: wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min})
|
if wandb: wandb.log({"loss": current_loss, "dpo_loss": current_dpo_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
|
||||||
|
|
||||||
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|||||||
@ -38,8 +38,8 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
|
|||||||
Y.view(-1)
|
Y.view(-1)
|
||||||
).view(Y.size())
|
).view(Y.size())
|
||||||
|
|
||||||
loss = (loss * loss_mask).sum() / loss_mask.sum()
|
logits_loss = (loss * loss_mask).sum() / loss_mask.sum()
|
||||||
loss += res.aux_loss
|
loss = logits_loss + res.aux_loss
|
||||||
loss = loss / args.accumulation_steps
|
loss = loss / args.accumulation_steps
|
||||||
|
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
@ -56,12 +56,14 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
|
|||||||
if step % args.log_interval == 0 or step == iters - 1:
|
if step % args.log_interval == 0 or step == iters - 1:
|
||||||
spend_time = time.time() - start_time
|
spend_time = time.time() - start_time
|
||||||
current_loss = loss.item() * args.accumulation_steps
|
current_loss = loss.item() * args.accumulation_steps
|
||||||
|
current_logits_loss = logits_loss.item()
|
||||||
|
current_aux_loss = res.aux_loss.item()
|
||||||
current_lr = optimizer.param_groups[-1]['lr']
|
current_lr = optimizer.param_groups[-1]['lr']
|
||||||
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
|
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
|
||||||
|
|
||||||
Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:')
|
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, learning_rate: {current_lr:.8f}, epoch_time: {eta_min:.3f}min')
|
||||||
|
|
||||||
if wandb: wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min})
|
if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
|
||||||
|
|
||||||
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|||||||
@ -119,7 +119,11 @@ def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_token
|
|||||||
per_token_logps.append(torch.gather(logits_row.log_softmax(dim=-1), 1, ids_row.unsqueeze(1)).squeeze(1))
|
per_token_logps.append(torch.gather(logits_row.log_softmax(dim=-1), 1, ids_row.unsqueeze(1)).squeeze(1))
|
||||||
return torch.stack(per_token_logps)
|
return torch.stack(per_token_logps)
|
||||||
|
|
||||||
per_token_logps = get_per_token_logps(model, outputs, completion_ids.size(1)) # [B*num_gen, R]
|
with autocast_ctx:
|
||||||
|
per_token_logps = get_per_token_logps(model, outputs, completion_ids.size(1)) # [B*num_gen, R]
|
||||||
|
res = model(outputs) if lm_config.use_moe else None
|
||||||
|
aux_loss = res.aux_loss if res is not None else torch.tensor(0.0, device=args.device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
ref_per_token_logps = get_per_token_logps(ref_model, outputs, completion_ids.size(1)) # [B*num_gen, R]
|
ref_per_token_logps = get_per_token_logps(ref_model, outputs, completion_ids.size(1)) # [B*num_gen, R]
|
||||||
|
|
||||||
@ -140,7 +144,8 @@ def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_token
|
|||||||
kl_div = ref_per_token_logps - per_token_logps
|
kl_div = ref_per_token_logps - per_token_logps
|
||||||
per_token_kl = torch.exp(kl_div) - kl_div - 1 # [B*num_gen, R]
|
per_token_kl = torch.exp(kl_div) - kl_div - 1 # [B*num_gen, R]
|
||||||
per_token_loss = -(torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) - args.beta * per_token_kl) # [B*num_gen, R]
|
per_token_loss = -(torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) - args.beta * per_token_kl) # [B*num_gen, R]
|
||||||
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() / args.accumulation_steps # scalar
|
policy_loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
||||||
|
loss = (policy_loss + aux_loss) / args.accumulation_steps # scalar
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
if (step + 1) % args.accumulation_steps == 0:
|
if (step + 1) % args.accumulation_steps == 0:
|
||||||
@ -151,18 +156,20 @@ def grpo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_token
|
|||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
if step % args.log_interval == 0 or step == iters:
|
if step % args.log_interval == 0 or step == iters:
|
||||||
policy_loss_val = loss.item()
|
policy_loss_val = loss.item() * args.accumulation_steps
|
||||||
|
current_aux_loss = aux_loss.item()
|
||||||
avg_reward_val = rewards.mean().item()
|
avg_reward_val = rewards.mean().item()
|
||||||
avg_len_val = completion_mask.sum(dim=1).float().mean().item()
|
avg_len_val = completion_mask.sum(dim=1).float().mean().item()
|
||||||
current_lr = optimizer.param_groups[0]['lr']
|
current_lr = optimizer.param_groups[0]['lr']
|
||||||
|
|
||||||
Logger(f'Epoch: {epoch+1}, Step: {step}/{iters}, '
|
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), '
|
||||||
f'Actor Loss: {policy_loss_val:.6f}, Reward: {avg_reward_val:.6f}, '
|
f'Actor Loss: {policy_loss_val:.4f}, Aux Loss: {current_aux_loss:.4f}, Reward: {avg_reward_val:.4f}, '
|
||||||
f'Avg Response Len: {avg_len_val:.2f}, LR: {current_lr:.2e}')
|
f'Avg Response Len: {avg_len_val:.2f}, Learning Rate: {current_lr:.8f}')
|
||||||
|
|
||||||
if wandb and is_main_process():
|
if wandb and is_main_process():
|
||||||
wandb.log({
|
wandb.log({
|
||||||
"policy_loss": policy_loss_val,
|
"policy_loss": policy_loss_val,
|
||||||
|
"aux_loss": current_aux_loss,
|
||||||
"reward": avg_reward_val,
|
"reward": avg_reward_val,
|
||||||
"avg_response_len": avg_len_val,
|
"avg_response_len": avg_len_val,
|
||||||
"advantages_mean": advantages.mean().item(),
|
"advantages_mean": advantages.mean().item(),
|
||||||
|
|||||||
@ -39,8 +39,8 @@ def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
|
|||||||
Y.view(-1)
|
Y.view(-1)
|
||||||
).view(Y.size())
|
).view(Y.size())
|
||||||
|
|
||||||
loss = (loss * loss_mask).sum() / loss_mask.sum()
|
logits_loss = (loss * loss_mask).sum() / loss_mask.sum()
|
||||||
loss += res.aux_loss
|
loss = logits_loss + res.aux_loss
|
||||||
loss = loss / args.accumulation_steps
|
loss = loss / args.accumulation_steps
|
||||||
|
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
@ -57,12 +57,14 @@ def train_epoch(epoch, loader, iters, lora_params, start_step=0, wandb=None):
|
|||||||
if step % args.log_interval == 0 or step == iters - 1:
|
if step % args.log_interval == 0 or step == iters - 1:
|
||||||
spend_time = time.time() - start_time
|
spend_time = time.time() - start_time
|
||||||
current_loss = loss.item() * args.accumulation_steps
|
current_loss = loss.item() * args.accumulation_steps
|
||||||
|
current_logits_loss = logits_loss.item()
|
||||||
|
current_aux_loss = res.aux_loss.item()
|
||||||
current_lr = optimizer.param_groups[-1]['lr']
|
current_lr = optimizer.param_groups[-1]['lr']
|
||||||
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
|
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
|
||||||
|
|
||||||
Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:')
|
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, learning_rate: {current_lr:.8f}, epoch_time: {eta_min:.3f}min')
|
||||||
|
|
||||||
if wandb: wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min})
|
if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
|
||||||
|
|
||||||
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|||||||
@ -143,7 +143,11 @@ def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_sche
|
|||||||
values = values_seq[torch.arange(values_seq.size(0), device=values_seq.device), last_indices] # [B]
|
values = values_seq[torch.arange(values_seq.size(0), device=values_seq.device), last_indices] # [B]
|
||||||
advantages = rewards - values.detach() # [B]
|
advantages = rewards - values.detach() # [B]
|
||||||
|
|
||||||
logits = actor_model(input_ids=gen_out, attention_mask=full_mask).logits # [B, P+R, V]
|
with autocast_ctx:
|
||||||
|
res = actor_model(input_ids=gen_out, attention_mask=full_mask)
|
||||||
|
logits = res.logits # [B, P+R, V]
|
||||||
|
aux_loss = res.aux_loss if lm_config.use_moe else torch.tensor(0.0, device=args.device)
|
||||||
|
|
||||||
labels = gen_out[:, 1:].clone() # [B, P+R-1]
|
labels = gen_out[:, 1:].clone() # [B, P+R-1]
|
||||||
logp_tokens = F.log_softmax(logits[:, :-1], dim=-1).gather(2, labels.unsqueeze(-1)).squeeze(-1) # [B, P+R-1]
|
logp_tokens = F.log_softmax(logits[:, :-1], dim=-1).gather(2, labels.unsqueeze(-1)).squeeze(-1) # [B, P+R-1]
|
||||||
seq_len = gen_out.size(1) - 1
|
seq_len = gen_out.size(1) - 1
|
||||||
@ -167,7 +171,7 @@ def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_sche
|
|||||||
surr2 = torch.clamp(ratio, 1.0 - args.clip_epsilon, 1.0 + args.clip_epsilon) * advantages # [B]
|
surr2 = torch.clamp(ratio, 1.0 - args.clip_epsilon, 1.0 + args.clip_epsilon) * advantages # [B]
|
||||||
policy_loss = -torch.min(surr1, surr2).mean() # scalar
|
policy_loss = -torch.min(surr1, surr2).mean() # scalar
|
||||||
value_loss = F.mse_loss(values, rewards) # scalar
|
value_loss = F.mse_loss(values, rewards) # scalar
|
||||||
loss = policy_loss + args.vf_coef * value_loss + args.kl_coef * kl_ref # scalar
|
loss = (policy_loss + args.vf_coef * value_loss + args.kl_coef * kl_ref + aux_loss) / args.accumulation_steps # scalar
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
if (step + 1) % args.accumulation_steps == 0:
|
if (step + 1) % args.accumulation_steps == 0:
|
||||||
@ -190,6 +194,7 @@ def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_sche
|
|||||||
|
|
||||||
actor_loss_val = policy_loss.item()
|
actor_loss_val = policy_loss.item()
|
||||||
critic_loss_val = value_loss.item()
|
critic_loss_val = value_loss.item()
|
||||||
|
current_aux_loss = aux_loss.item()
|
||||||
reward_val = rewards.mean().item()
|
reward_val = rewards.mean().item()
|
||||||
kl_val = kl.item()
|
kl_val = kl.item()
|
||||||
kl_ref_val = kl_ref.item()
|
kl_ref_val = kl_ref.item()
|
||||||
@ -201,6 +206,7 @@ def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_sche
|
|||||||
wandb.log({
|
wandb.log({
|
||||||
"actor_loss": actor_loss_val,
|
"actor_loss": actor_loss_val,
|
||||||
"critic_loss": critic_loss_val,
|
"critic_loss": critic_loss_val,
|
||||||
|
"aux_loss": current_aux_loss,
|
||||||
"reward": reward_val,
|
"reward": reward_val,
|
||||||
"kl": kl_val,
|
"kl": kl_val,
|
||||||
"kl_ref": kl_ref_val,
|
"kl_ref": kl_ref_val,
|
||||||
@ -208,10 +214,10 @@ def ppo_train_epoch(epoch, loader, iters, old_actor_model, ref_model, actor_sche
|
|||||||
"actor_lr": actor_lr,
|
"actor_lr": actor_lr,
|
||||||
})
|
})
|
||||||
|
|
||||||
Logger(f"Epoch: {epoch+1}, Step: {step}/{iters}, "
|
Logger(f"Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), "
|
||||||
f"Actor Loss: {actor_loss_val:.6f}, Critic Loss: {critic_loss_val:.6f}, "
|
f"Actor Loss: {actor_loss_val:.4f}, Critic Loss: {critic_loss_val:.4f}, Aux Loss: {current_aux_loss:.4f}, "
|
||||||
f"Reward: {reward_val:.6f}, KL: {kl_val:.6f}, KL_ref: {kl_ref_val:.6f}, "
|
f"Reward: {reward_val:.4f}, KL: {kl_val:.4f}, KL_ref: {kl_ref_val:.4f}, "
|
||||||
f"Avg Response Len: {avg_len_val:.2f}, Actor LR: {actor_lr:.2e}, Critic LR: {critic_lr:.2e}")
|
f"Avg Response Len: {avg_len_val:.2f}, Actor LR: {actor_lr:.8f}, Critic LR: {critic_lr:.8f}")
|
||||||
|
|
||||||
if (step + 1) % args.update_old_actor_freq == 0:
|
if (step + 1) % args.update_old_actor_freq == 0:
|
||||||
state_dict = actor_model.module.state_dict() if isinstance(actor_model, DistributedDataParallel) else actor_model.state_dict()
|
state_dict = actor_model.module.state_dict() if isinstance(actor_model, DistributedDataParallel) else actor_model.state_dict()
|
||||||
|
|||||||
@ -38,8 +38,8 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
|
|||||||
Y.view(-1)
|
Y.view(-1)
|
||||||
).view(Y.size())
|
).view(Y.size())
|
||||||
|
|
||||||
loss = (loss * loss_mask).sum() / loss_mask.sum()
|
logits_loss = (loss * loss_mask).sum() / loss_mask.sum()
|
||||||
loss += res.aux_loss
|
loss = logits_loss + res.aux_loss
|
||||||
loss = loss / args.accumulation_steps
|
loss = loss / args.accumulation_steps
|
||||||
|
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
@ -56,12 +56,14 @@ def train_epoch(epoch, loader, iters, start_step=0, wandb=None):
|
|||||||
if step % args.log_interval == 0 or step == iters - 1:
|
if step % args.log_interval == 0 or step == iters - 1:
|
||||||
spend_time = time.time() - start_time
|
spend_time = time.time() - start_time
|
||||||
current_loss = loss.item() * args.accumulation_steps
|
current_loss = loss.item() * args.accumulation_steps
|
||||||
|
current_logits_loss = logits_loss.item()
|
||||||
|
current_aux_loss = res.aux_loss.item()
|
||||||
current_lr = optimizer.param_groups[-1]['lr']
|
current_lr = optimizer.param_groups[-1]['lr']
|
||||||
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
|
eta_min = spend_time / (step + 1) * iters // 60 - spend_time // 60
|
||||||
|
|
||||||
Logger(f'Epoch:[{epoch+1}/{args.epochs}]({step}/{iters}) loss:{current_loss:.6f} lr:{current_lr:.12f} epoch_Time:{eta_min}min:')
|
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), loss: {current_loss:.4f}, logits_loss: {current_logits_loss:.4f}, aux_loss: {current_aux_loss:.4f}, learning_rate: {current_lr:.8f}, epoch_time: {eta_min:.3f}min')
|
||||||
|
|
||||||
if wandb: wandb.log({"loss": current_loss, "lr": current_lr, "epoch_Time": eta_min})
|
if wandb: wandb.log({"loss": current_loss, "logits_loss": current_logits_loss, "aux_loss": current_aux_loss, "learning_rate": current_lr, "epoch_time": eta_min})
|
||||||
|
|
||||||
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
if (step % args.save_interval == 0 or step == iters - 1) and is_main_process():
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|||||||
@ -155,7 +155,11 @@ def spo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokeni
|
|||||||
per_token_logps.append(torch.gather(logits_row.log_softmax(dim=-1), 1, ids_row.unsqueeze(1)).squeeze(1))
|
per_token_logps.append(torch.gather(logits_row.log_softmax(dim=-1), 1, ids_row.unsqueeze(1)).squeeze(1))
|
||||||
return torch.stack(per_token_logps)
|
return torch.stack(per_token_logps)
|
||||||
|
|
||||||
per_token_logps = get_per_token_logps(model, outputs, completion_ids.size(1)) # [B, R]
|
with autocast_ctx:
|
||||||
|
per_token_logps = get_per_token_logps(model, outputs, completion_ids.size(1)) # [B, R]
|
||||||
|
res = model(outputs) if lm_config.use_moe else None
|
||||||
|
aux_loss = res.aux_loss if res is not None else torch.tensor(0.0, device=args.device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
ref_per_token_logps = get_per_token_logps(ref_model, outputs, completion_ids.size(1)) # [B, R]
|
ref_per_token_logps = get_per_token_logps(ref_model, outputs, completion_ids.size(1)) # [B, R]
|
||||||
|
|
||||||
@ -180,7 +184,8 @@ def spo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokeni
|
|||||||
kl_div = ref_per_token_logps - per_token_logps # [B, R]
|
kl_div = ref_per_token_logps - per_token_logps # [B, R]
|
||||||
per_token_kl = torch.exp(kl_div) - kl_div - 1 # [B, R]
|
per_token_kl = torch.exp(kl_div) - kl_div - 1 # [B, R]
|
||||||
per_token_loss = -per_token_logps * advantages.unsqueeze(1) + args.beta * per_token_kl # [B, R]
|
per_token_loss = -per_token_logps * advantages.unsqueeze(1) + args.beta * per_token_kl # [B, R]
|
||||||
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() / args.accumulation_steps # scalar
|
policy_loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
||||||
|
loss = (policy_loss + aux_loss) / args.accumulation_steps # scalar
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
response_masks = completion_mask.float() # [B, R]
|
response_masks = completion_mask.float() # [B, R]
|
||||||
@ -194,21 +199,23 @@ def spo_train_epoch(epoch, loader, iters, ref_model, reward_model, reward_tokeni
|
|||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
if step % args.log_interval == 0 or step == iters:
|
if step % args.log_interval == 0 or step == iters:
|
||||||
policy_loss_val = loss.item()
|
policy_loss_val = loss.item() * args.accumulation_steps
|
||||||
|
current_aux_loss = aux_loss.item()
|
||||||
avg_reward_val = rewards.mean().item()
|
avg_reward_val = rewards.mean().item()
|
||||||
avg_len_val = completion_mask.sum(dim=1).float().mean().item()
|
avg_len_val = completion_mask.sum(dim=1).float().mean().item()
|
||||||
kl_val = ((per_token_kl * completion_mask).sum() / (completion_mask.sum() + 1e-8)).item()
|
kl_val = ((per_token_kl * completion_mask).sum() / (completion_mask.sum() + 1e-8)).item()
|
||||||
avg_baseline_val = baselines.mean().item()
|
avg_baseline_val = baselines.mean().item()
|
||||||
current_lr = optimizer.param_groups[0]['lr']
|
current_lr = optimizer.param_groups[0]['lr']
|
||||||
|
|
||||||
Logger(f'Epoch: {epoch+1}, Step: {step}/{iters}, '
|
Logger(f'Epoch:[{epoch + 1}/{args.epochs}]({step}/{iters}), '
|
||||||
f'Actor Loss: {policy_loss_val:.6f}, Reward: {avg_reward_val:.6f}, '
|
f'Actor Loss: {policy_loss_val:.4f}, Aux Loss: {current_aux_loss:.4f}, Reward: {avg_reward_val:.4f}, '
|
||||||
f'Baseline: {avg_baseline_val:.4f}, KL: {kl_val:.4f}, Rho: {rho:.4f}, '
|
f'Baseline: {avg_baseline_val:.4f}, KL: {kl_val:.4f}, Rho: {rho:.4f}, '
|
||||||
f'Avg Response Len: {avg_len_val:.2f}, LR: {current_lr:.2e}')
|
f'Avg Response Len: {avg_len_val:.2f}, Learning Rate: {current_lr:.8f}')
|
||||||
|
|
||||||
if wandb and is_main_process():
|
if wandb and is_main_process():
|
||||||
wandb.log({
|
wandb.log({
|
||||||
"policy_loss": policy_loss_val,
|
"policy_loss": policy_loss_val,
|
||||||
|
"aux_loss": current_aux_loss,
|
||||||
"reward": avg_reward_val,
|
"reward": avg_reward_val,
|
||||||
"kl": kl_val,
|
"kl": kl_val,
|
||||||
"rho": float(rho),
|
"rho": float(rho),
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user