mirror of
https://github.com/jingyaogong/minimind.git
synced 2026-06-06 00:04:50 +00:00
[fix] data skip
This commit is contained in:
@@ -222,7 +222,7 @@ if __name__ == "__main__":
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
skip = start_step + 1 if (epoch == start_epoch and start_step > 0) else 0
|
||||
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
||||
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
if skip > 0:
|
||||
|
||||
@@ -206,7 +206,7 @@ if __name__ == "__main__":
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
skip = start_step + 1 if (epoch == start_epoch and start_step > 0) else 0
|
||||
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
||||
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
if skip > 0:
|
||||
|
||||
@@ -149,7 +149,7 @@ if __name__ == "__main__":
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
skip = start_step + 1 if (epoch == start_epoch and start_step > 0) else 0
|
||||
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
||||
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
if skip > 0:
|
||||
|
||||
@@ -291,7 +291,7 @@ if __name__ == "__main__":
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
skip = start_step + 1 if (epoch == start_epoch and start_step > 0) else 0
|
||||
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
||||
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
if skip > 0:
|
||||
|
||||
@@ -162,7 +162,7 @@ if __name__ == "__main__":
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
skip = start_step + 1 if (epoch == start_epoch and start_step > 0) else 0
|
||||
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
||||
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
if skip > 0:
|
||||
|
||||
@@ -366,7 +366,7 @@ if __name__ == "__main__":
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
skip = start_step + 1 if (epoch == start_epoch and start_step > 0) else 0
|
||||
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
||||
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
if skip > 0:
|
||||
|
||||
@@ -148,7 +148,7 @@ if __name__ == "__main__":
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
skip = start_step + 1 if (epoch == start_epoch and start_step > 0) else 0
|
||||
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
||||
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
if skip > 0:
|
||||
|
||||
@@ -166,7 +166,7 @@ if __name__ == "__main__":
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
skip = start_step + 1 if (epoch == start_epoch and start_step > 0) else 0
|
||||
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
||||
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
if skip > 0:
|
||||
|
||||
@@ -341,7 +341,7 @@ if __name__ == "__main__":
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
train_sampler and train_sampler.set_epoch(epoch)
|
||||
setup_seed(42 + epoch); indices = torch.randperm(len(train_ds)).tolist()
|
||||
skip = start_step + 1 if (epoch == start_epoch and start_step > 0) else 0
|
||||
skip = start_step if (epoch == start_epoch and start_step > 0) else 0
|
||||
batch_sampler = SkipBatchSampler(train_sampler or indices, args.batch_size, skip)
|
||||
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True)
|
||||
if skip > 0:
|
||||
|
||||
Reference in New Issue
Block a user