mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-05-01 11:48:14 +08:00
fix bugs
This commit is contained in:
parent
5ffde04b7c
commit
caae54a89e
@ -34,17 +34,17 @@ def logits_to_probs(logits, labels):
|
||||
# logits shape: (batch_size, seq_len, vocab_size)
|
||||
# labels shape: (batch_size, seq_len)
|
||||
# probs shape: (batch_size, seq_len)
|
||||
log_probs = F.log_softmax(logits, hidden_size=2)
|
||||
probs = torch.gather(log_probs, hidden_size=2, index=labels.unsqueeze(2)).squeeze(-1)
|
||||
log_probs = F.log_softmax(logits, dim=2)
|
||||
probs = torch.gather(log_probs, dim=2, index=labels.unsqueeze(2)).squeeze(-1)
|
||||
return probs
|
||||
|
||||
|
||||
def dpo_loss(ref_probs, probs, mask, beta):
|
||||
# ref_probs 和 probs 都是 shape: (batch_size, seq_len)
|
||||
# https://github.com/jingyaogong/minimind/issues/298
|
||||
seq_lengths = mask.sum(hidden_size=1, keephidden_size=True) # (batch_size, 1)
|
||||
ref_probs = (ref_probs * mask).sum(hidden_size=1) / seq_lengths.squeeze()
|
||||
probs = (probs * mask).sum(hidden_size=1) / seq_lengths.squeeze()
|
||||
seq_lengths = mask.sum(dim=1, keepdim=True) # (batch_size, 1)
|
||||
ref_probs = (ref_probs * mask).sum(dim=1) / seq_lengths.squeeze()
|
||||
probs = (probs * mask).sum(dim=1) / seq_lengths.squeeze()
|
||||
|
||||
# 将 chosen 和 rejected 数据分开
|
||||
batch_size = ref_probs.shape[0]
|
||||
@ -69,9 +69,9 @@ def train_epoch(epoch, wandb):
|
||||
y_rejected = batch['y_rejected'].to(args.device)
|
||||
mask_chosen = batch['mask_chosen'].to(args.device)
|
||||
mask_rejected = batch['mask_rejected'].to(args.device)
|
||||
x = torch.cat([x_chosen, x_rejected], hidden_size=0)
|
||||
y = torch.cat([y_chosen, y_rejected], hidden_size=0)
|
||||
mask = torch.cat([mask_chosen, mask_rejected], hidden_size=0)
|
||||
x = torch.cat([x_chosen, x_rejected], dim=0)
|
||||
y = torch.cat([y_chosen, y_rejected], dim=0)
|
||||
mask = torch.cat([mask_chosen, mask_rejected], dim=0)
|
||||
|
||||
lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate)
|
||||
for param_group in optimizer.param_groups:
|
||||
@ -166,7 +166,7 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="MiniMind RLHF")
|
||||
parser.add_argument("--out_dir", type=str, default="../out")
|
||||
parser.add_argument("--epochs", type=int, default=2)
|
||||
parser.add_argument("--batch_size", type=int, default=8)
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
# sft阶段学习率为 「5e-6」->「5e-7」长度512,建议离线正负样本「概率」偏好对齐阶段lr <=「1e-8」长度3000,否则很容易遗忘训坏
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-8)
|
||||
parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
@ -136,11 +136,11 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--log_interval", type=int, default=100)
|
||||
parser.add_argument("--save_interval", type=int, default=100)
|
||||
parser.add_argument('--local_rank', type=int, default=-1)
|
||||
parser.add_argument('--hidden_size', default=768, type=int)
|
||||
parser.add_argument('--num_hidden_layers', default=16, type=int)
|
||||
parser.add_argument('--hidden_size', default=512, type=int)
|
||||
parser.add_argument('--num_hidden_layers', default=8, type=int)
|
||||
parser.add_argument('--max_seq_len', default=512, type=int)
|
||||
parser.add_argument('--use_moe', default=False, type=bool)
|
||||
parser.add_argument("--data_path", type=str, default="../dataset/sft_512.jsonl")
|
||||
parser.add_argument("--data_path", type=str, default="../dataset/sft_mini_512.jsonl")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user