Fix --resume_from_checkpoint step in train_text_to_image.py (#1914)
fix resume step in train_text_to_image example
This commit is contained in:
@@ -685,9 +685,8 @@ def main():
|
|||||||
accelerator.load_state(os.path.join(args.output_dir, path))
|
accelerator.load_state(os.path.join(args.output_dir, path))
|
||||||
global_step = int(path.split("-")[1])
|
global_step = int(path.split("-")[1])
|
||||||
|
|
||||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
first_epoch = global_step // num_update_steps_per_epoch
|
||||||
first_epoch = resume_global_step // num_update_steps_per_epoch
|
resume_step = global_step % num_update_steps_per_epoch
|
||||||
resume_step = resume_global_step % num_update_steps_per_epoch
|
|
||||||
|
|
||||||
# Only show the progress bar once on each machine.
|
# Only show the progress bar once on each machine.
|
||||||
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||||
|
|||||||
Reference in New Issue
Block a user