Various Fixes for Flax Dreambooth (#1782)
* Various Fixes for Flax Dreambooth - Correctly update the progress bar every epoch - Allow specifying a pretrained VAE - Allow specifying a revision to pretrained models - Cache compiled models between invocations (speeds up TPU execution a lot!) - Save intermediate checkpoints by specifying `save_steps` * Don't die when save_steps is not set. * Address CR * Address comments * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Suraj Patil <surajp815@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
@@ -28,6 +28,7 @@ from flax import jax_utils
|
|||||||
from flax.training import train_state
|
from flax.training import train_state
|
||||||
from flax.training.common_utils import shard
|
from flax.training.common_utils import shard
|
||||||
from huggingface_hub import HfFolder, Repository, whoami
|
from huggingface_hub import HfFolder, Repository, whoami
|
||||||
|
from jax.experimental.compilation_cache import compilation_cache as cc
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
@@ -37,6 +38,9 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel,
|
|||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.10.0.dev0")
|
check_min_version("0.10.0.dev0")
|
||||||
|
|
||||||
|
# Cache compiled models across invocations of this script.
|
||||||
|
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -49,6 +53,19 @@ def parse_args():
|
|||||||
required=True,
|
required=True,
|
||||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pretrained_vae_name_or_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to pretrained vae or vae identifier from huggingface.co/models.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--revision",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
required=False,
|
||||||
|
help="Revision of pretrained model identifier from huggingface.co/models.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--tokenizer_name",
|
"--tokenizer_name",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -103,6 +120,7 @@ def parse_args():
|
|||||||
default="text-inversion-model",
|
default="text-inversion-model",
|
||||||
help="The output directory where the model predictions and checkpoints will be written.",
|
help="The output directory where the model predictions and checkpoints will be written.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--save_steps", type=int, default=None, help="Save a checkpoint every X steps.")
|
||||||
parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.")
|
parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--resolution",
|
"--resolution",
|
||||||
@@ -332,7 +350,7 @@ def main():
|
|||||||
|
|
||||||
if cur_class_images < args.num_class_images:
|
if cur_class_images < args.num_class_images:
|
||||||
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||||
args.pretrained_model_name_or_path, safety_checker=None
|
args.pretrained_model_name_or_path, safety_checker=None, revision=args.revision
|
||||||
)
|
)
|
||||||
pipeline.set_progress_bar_config(disable=True)
|
pipeline.set_progress_bar_config(disable=True)
|
||||||
|
|
||||||
@@ -383,7 +401,11 @@ def main():
|
|||||||
if args.tokenizer_name:
|
if args.tokenizer_name:
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
|
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
|
||||||
elif args.pretrained_model_name_or_path:
|
elif args.pretrained_model_name_or_path:
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
|
tokenizer = CLIPTokenizer.from_pretrained(
|
||||||
|
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("No tokenizer specified!")
|
||||||
|
|
||||||
train_dataset = DreamBoothDataset(
|
train_dataset = DreamBoothDataset(
|
||||||
instance_data_root=args.instance_data_dir,
|
instance_data_root=args.instance_data_dir,
|
||||||
@@ -437,15 +459,23 @@ def main():
|
|||||||
elif args.mixed_precision == "bf16":
|
elif args.mixed_precision == "bf16":
|
||||||
weight_dtype = jnp.bfloat16
|
weight_dtype = jnp.bfloat16
|
||||||
|
|
||||||
|
if args.pretrained_vae_name_or_path:
|
||||||
|
# TODO(patil-suraj): Upload flax weights for the VAE
|
||||||
|
vae_arg, vae_kwargs = (args.pretrained_vae_name_or_path, {"from_pt": True})
|
||||||
|
else:
|
||||||
|
vae_arg, vae_kwargs = (args.pretrained_model_name_or_path, {"subfolder": "vae", "revision": args.revision})
|
||||||
|
|
||||||
# Load models and create wrapper for stable diffusion
|
# Load models and create wrapper for stable diffusion
|
||||||
text_encoder = FlaxCLIPTextModel.from_pretrained(
|
text_encoder = FlaxCLIPTextModel.from_pretrained(
|
||||||
args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype
|
args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype, revision=args.revision
|
||||||
)
|
)
|
||||||
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
|
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
|
||||||
args.pretrained_model_name_or_path, subfolder="vae", dtype=weight_dtype
|
vae_arg,
|
||||||
|
dtype=weight_dtype,
|
||||||
|
**vae_kwargs,
|
||||||
)
|
)
|
||||||
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
|
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
|
||||||
args.pretrained_model_name_or_path, subfolder="unet", dtype=weight_dtype
|
args.pretrained_model_name_or_path, subfolder="unet", dtype=weight_dtype, revision=args.revision
|
||||||
)
|
)
|
||||||
|
|
||||||
# Optimization
|
# Optimization
|
||||||
@@ -597,6 +627,39 @@ def main():
|
|||||||
logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
|
logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
|
||||||
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||||
|
|
||||||
|
def checkpoint(step=None):
|
||||||
|
# Create the pipeline using the trained modules and save it.
|
||||||
|
scheduler, _ = FlaxPNDMScheduler.from_pretrained(
|
||||||
|
"CompVis/stable-diffusion-v1-4", subfolder="scheduler"
|
||||||
|
)
|
||||||
|
safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained(
|
||||||
|
"CompVis/stable-diffusion-safety-checker", from_pt=True
|
||||||
|
)
|
||||||
|
pipeline = FlaxStableDiffusionPipeline(
|
||||||
|
text_encoder=text_encoder,
|
||||||
|
vae=vae,
|
||||||
|
unet=unet,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
scheduler=scheduler,
|
||||||
|
safety_checker=safety_checker,
|
||||||
|
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
||||||
|
)
|
||||||
|
|
||||||
|
outdir = os.path.join(args.output_dir, str(step)) if step else args.output_dir
|
||||||
|
pipeline.save_pretrained(
|
||||||
|
outdir,
|
||||||
|
params={
|
||||||
|
"text_encoder": get_params_to_save(text_encoder_state.params),
|
||||||
|
"vae": get_params_to_save(vae_params),
|
||||||
|
"unet": get_params_to_save(unet_state.params),
|
||||||
|
"safety_checker": safety_checker.params,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.push_to_hub:
|
||||||
|
message = f"checkpoint-{step}" if step is not None else "End of training"
|
||||||
|
repo.push_to_hub(commit_message=message, blocking=False, auto_lfs_prune=True)
|
||||||
|
|
||||||
global_step = 0
|
global_step = 0
|
||||||
|
|
||||||
epochs = tqdm(range(args.num_train_epochs), desc="Epoch ... ", position=0)
|
epochs = tqdm(range(args.num_train_epochs), desc="Epoch ... ", position=0)
|
||||||
@@ -615,9 +678,11 @@ def main():
|
|||||||
)
|
)
|
||||||
train_metrics.append(train_metric)
|
train_metrics.append(train_metric)
|
||||||
|
|
||||||
train_step_progress_bar.update(1)
|
train_step_progress_bar.update(jax.local_device_count())
|
||||||
|
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
if jax.process_index() == 0 and args.save_steps and global_step % args.save_steps == 0:
|
||||||
|
checkpoint(global_step)
|
||||||
if global_step >= args.max_train_steps:
|
if global_step >= args.max_train_steps:
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -626,36 +691,8 @@ def main():
|
|||||||
train_step_progress_bar.close()
|
train_step_progress_bar.close()
|
||||||
epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})")
|
epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})")
|
||||||
|
|
||||||
# Create the pipeline using using the trained modules and save it.
|
|
||||||
if jax.process_index() == 0:
|
if jax.process_index() == 0:
|
||||||
scheduler = FlaxPNDMScheduler(
|
checkpoint()
|
||||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
|
|
||||||
)
|
|
||||||
safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained(
|
|
||||||
"CompVis/stable-diffusion-safety-checker", from_pt=True
|
|
||||||
)
|
|
||||||
pipeline = FlaxStableDiffusionPipeline(
|
|
||||||
text_encoder=text_encoder,
|
|
||||||
vae=vae,
|
|
||||||
unet=unet,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
scheduler=scheduler,
|
|
||||||
safety_checker=safety_checker,
|
|
||||||
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
|
||||||
)
|
|
||||||
|
|
||||||
pipeline.save_pretrained(
|
|
||||||
args.output_dir,
|
|
||||||
params={
|
|
||||||
"text_encoder": get_params_to_save(text_encoder_state.params),
|
|
||||||
"vae": get_params_to_save(vae_params),
|
|
||||||
"unet": get_params_to_save(unet_state.params),
|
|
||||||
"safety_checker": safety_checker.params,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if args.push_to_hub:
|
|
||||||
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user