mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-04-23 15:58:15 +08:00
[fix] loss-issues-430
This commit is contained in:
parent
4014e62cdf
commit
805744e60a
@ -79,7 +79,7 @@ def train_epoch(epoch, wandb):
|
||||
if step % args.log_interval == 0 or step == iter_per_epoch - 1:
|
||||
spend_time = time.time() - start_time
|
||||
Logger(
|
||||
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
|
||||
'Epoch:[{}/{}]({}/{}) loss:{:.6f} lr:{:.12f} epoch_Time:{}min:'.format(
|
||||
epoch + 1,
|
||||
args.epochs,
|
||||
step,
|
||||
|
||||
@ -113,7 +113,7 @@ def train_epoch(epoch, wandb, alpha=0.0, temperature=1.0):
|
||||
if step % args.log_interval == 0 or step == iter_per_epoch - 1:
|
||||
spend_time = time.time() - start_time
|
||||
Logger(
|
||||
'Epoch:[{}/{}]({}/{}) loss:{:.4f} lr:{:.12f} epoch_Time:{}min:'.format(
|
||||
'Epoch:[{}/{}]({}/{}) loss:{:.6f} lr:{:.12f} epoch_Time:{}min:'.format(
|
||||
epoch,
|
||||
args.epochs - 1,
|
||||
step,
|
||||
|
||||
@ -103,7 +103,7 @@ def train_epoch(epoch, wandb):
|
||||
if step % args.log_interval == 0 or step == iter_per_epoch - 1:
|
||||
spend_time = time.time() - start_time
|
||||
Logger(
|
||||
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
|
||||
'Epoch:[{}/{}]({}/{}) loss:{:.6f} lr:{:.12f} epoch_Time:{}min:'.format(
|
||||
epoch + 1,
|
||||
args.epochs,
|
||||
step,
|
||||
|
||||
@ -66,7 +66,7 @@ def train_epoch(epoch, wandb):
|
||||
if step % args.log_interval == 0 or step == iter_per_epoch - 1:
|
||||
spend_time = time.time() - start_time
|
||||
Logger(
|
||||
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
|
||||
'Epoch:[{}/{}]({}/{}) loss:{:.6f} lr:{:.12f} epoch_Time:{}min:'.format(
|
||||
epoch + 1,
|
||||
args.epochs,
|
||||
step,
|
||||
|
||||
@ -169,7 +169,7 @@ def grpo_train_epoch(epoch, wandb):
|
||||
|
||||
Logger(
|
||||
f'Epoch: {epoch}, Step: {step + 1}/{iter_per_epoch}, '
|
||||
f'Actor Loss: {policy_loss_val:.4f}, Reward: {avg_reward_val:.4f}, '
|
||||
f'Actor Loss: {policy_loss_val:.6f}, Reward: {avg_reward_val:.6f}, '
|
||||
f'Avg Response Len: {avg_len_val:.2f}, LR: {current_lr:.2e}')
|
||||
|
||||
if wandb and (not ddp or dist.get_rank() == 0):
|
||||
|
||||
@ -67,7 +67,7 @@ def train_epoch(epoch, wandb):
|
||||
if step % args.log_interval == 0 or step == iter_per_epoch - 1:
|
||||
spend_time = time.time() - start_time
|
||||
Logger(
|
||||
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
|
||||
'Epoch:[{}/{}]({}/{}) loss:{:.6f} lr:{:.12f} epoch_Time:{}min:'.format(
|
||||
epoch + 1,
|
||||
args.epochs,
|
||||
step,
|
||||
|
||||
@ -193,8 +193,8 @@ def ppo_train_epoch(epoch: int, wandb_run, old_actor_model, ref_model, actor_sch
|
||||
})
|
||||
|
||||
Logger(f"Epoch: {epoch}, Step: {step + 1}/{len(train_loader)}, "
|
||||
f"Actor Loss: {actor_loss_val:.4f}, Critic Loss: {critic_loss_val:.4f}, "
|
||||
f"Reward: {reward_val:.4f}, KL: {kl_val:.4f}, KL_ref: {kl_ref_val:.4f}, "
|
||||
f"Actor Loss: {actor_loss_val:.6f}, Critic Loss: {critic_loss_val:.6f}, "
|
||||
f"Reward: {reward_val:.6f}, KL: {kl_val:.6f}, KL_ref: {kl_ref_val:.6f}, "
|
||||
f"Avg Response Len: {avg_len_val:.2f}, Actor LR: {actor_lr:.2e}, Critic LR: {critic_lr:.2e}")
|
||||
|
||||
if (step + 1) % args.update_old_actor_freq == 0:
|
||||
|
||||
@ -66,7 +66,7 @@ def train_epoch(epoch, wandb):
|
||||
if step % args.log_interval == 0 or step == iter_per_epoch - 1:
|
||||
spend_time = time.time() - start_time
|
||||
Logger(
|
||||
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
|
||||
'Epoch:[{}/{}]({}/{}) loss:{:.6f} lr:{:.12f} epoch_Time:{}min:'.format(
|
||||
epoch + 1,
|
||||
args.epochs,
|
||||
step,
|
||||
|
||||
@ -216,7 +216,7 @@ def spo_train_epoch(epoch, wandb, value_tracker):
|
||||
|
||||
Logger(
|
||||
f'Epoch: {epoch}, Step: {step + 1}/{iter_per_epoch}, '
|
||||
f'Actor Loss: {policy_loss_val:.4f}, Reward: {avg_reward_val:.4f}, '
|
||||
f'Actor Loss: {policy_loss_val:.6f}, Reward: {avg_reward_val:.6f}, '
|
||||
f'Baseline: {avg_baseline_val:.4f}, KL: {kl_val:.4f}, Rho: {rho:.4f}, Avg Response Len: {avg_len_val:.2f}, LR: {current_lr:.2e}')
|
||||
|
||||
if wandb and (not ddp or dist.get_rank() == 0):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user