mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-04-25 08:48:16 +08:00
[feat] shuffle data
This commit is contained in:
parent
805744e60a
commit
a82526da11
@ -199,7 +199,7 @@ if __name__ == "__main__":
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=True,
|
||||
drop_last=False,
|
||||
shuffle=False,
|
||||
shuffle=(train_sampler is None),
|
||||
num_workers=args.num_workers,
|
||||
sampler=train_sampler
|
||||
)
|
||||
|
||||
@ -248,7 +248,7 @@ if __name__ == "__main__":
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=True,
|
||||
drop_last=False,
|
||||
shuffle=False,
|
||||
shuffle=(train_sampler is None),
|
||||
num_workers=args.num_workers,
|
||||
sampler=train_sampler
|
||||
)
|
||||
|
||||
@ -231,7 +231,7 @@ if __name__ == "__main__":
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=True,
|
||||
drop_last=False,
|
||||
shuffle=False,
|
||||
shuffle=(train_sampler is None),
|
||||
num_workers=args.num_workers,
|
||||
sampler=train_sampler
|
||||
)
|
||||
|
||||
@ -185,7 +185,7 @@ if __name__ == "__main__":
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=True,
|
||||
drop_last=False,
|
||||
shuffle=False,
|
||||
shuffle=(train_sampler is None),
|
||||
num_workers=args.num_workers,
|
||||
sampler=train_sampler
|
||||
)
|
||||
|
||||
@ -299,7 +299,7 @@ if __name__ == "__main__":
|
||||
train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
|
||||
train_sampler = DistributedSampler(train_ds) if ddp else None
|
||||
train_loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True,
|
||||
drop_last=False, shuffle=False,
|
||||
drop_last=False, shuffle=(train_sampler is None),
|
||||
num_workers=args.num_workers, sampler=train_sampler)
|
||||
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||
|
||||
@ -196,7 +196,7 @@ if __name__ == "__main__":
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=True,
|
||||
drop_last=False,
|
||||
shuffle=False,
|
||||
shuffle=(train_sampler is None),
|
||||
num_workers=args.num_workers,
|
||||
sampler=train_sampler
|
||||
)
|
||||
|
||||
@ -344,7 +344,7 @@ if __name__ == "__main__":
|
||||
train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=(args.max_seq_len + args.max_gen_len))
|
||||
train_sampler = DistributedSampler(train_ds) if ddp else None
|
||||
train_loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True,
|
||||
drop_last=False, shuffle=False,
|
||||
drop_last=False, shuffle=(train_sampler is None),
|
||||
num_workers=args.num_workers, sampler=train_sampler)
|
||||
|
||||
# 初始化优化器
|
||||
|
||||
@ -183,7 +183,7 @@ if __name__ == "__main__":
|
||||
batch_size=args.batch_size,
|
||||
pin_memory=True,
|
||||
drop_last=False,
|
||||
shuffle=False,
|
||||
shuffle=(train_sampler is None),
|
||||
num_workers=args.num_workers,
|
||||
sampler=train_sampler
|
||||
)
|
||||
|
||||
@ -348,7 +348,7 @@ if __name__ == "__main__":
|
||||
train_ds = RLAIFDataset(args.data_path, tokenizer, max_length=lm_config.max_seq_len)
|
||||
train_sampler = DistributedSampler(train_ds) if ddp else None
|
||||
train_loader = DataLoader(train_ds, batch_size=args.batch_size, pin_memory=True,
|
||||
drop_last=False, shuffle=False,
|
||||
drop_last=False, shuffle=(train_sampler is None),
|
||||
num_workers=args.num_workers, sampler=train_sampler)
|
||||
|
||||
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user