diff --git a/examples/dreambooth/train_dreambooth_flax.py b/examples/dreambooth/train_dreambooth_flax.py index 3264aa726b..0d235e2c29 100644 --- a/examples/dreambooth/train_dreambooth_flax.py +++ b/examples/dreambooth/train_dreambooth_flax.py @@ -629,9 +629,7 @@ def main(): def checkpoint(step=None): # Create the pipeline using the trained modules and save it. - scheduler, _ = FlaxPNDMScheduler.from_pretrained( - "CompVis/stable-diffusion-v1-4", subfolder="scheduler" - ) + scheduler, _ = FlaxPNDMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler") safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained( "CompVis/stable-diffusion-safety-checker", from_pt=True )