diff --git a/trainer/train_dpo.py b/trainer/train_dpo.py index bc1fe0d..c7701d0 100644 --- a/trainer/train_dpo.py +++ b/trainer/train_dpo.py @@ -34,17 +34,17 @@ def logits_to_probs(logits, labels): # logits shape: (batch_size, seq_len, vocab_size) # labels shape: (batch_size, seq_len) # probs shape: (batch_size, seq_len) - log_probs = F.log_softmax(logits, hidden_size=2) - probs = torch.gather(log_probs, hidden_size=2, index=labels.unsqueeze(2)).squeeze(-1) + log_probs = F.log_softmax(logits, dim=2) + probs = torch.gather(log_probs, dim=2, index=labels.unsqueeze(2)).squeeze(-1) return probs def dpo_loss(ref_probs, probs, mask, beta): # ref_probs 和 probs 都是 shape: (batch_size, seq_len) # https://github.com/jingyaogong/minimind/issues/298 - seq_lengths = mask.sum(hidden_size=1, keephidden_size=True) # (batch_size, 1) - ref_probs = (ref_probs * mask).sum(hidden_size=1) / seq_lengths.squeeze() - probs = (probs * mask).sum(hidden_size=1) / seq_lengths.squeeze() + seq_lengths = mask.sum(dim=1, keepdim=True) # (batch_size, 1) + ref_probs = (ref_probs * mask).sum(dim=1) / seq_lengths.squeeze() + probs = (probs * mask).sum(dim=1) / seq_lengths.squeeze() # 将 chosen 和 rejected 数据分开 batch_size = ref_probs.shape[0] @@ -69,9 +69,9 @@ def train_epoch(epoch, wandb): y_rejected = batch['y_rejected'].to(args.device) mask_chosen = batch['mask_chosen'].to(args.device) mask_rejected = batch['mask_rejected'].to(args.device) - x = torch.cat([x_chosen, x_rejected], hidden_size=0) - y = torch.cat([y_chosen, y_rejected], hidden_size=0) - mask = torch.cat([mask_chosen, mask_rejected], hidden_size=0) + x = torch.cat([x_chosen, x_rejected], dim=0) + y = torch.cat([y_chosen, y_rejected], dim=0) + mask = torch.cat([mask_chosen, mask_rejected], dim=0) lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate) for param_group in optimizer.param_groups: @@ -166,7 +166,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description="MiniMind RLHF") parser.add_argument("--out_dir", type=str, default="../out") parser.add_argument("--epochs", type=int, default=2) - parser.add_argument("--batch_size", type=int, default=8) + parser.add_argument("--batch_size", type=int, default=4) # sft阶段学习率为 「5e-6」->「5e-7」长度512,建议离线正负样本「概率」偏好对齐阶段lr <=「1e-8」长度3000,否则很容易遗忘训坏 parser.add_argument("--learning_rate", type=float, default=1e-8) parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu") diff --git a/trainer/train_full_sft.py b/trainer/train_full_sft.py index 0670a11..e2a9a2f 100644 --- a/trainer/train_full_sft.py +++ b/trainer/train_full_sft.py @@ -136,11 +136,11 @@ if __name__ == "__main__": parser.add_argument("--log_interval", type=int, default=100) parser.add_argument("--save_interval", type=int, default=100) parser.add_argument('--local_rank', type=int, default=-1) - parser.add_argument('--hidden_size', default=768, type=int) - parser.add_argument('--num_hidden_layers', default=16, type=int) + parser.add_argument('--hidden_size', default=512, type=int) + parser.add_argument('--num_hidden_layers', default=8, type=int) parser.add_argument('--max_seq_len', default=512, type=int) parser.add_argument('--use_moe', default=False, type=bool) - parser.add_argument("--data_path", type=str, default="../dataset/sft_512.jsonl") + parser.add_argument("--data_path", type=str, default="../dataset/sft_mini_512.jsonl") args = parser.parse_args()