Various Fixes for Flax Dreambooth (#1782)
* Various Fixes for Flax Dreambooth - Correctly update the progress bar every epoch - Allow specifying a pretrained VAE - Allow specifying a revision to pretrained models - Cache compiled models between invocations (speeds up TPU execution a lot!) - Save intermediate checkpoints by specifying `save_steps` * Don't die when save_steps is not set. * Address CR * Address comments * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Suraj Patil <surajp815@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
@@ -28,6 +28,7 @@ from flax import jax_utils
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import shard
|
||||
from huggingface_hub import HfFolder, Repository, whoami
|
||||
from jax.experimental.compilation_cache import compilation_cache as cc
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
@@ -37,6 +38,9 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel,
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.10.0.dev0")
|
||||
|
||||
# Cache compiled models across invocations of this script.
|
||||
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -49,6 +53,19 @@ def parse_args():
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_vae_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to pretrained vae or vae identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Revision of pretrained model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
type=str,
|
||||
@@ -103,6 +120,7 @@ def parse_args():
|
||||
default="text-inversion-model",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument("--save_steps", type=int, default=None, help="Save a checkpoint every X steps.")
|
||||
parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
@@ -332,7 +350,7 @@ def main():
|
||||
|
||||
if cur_class_images < args.num_class_images:
|
||||
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path, safety_checker=None
|
||||
args.pretrained_model_name_or_path, safety_checker=None, revision=args.revision
|
||||
)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
@@ -383,7 +401,11 @@ def main():
|
||||
if args.tokenizer_name:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
|
||||
elif args.pretrained_model_name_or_path:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("No tokenizer specified!")
|
||||
|
||||
train_dataset = DreamBoothDataset(
|
||||
instance_data_root=args.instance_data_dir,
|
||||
@@ -437,15 +459,23 @@ def main():
|
||||
elif args.mixed_precision == "bf16":
|
||||
weight_dtype = jnp.bfloat16
|
||||
|
||||
if args.pretrained_vae_name_or_path:
|
||||
# TODO(patil-suraj): Upload flax weights for the VAE
|
||||
vae_arg, vae_kwargs = (args.pretrained_vae_name_or_path, {"from_pt": True})
|
||||
else:
|
||||
vae_arg, vae_kwargs = (args.pretrained_model_name_or_path, {"subfolder": "vae", "revision": args.revision})
|
||||
|
||||
# Load models and create wrapper for stable diffusion
|
||||
text_encoder = FlaxCLIPTextModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype, revision=args.revision
|
||||
)
|
||||
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="vae", dtype=weight_dtype
|
||||
vae_arg,
|
||||
dtype=weight_dtype,
|
||||
**vae_kwargs,
|
||||
)
|
||||
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="unet", dtype=weight_dtype
|
||||
args.pretrained_model_name_or_path, subfolder="unet", dtype=weight_dtype, revision=args.revision
|
||||
)
|
||||
|
||||
# Optimization
|
||||
@@ -597,6 +627,39 @@ def main():
|
||||
logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||
|
||||
def checkpoint(step=None):
|
||||
# Create the pipeline using the trained modules and save it.
|
||||
scheduler, _ = FlaxPNDMScheduler.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", subfolder="scheduler"
|
||||
)
|
||||
safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained(
|
||||
"CompVis/stable-diffusion-safety-checker", from_pt=True
|
||||
)
|
||||
pipeline = FlaxStableDiffusionPipeline(
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
||||
)
|
||||
|
||||
outdir = os.path.join(args.output_dir, str(step)) if step else args.output_dir
|
||||
pipeline.save_pretrained(
|
||||
outdir,
|
||||
params={
|
||||
"text_encoder": get_params_to_save(text_encoder_state.params),
|
||||
"vae": get_params_to_save(vae_params),
|
||||
"unet": get_params_to_save(unet_state.params),
|
||||
"safety_checker": safety_checker.params,
|
||||
},
|
||||
)
|
||||
|
||||
if args.push_to_hub:
|
||||
message = f"checkpoint-{step}" if step is not None else "End of training"
|
||||
repo.push_to_hub(commit_message=message, blocking=False, auto_lfs_prune=True)
|
||||
|
||||
global_step = 0
|
||||
|
||||
epochs = tqdm(range(args.num_train_epochs), desc="Epoch ... ", position=0)
|
||||
@@ -615,9 +678,11 @@ def main():
|
||||
)
|
||||
train_metrics.append(train_metric)
|
||||
|
||||
train_step_progress_bar.update(1)
|
||||
train_step_progress_bar.update(jax.local_device_count())
|
||||
|
||||
global_step += 1
|
||||
if jax.process_index() == 0 and args.save_steps and global_step % args.save_steps == 0:
|
||||
checkpoint(global_step)
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
@@ -626,36 +691,8 @@ def main():
|
||||
train_step_progress_bar.close()
|
||||
epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})")
|
||||
|
||||
# Create the pipeline using using the trained modules and save it.
|
||||
if jax.process_index() == 0:
|
||||
scheduler = FlaxPNDMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
|
||||
)
|
||||
safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained(
|
||||
"CompVis/stable-diffusion-safety-checker", from_pt=True
|
||||
)
|
||||
pipeline = FlaxStableDiffusionPipeline(
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
||||
)
|
||||
|
||||
pipeline.save_pretrained(
|
||||
args.output_dir,
|
||||
params={
|
||||
"text_encoder": get_params_to_save(text_encoder_state.params),
|
||||
"vae": get_params_to_save(vae_params),
|
||||
"unet": get_params_to_save(unet_state.params),
|
||||
"safety_checker": safety_checker.params,
|
||||
},
|
||||
)
|
||||
|
||||
if args.push_to_hub:
|
||||
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
|
||||
checkpoint()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user