[fix] loss-issues-430

This commit is contained in:
jingyaogong 2025-10-23 19:08:42 +08:00
parent 4014e62cdf
commit 805744e60a
9 changed files with 10 additions and 10 deletions

View File

@ -79,7 +79,7 @@ def train_epoch(epoch, wandb):
if step % args.log_interval == 0 or step == iter_per_epoch - 1:
spend_time = time.time() - start_time
Logger(
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
'Epoch:[{}/{}]({}/{}) loss:{:.6f} lr:{:.12f} epoch_Time:{}min:'.format(
epoch + 1,
args.epochs,
step,

View File

@ -113,7 +113,7 @@ def train_epoch(epoch, wandb, alpha=0.0, temperature=1.0):
if step % args.log_interval == 0 or step == iter_per_epoch - 1:
spend_time = time.time() - start_time
Logger(
'Epoch:[{}/{}]({}/{}) loss:{:.4f} lr:{:.12f} epoch_Time:{}min:'.format(
'Epoch:[{}/{}]({}/{}) loss:{:.6f} lr:{:.12f} epoch_Time:{}min:'.format(
epoch,
args.epochs - 1,
step,

View File

@ -103,7 +103,7 @@ def train_epoch(epoch, wandb):
if step % args.log_interval == 0 or step == iter_per_epoch - 1:
spend_time = time.time() - start_time
Logger(
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
'Epoch:[{}/{}]({}/{}) loss:{:.6f} lr:{:.12f} epoch_Time:{}min:'.format(
epoch + 1,
args.epochs,
step,

View File

@ -66,7 +66,7 @@ def train_epoch(epoch, wandb):
if step % args.log_interval == 0 or step == iter_per_epoch - 1:
spend_time = time.time() - start_time
Logger(
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
'Epoch:[{}/{}]({}/{}) loss:{:.6f} lr:{:.12f} epoch_Time:{}min:'.format(
epoch + 1,
args.epochs,
step,

View File

@ -169,7 +169,7 @@ def grpo_train_epoch(epoch, wandb):
Logger(
f'Epoch: {epoch}, Step: {step + 1}/{iter_per_epoch}, '
f'Actor Loss: {policy_loss_val:.4f}, Reward: {avg_reward_val:.4f}, '
f'Actor Loss: {policy_loss_val:.6f}, Reward: {avg_reward_val:.6f}, '
f'Avg Response Len: {avg_len_val:.2f}, LR: {current_lr:.2e}')
if wandb and (not ddp or dist.get_rank() == 0):

View File

@ -67,7 +67,7 @@ def train_epoch(epoch, wandb):
if step % args.log_interval == 0 or step == iter_per_epoch - 1:
spend_time = time.time() - start_time
Logger(
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
'Epoch:[{}/{}]({}/{}) loss:{:.6f} lr:{:.12f} epoch_Time:{}min:'.format(
epoch + 1,
args.epochs,
step,

View File

@ -193,8 +193,8 @@ def ppo_train_epoch(epoch: int, wandb_run, old_actor_model, ref_model, actor_sch
})
Logger(f"Epoch: {epoch}, Step: {step + 1}/{len(train_loader)}, "
f"Actor Loss: {actor_loss_val:.4f}, Critic Loss: {critic_loss_val:.4f}, "
f"Reward: {reward_val:.4f}, KL: {kl_val:.4f}, KL_ref: {kl_ref_val:.4f}, "
f"Actor Loss: {actor_loss_val:.6f}, Critic Loss: {critic_loss_val:.6f}, "
f"Reward: {reward_val:.6f}, KL: {kl_val:.6f}, KL_ref: {kl_ref_val:.6f}, "
f"Avg Response Len: {avg_len_val:.2f}, Actor LR: {actor_lr:.2e}, Critic LR: {critic_lr:.2e}")
if (step + 1) % args.update_old_actor_freq == 0:

View File

@ -66,7 +66,7 @@ def train_epoch(epoch, wandb):
if step % args.log_interval == 0 or step == iter_per_epoch - 1:
spend_time = time.time() - start_time
Logger(
'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.12f} epoch_Time:{}min:'.format(
'Epoch:[{}/{}]({}/{}) loss:{:.6f} lr:{:.12f} epoch_Time:{}min:'.format(
epoch + 1,
args.epochs,
step,

View File

@ -216,7 +216,7 @@ def spo_train_epoch(epoch, wandb, value_tracker):
Logger(
f'Epoch: {epoch}, Step: {step + 1}/{iter_per_epoch}, '
f'Actor Loss: {policy_loss_val:.4f}, Reward: {avg_reward_val:.4f}, '
f'Actor Loss: {policy_loss_val:.6f}, Reward: {avg_reward_val:.6f}, '
f'Baseline: {avg_baseline_val:.4f}, KL: {kl_val:.4f}, Rho: {rho:.4f}, Avg Response Len: {avg_len_val:.2f}, LR: {current_lr:.2e}')
if wandb and (not ddp or dist.get_rank() == 0):