This commit is contained in:
jingyaogong 2025-04-29 20:45:39 +08:00
parent 5ffde04b7c
commit caae54a89e
2 changed files with 12 additions and 12 deletions

View File

@ -34,17 +34,17 @@ def logits_to_probs(logits, labels):
# logits shape: (batch_size, seq_len, vocab_size)
# labels shape: (batch_size, seq_len)
# probs shape: (batch_size, seq_len)
log_probs = F.log_softmax(logits, hidden_size=2)
probs = torch.gather(log_probs, hidden_size=2, index=labels.unsqueeze(2)).squeeze(-1)
log_probs = F.log_softmax(logits, dim=2)
probs = torch.gather(log_probs, dim=2, index=labels.unsqueeze(2)).squeeze(-1)
return probs
def dpo_loss(ref_probs, probs, mask, beta):
# ref_probs 和 probs 都是 shape: (batch_size, seq_len)
# https://github.com/jingyaogong/minimind/issues/298
seq_lengths = mask.sum(hidden_size=1, keephidden_size=True) # (batch_size, 1)
ref_probs = (ref_probs * mask).sum(hidden_size=1) / seq_lengths.squeeze()
probs = (probs * mask).sum(hidden_size=1) / seq_lengths.squeeze()
seq_lengths = mask.sum(dim=1, keepdim=True) # (batch_size, 1)
ref_probs = (ref_probs * mask).sum(dim=1) / seq_lengths.squeeze()
probs = (probs * mask).sum(dim=1) / seq_lengths.squeeze()
# 将 chosen 和 rejected 数据分开
batch_size = ref_probs.shape[0]
@ -69,9 +69,9 @@ def train_epoch(epoch, wandb):
y_rejected = batch['y_rejected'].to(args.device)
mask_chosen = batch['mask_chosen'].to(args.device)
mask_rejected = batch['mask_rejected'].to(args.device)
x = torch.cat([x_chosen, x_rejected], hidden_size=0)
y = torch.cat([y_chosen, y_rejected], hidden_size=0)
mask = torch.cat([mask_chosen, mask_rejected], hidden_size=0)
x = torch.cat([x_chosen, x_rejected], dim=0)
y = torch.cat([y_chosen, y_rejected], dim=0)
mask = torch.cat([mask_chosen, mask_rejected], dim=0)
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
for param_group in optimizer.param_groups:
@ -166,7 +166,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="MiniMind RLHF")
parser.add_argument("--out_dir", type=str, default="../out")
parser.add_argument("--epochs", type=int, default=2)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--batch_size", type=int, default=4)
# sft阶段学习率为 「5e-6」->「5e-7」长度512建议离线正负样本「概率」偏好对齐阶段lr <=「1e-8」长度3000否则很容易遗忘训坏
parser.add_argument("--learning_rate", type=float, default=1e-8)
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")

View File

@ -136,11 +136,11 @@ if __name__ == "__main__":
parser.add_argument("--log_interval", type=int, default=100)
parser.add_argument("--save_interval", type=int, default=100)
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--hidden_size', default=768, type=int)
parser.add_argument('--num_hidden_layers', default=16, type=int)
parser.add_argument('--hidden_size', default=512, type=int)
parser.add_argument('--num_hidden_layers', default=8, type=int)
parser.add_argument('--max_seq_len', default=512, type=int)
parser.add_argument('--use_moe', default=False, type=bool)
parser.add_argument("--data_path", type=str, default="../dataset/sft_512.jsonl")
parser.add_argument("--data_path", type=str, default="../dataset/sft_mini_512.jsonl")
args = parser.parse_args()