From bcd6f3f9ce77a96511951f3e66659ba49a602167 Mon Sep 17 00:00:00 2001 From: Yasyf Mohamedali Date: Wed, 4 Jan 2023 13:49:56 -0800 Subject: [PATCH] Various Fixes for Flax Dreambooth (#1782) * Various Fixes for Flax Dreambooth - Correctly update the progress bar every epoch - Allow specifying a pretrained VAE - Allow specifying a revision to pretrained models - Cache compiled models between invocations (speeds up TPU execution a lot!) - Save intermediate checkpoints by specifying `save_steps` * Don't die when save_steps is not set. * Address CR * Address comments * Apply suggestions from code review Co-authored-by: Pedro Cuenca Co-authored-by: Suraj Patil Co-authored-by: Patrick von Platen Co-authored-by: Pedro Cuenca --- examples/dreambooth/train_dreambooth_flax.py | 107 +++++++++++++------ 1 file changed, 72 insertions(+), 35 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_flax.py b/examples/dreambooth/train_dreambooth_flax.py index 8cb33df30e..3264aa726b 100644 --- a/examples/dreambooth/train_dreambooth_flax.py +++ b/examples/dreambooth/train_dreambooth_flax.py @@ -28,6 +28,7 @@ from flax import jax_utils from flax.training import train_state from flax.training.common_utils import shard from huggingface_hub import HfFolder, Repository, whoami +from jax.experimental.compilation_cache import compilation_cache as cc from PIL import Image from torchvision import transforms from tqdm.auto import tqdm @@ -37,6 +38,9 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.10.0.dev0") +# Cache compiled models across invocations of this script. +cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache")) + logger = logging.getLogger(__name__) @@ -49,6 +53,19 @@ def parse_args(): required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) + parser.add_argument( + "--pretrained_vae_name_or_path", + type=str, + default=None, + help="Path to pretrained vae or vae identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) parser.add_argument( "--tokenizer_name", type=str, @@ -103,6 +120,7 @@ def parse_args(): default="text-inversion-model", help="The output directory where the model predictions and checkpoints will be written.", ) + parser.add_argument("--save_steps", type=int, default=None, help="Save a checkpoint every X steps.") parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.") parser.add_argument( "--resolution", @@ -332,7 +350,7 @@ def main(): if cur_class_images < args.num_class_images: pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, safety_checker=None + args.pretrained_model_name_or_path, safety_checker=None, revision=args.revision ) pipeline.set_progress_bar_config(disable=True) @@ -383,7 +401,11 @@ def main(): if args.tokenizer_name: tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) elif args.pretrained_model_name_or_path: - tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") + tokenizer = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision + ) + else: + raise NotImplementedError("No tokenizer specified!") train_dataset = DreamBoothDataset( instance_data_root=args.instance_data_dir, @@ -437,15 +459,23 @@ def main(): elif args.mixed_precision == "bf16": weight_dtype = jnp.bfloat16 + if args.pretrained_vae_name_or_path: + # TODO(patil-suraj): Upload flax weights for the VAE + vae_arg, vae_kwargs = (args.pretrained_vae_name_or_path, {"from_pt": True}) + else: + vae_arg, vae_kwargs = (args.pretrained_model_name_or_path, {"subfolder": "vae", "revision": args.revision}) + # Load models and create wrapper for stable diffusion text_encoder = FlaxCLIPTextModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype + args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype, revision=args.revision ) vae, vae_params = FlaxAutoencoderKL.from_pretrained( - args.pretrained_model_name_or_path, subfolder="vae", dtype=weight_dtype + vae_arg, + dtype=weight_dtype, + **vae_kwargs, ) unet, unet_params = FlaxUNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="unet", dtype=weight_dtype + args.pretrained_model_name_or_path, subfolder="unet", dtype=weight_dtype, revision=args.revision ) # Optimization @@ -597,6 +627,39 @@ def main(): logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}") logger.info(f" Total optimization steps = {args.max_train_steps}") + def checkpoint(step=None): + # Create the pipeline using the trained modules and save it. + scheduler, _ = FlaxPNDMScheduler.from_pretrained( + "CompVis/stable-diffusion-v1-4", subfolder="scheduler" + ) + safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained( + "CompVis/stable-diffusion-safety-checker", from_pt=True + ) + pipeline = FlaxStableDiffusionPipeline( + text_encoder=text_encoder, + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), + ) + + outdir = os.path.join(args.output_dir, str(step)) if step else args.output_dir + pipeline.save_pretrained( + outdir, + params={ + "text_encoder": get_params_to_save(text_encoder_state.params), + "vae": get_params_to_save(vae_params), + "unet": get_params_to_save(unet_state.params), + "safety_checker": safety_checker.params, + }, + ) + + if args.push_to_hub: + message = f"checkpoint-{step}" if step is not None else "End of training" + repo.push_to_hub(commit_message=message, blocking=False, auto_lfs_prune=True) + global_step = 0 epochs = tqdm(range(args.num_train_epochs), desc="Epoch ... ", position=0) @@ -615,9 +678,11 @@ def main(): ) train_metrics.append(train_metric) - train_step_progress_bar.update(1) + train_step_progress_bar.update(jax.local_device_count()) global_step += 1 + if jax.process_index() == 0 and args.save_steps and global_step % args.save_steps == 0: + checkpoint(global_step) if global_step >= args.max_train_steps: break @@ -626,36 +691,8 @@ def main(): train_step_progress_bar.close() epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})") - # Create the pipeline using using the trained modules and save it. if jax.process_index() == 0: - scheduler = FlaxPNDMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True - ) - safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained( - "CompVis/stable-diffusion-safety-checker", from_pt=True - ) - pipeline = FlaxStableDiffusionPipeline( - text_encoder=text_encoder, - vae=vae, - unet=unet, - tokenizer=tokenizer, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), - ) - - pipeline.save_pretrained( - args.output_dir, - params={ - "text_encoder": get_params_to_save(text_encoder_state.params), - "vae": get_params_to_save(vae_params), - "unet": get_params_to_save(unet_state.params), - "safety_checker": safety_checker.params, - }, - ) - - if args.push_to_hub: - repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) + checkpoint() if __name__ == "__main__":