mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-04-25 08:48:16 +08:00
[fix] data skip
This commit is contained in:
parent
f7ffdf1fdb
commit
fea69cf338
@ -222,7 +222,7 @@ if __name__ == "__main__":
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
skip = start_step + 1 if (epoch == start_epoch and start_step > 0) else 0
|
||||
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
||||
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
if skip > 0:
|
||||
|
||||
@ -206,7 +206,7 @@ if __name__ == "__main__":
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
skip = start_step + 1 if (epoch == start_epoch and start_step > 0) else 0
|
||||
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
||||
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
if skip > 0:
|
||||
|
||||
@ -149,7 +149,7 @@ if __name__ == "__main__":
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
skip = start_step + 1 if (epoch == start_epoch and start_step > 0) else 0
|
||||
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
||||
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
if skip > 0:
|
||||
|
||||
@ -291,7 +291,7 @@ if __name__ == "__main__":
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
skip = start_step + 1 if (epoch == start_epoch and start_step > 0) else 0
|
||||
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
||||
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
if skip > 0:
|
||||
|
||||
@ -162,7 +162,7 @@ if __name__ == "__main__":
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
skip = start_step + 1 if (epoch == start_epoch and start_step > 0) else 0
|
||||
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
||||
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
if skip > 0:
|
||||
|
||||
@ -366,7 +366,7 @@ if __name__ == "__main__":
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
skip = start_step + 1 if (epoch == start_epoch and start_step > 0) else 0
|
||||
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
||||
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
if skip > 0:
|
||||
|
||||
@ -148,7 +148,7 @@ if __name__ == "__main__":
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
skip = start_step + 1 if (epoch == start_epoch and start_step > 0) else 0
|
||||
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
||||
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
if skip > 0:
|
||||
|
||||
@ -166,7 +166,7 @@ if __name__ == "__main__":
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
skip = start_step + 1 if (epoch == start_epoch and start_step > 0) else 0
|
||||
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
||||
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
if skip > 0:
|
||||
|
||||
@ -341,7 +341,7 @@ if __name__ == "__main__":
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
skip = start_step + 1 if (epoch == start_epoch and start_step > 0) else 0
|
||||
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
||||
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
if skip > 0:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user