Compare commits

...

78 Commits

Author SHA1 Message Date
Dhruv Nair 041415a076 pinn ruff 2023-12-05 12:05:34 +00:00
Steven Liu 4684ea2fe8 [docs] #Copied from mechanism (#6007)
* copied from section

* feedback
2023-12-04 10:12:52 -08:00
Steven Liu b64f835ea7 [docs] Add Kandinsky 3 (#5988)
* add

* fix api docs

* edits
2023-12-04 10:11:15 -08:00
Linoy Tsaban 880c0fdd36 [advanced dreambooth lora training script][bug_fix] change token_abstraction type to str (#6040)
* improve help tags

* style fix

* changes token_abstraction type to string.
support multiple concepts for pivotal using a comma separated string.

* style fixup

* changed logger to warning (not yet available)

* moved the token_abstraction parsing to be in the same block as where we create the mapping of identifier to token

---------

Co-authored-by: Linoy <linoy@huggingface.co>
2023-12-04 18:38:44 +01:00
RuoyiDu c36f1c3160 [Community Pipeline] DemoFusion: Democratising High-Resolution Image Generation With No $$$ (#6022)
* Add files via upload

* Update README.md

* Update pipeline_demofusion_sdxl.py

* Update pipeline_demofusion_sdxl.py

* Update examples/community/README.md

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2023-12-04 19:44:57 +05:30
takuoko 0a08d41961 [Feature] Support IP-Adapter Plus (#5915)
* Support IP-Adapter Plus

* fix format

* restore before black format

* restore before black format

* generic

* Refactor PerceiverAttention

* format

* fix test and refactor PerceiverAttention

* generic encode_image

* keep attention implementation

* merge tests

* encode_image backward compatible

* code quality

* fix controlnet inpaint pipeline

* refactor FFN

* refactor FFN

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
2023-12-04 12:43:34 +01:00
Levi McCallum e185084a5d Add variant argument to dreambooth lora sdxl advanced (#6021)
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2023-12-04 12:04:15 +01:00
Dhruv Nair b21729225a Update Tests Fetcher (#5950)
* update setup and deps table

* update

* update

* update

* up

* up

* update

* up

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* quality fix

* fix failure reporting

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2023-12-04 12:59:41 +05:30
Parth38 8a812e4e14 Update value_guided_sampling.py (#6027)
* Update value_guided_sampling.py

Changed the scheduler step function as predict_epsilon parameter is not there in latest  DDPM Scheduler

* Update value_guided_sampling.md

Updated a link to a working notebook

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2023-12-04 10:36:25 +05:30
gujing bf92e746c0 fix StableDiffusionTensorRT super args error (#6009) 2023-12-04 10:06:23 +05:30
Linoy Tsaban b785a155d6 [advanced dreambooth lora sdxl training script] improve help tags (#6035)
* improve help tags

* style fix

---------

Co-authored-by: Linoy <linoy@huggingface.co>
2023-12-04 09:41:25 +05:30
Sayak Paul d486f0e846 [LoRA serialization] fix: duplicate unet prefix problem. (#5991)
* fix: duplicate unet prefix problem.

* Update src/diffusers/loaders/lora.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2023-12-02 21:35:16 +05:30
Sayak Paul 3351270627 [PixArt Tests] remove fast tests from slow suite (#5945)
remove fast tests from slow suite
2023-12-02 20:58:27 +05:30
Junsong Chen 4520e1221a adapt PixArtAlphaPipeline for pixart-lcm model (#5974)
* adapt PixArtAlphaPipeline for pixart-lcm model

* remove original_inference_steps from __call__

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2023-12-02 13:30:40 +05:30
Long(Tony) Lian 618260409f LLMGroundedDiffusionPipeline: inherit from DiffusionPipeline and fix peft (#6023)
* LLMGroundedDiffusionPipeline: inherit from DiffusionPipeline and fix peft

* Use main in the revision in the examples

* Add "Copied from" statements in comments

* Fix formatting with ruff
2023-12-01 09:58:25 -10:00
Patrick von Platen dadd55fb36 Post Release: v0.24.0 (#5985)
* Post Release: v0.24.0

* post pone deprecation

* post pone deprecation

* Add model_index.json
2023-12-01 18:43:44 +01:00
YiYi Xu 1b6c7ea74e [schedulers] create self.sigmas during __init__ (#6006)
* fix dpm
* all scheulers
2023-12-01 07:15:37 -10:00
YiYi Xu b41f809a4e [Kandinsky 3.0] Follow-up TODOs (#5944)
clean-up kendinsky 3.0
2023-12-01 07:14:22 -10:00
Patrick von Platen 0f55c17e17 fix style 2023-12-01 15:59:34 +00:00
Charchit Sharma 5058d27f12 added attention_head_dim, attention_type, resolution_idx (#6011) 2023-12-01 16:26:58 +01:00
M. Tolga Cangöz 748c1b3ec7 [Docs] Update a link (#6014)
* Update the location of Python's version

* Trim trailing whitespace
2023-12-01 16:26:25 +01:00
M. Tolga Cangöz 523507034f [logging] Fix assertion bug (#6012)
Fix assertion bug
2023-12-01 16:26:04 +01:00
hako-mikan 46c751e970 [Community Pipeline] Regional Prompting Pipeline (#6015)
* Update README.md

* Update README.md

* Add files via upload

* Update README.md

* Update examples/community/README.md

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2023-12-01 16:22:59 +01:00
Patrick von Platen bc1d28c888 [From Single File] Allow Text Encoder to be passed (#6020)
Allow text encoder to be passed
2023-12-01 16:19:04 +01:00
Sayak Paul af378c1dd1 [Easy] minor edits to setup.py (#5996)
minor edits to setup
2023-12-01 20:38:46 +05:30
Steven Liu 6ba4c5395f [docs] Fix SVD video (#6004)
Update svd.md
2023-12-01 16:07:47 +01:00
Linoy Tsaban c1e4529541 [advanced_dreambooth_lora_sdxl_tranining_script] readme fix (#6019)
readme
2023-12-01 15:14:57 +01:00
Linoy Tsaban d29d97b616 [examples/advanced_diffusion_training] bug fixes and improvements for LoRA Dreambooth SDXL advanced training script (#5935)
* imports and readme bug fixes

* bug fix - ensures text_encoder params are dtype==float32 (when using pivotal tuning) even if the rest of the model is loaded in fp16

* added pivotal tuning to readme

* mapping token identifier to new inserted token in validation prompt (if used)

* correct default value of --train_text_encoder_frac

* change default value of  --adam_weight_decay_text_encoder

* validation prompt generations when using pivotal tuning bug fix

* style fix

* textual inversion embeddings name change

* style fix

* bug fix - stopping text encoder optimization halfway

* readme - will include token abstraction and new inserted tokens when using pivotal tuning
- added type to --num_new_tokens_per_abstraction

* style fix

---------

Co-authored-by: Linoy Tsaban <linoy@huggingface.co>
2023-12-01 14:18:43 +01:00
Jongho Choi 7d4a257c7f Remove a duplicated line? (#6010)
Update __init__.py
2023-12-01 15:49:36 +05:30
Kristian Mischke 141cd52d56 Fix LLMGroundedDiffusionPipeline super class arguments (#5993)
* make `requires_safety_checker` a kwarg instead of a positional argument as it's more future-proof

* apply `make style` formatting edits

* add image_encoder to arguments and pass to super constructor
2023-11-30 10:15:14 -10:00
Steven Liu f72b28c75b [docs] Fix video link (#5986)
Update svd.md
2023-11-29 20:52:25 +01:00
Suraj Patil ada8109d5b Fix SVD doc (#5983)
fix url
2023-11-29 19:55:05 +01:00
Patrick von Platen b34acbdcbc [SDXL Turbo] Add some docs (#5982)
* add diffusers example

* add diffusers example

* Comment about making it faster

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

---------

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
2023-11-29 19:52:07 +01:00
Suraj Patil 63f767ef15 Add SVD (#5895)
* begin model

* finish blocks

* add_embedding

* addition_time_embed_dim

* use TimestepEmbedding

* fix temporal res block

* fix time_pos_embed

* fix add_embedding

* add conversion script

* fix model

* up

* add new resnet blocks

* make forward work

* return sample in original shape

* fix temb shape in TemporalResnetBlock

* add spatio temporal transformers

* add vae blocks

* fix blocks

* update

* update

* fix shapes in Alphablender and add time activation in res blcok

* use new blocks

* style

* fix temb shape

* fix SpatioTemporalResBlock

* reuse TemporalBasicTransformerBlock

* fix TemporalBasicTransformerBlock

* use TransformerSpatioTemporalModel

* fix TransformerSpatioTemporalModel

* fix time_context dim

* clean up

* make temb optional

* add blocks

* rename model

* update conversion script

* remove UNetMidBlockSpatioTemporal

* add in init

* remove unused arg

* remove unused arg

* remove more unsed args

* up

* up

* check for None

* update vae

* update up/mid blocks for decoder

* begin pipeline

* adapt scheduler

* add guidance scalings

* fix norm eps in temporal transformers

* add temporal autoencoder

* make pipeline run

* fix frame decodig

* decode in float32

* decode n frames at a time

* pass decoding_t to decode_latents

* fix decode_latents

* vae encode/decode in fp32

* fix dtype in TransformerSpatioTemporalModel

* type image_latents same as image_embeddings

* allow using differnt eps in temporal block for video decoder

* fix default values in vae

* pass num frames in decode

* switch spatial to temporal for mixing in VAE

* fix num frames during split decoding

* cast alpha to sample dtype

* fix attention in MidBlockTemporalDecoder

* fix typo

* fix guidance_scales dtype

* fix missing activation in TemporalDecoder

* skip_post_quant_conv

* add vae conversion

* style

* take guidance scale as input

* up

* allow passing PIL to export_video

* accept fps as arg

* add pipeline and vae in init

* remove hack

* use AutoencoderKLTemporalDecoder

* don't scale image latents

* add unet tests

* clean up unet

* clean TransformerSpatioTemporalModel

* add slow svd test

* clean up

* make temb optional in Decoder mid block

* fix norm eps in TransformerSpatioTemporalModel

* clean up temp decoder

* clean up

* clean up

* use c_noise values for timesteps

* use math for log

* update

* fix copies

* doc

* upcast vae

* update forward pass for gradient checkpointing

* make added_time_ids is tensor

* up

* fix upcasting

* remove post quant conv

* add _resize_with_antialiasing

* fix _compute_padding

* cleanup model

* more cleanup

* more cleanup

* more cleanup

* remove freeu

* remove attn slice

* small clean

* up

* up

* remove extra step kwargs

* remove eta

* remove dropout

* remove callback

* remove merge factor args

* clean

* clean up

* move to dedicated folder

* remove attention_head_dim

* docstr and small fix

* update unet doc strings

* rename decoding_t

* correct linting

* store c_skip and c_out

* cleanup

* clean TemporalResnetBlock

* more cleanup

* clean up vae

* clean up

* begin doc

* more cleanup

* up

* up

* doc

* Improve

* better naming

* better naming

* better naming

* better naming

* better naming

* better naming

* better naming

* better naming

* Apply suggestions from code review

* Default chunk size to None

* add example

* Better

* Apply suggestions from code review

* update doc

* Update src/diffusers/pipelines/stable_diffusion_video/pipeline_stable_diffusion_video.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* style

* Get torch compile working

* up

* rename

* fix doc

* add chunking

* torch compile

* torch compile

* add modelling outputs

* torch compile

* Improve chunking

* Apply suggestions from code review

* Update docs/source/en/using-diffusers/svd.md

* Close diff tag

* remove slicing

* resnet docstr

* add docstr in resnet

* rename

* Apply suggestions from code review

* update tests

* Fix output type latents

* fix more

* fix more

* Update docs/source/en/using-diffusers/svd.md

* fix more

* add pipeline tests

* remove unused arg

* clean  up

* make sure get_scaling receives tensors

* fix euler scheduler

* fix get_scalings

* simply euler for now

* remove old test file

* use randn_tensor to create noise

* fix device for rand tensor

* increase expected_max_difference

* fix test_inference_batch_single_identical

* actually fix test_inference_batch_single_identical

* disable test_save_load_float16

* skip test_float16_inference

* skip test_inference_batch_single_identical

* fix test_xformers_attention_forwardGenerator_pass

* Apply suggestions from code review

* update StableVideoDiffusionPipelineSlowTests

* update image

* add diffusers example

* fix more

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: apolinário <joaopaulo.passos@gmail.com>
2023-11-29 19:13:36 +01:00
PENGUINLIONG d1b2a1a957 Fixed custom module importing on Windows (#5891)
* Fixed custom module importing on Windows

Windows use back slash and `os.path.join()` follows that convention.

* Apply suggestions from code review

Co-authored-by: Lucain <lucainp@gmail.com>

* Update pipeline_utils.py

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Lucain <lucainp@gmail.com>
2023-11-29 16:33:04 +01:00
Kashif Rasul 01782c220e [Wuerstchen] Adapt lora training example scripts to use PEFT (#5959)
* Adapt lora example scripts to use PEFT

* add to_out.0
2023-11-29 16:18:20 +01:00
vahramtadevosyan d63a498c3b [Pipeline] Add TextToVideoZeroSDXLPipeline (#4695)
* integrated sdxl for the text2video-zero pipeline

* make fix-copies

* fixed CI issues

* make fix-copies

* added docs and `copied from` statements

* added fast tests

* made a small change in docs

* quality+style check fix

* updated docs. added controlnet inference with sdxl

* added device compatibility for fast tests

* fixed docstrings

* changing vae upcasting

* remove torch.empty_cache to speed up inference

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* made fast tests to run on dummy models only, fixed copied from statements

* fixed testing utils imports

* Added bullet points for SDXL support

* fixed formatting & quality

* Update tests/pipelines/text_to_video/test_text_to_video_zero_sdxl.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update tests/pipelines/text_to_video/test_text_to_video_zero_sdxl.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* fixed minor error for merging

* fixed updates of sdxl

* made fast tests inherit from `PipelineTesterMixin` and run in 3-4secs on CPU

* make style && make quality

* reimplemented fast tests w/o default attn processor

* make style & make quality

* make fix-copies

* make fix-copies

* fixed docs

* make style & make quality & make fix-copies

* bug fix in cross attention

* make style && make quality

* make fix-copies

* fix gpu issues

* make fix-copies

* updated pipeline signature

---------

Co-authored-by: Vahram <vahram.tadevosyan@lambda-loginnode02.cm.cluster>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2023-11-29 16:10:43 +01:00
Marko Kostiv 6a4aad43dc Controlnet ssd 1b support (#5779)
* Add SSD-1B support for controlnet model

* Add conditioning_channels into ControlNet init from unet

* Fix black formatting

* Isort fixes

* Adds SSD-1B controlnet pipeline test with UNetMidBlock2D as mid block

* Overrides failing ssd-1b tests

* Fixes tests after main branch update

* Fixes code quality checks

---------

Co-authored-by: Marko Kostiv <marko@linearity.io>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2023-11-29 16:10:01 +01:00
Steven Liu ddd8bd53ed [docs] LCM training (#5796)
* first draft

* feedback
2023-11-29 16:08:05 +01:00
JuanCarlosPi 9f7b2cf2dc Support of ip-adapter to the StableDiffusionControlNetInpaintPipeline (#5887)
* Change pipeline_controlnet_inpaint.py to add ip-adapter support. Changes are similar to those in pipeline_controlnet

* Change tests for the StableDiffusionControlNetInpaintPipeline by adding image_encoder: None

* Update src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
2023-11-29 16:00:24 +01:00
Sayak Paul 895c4b704b [LoRA refactor] move several state dict conversion utils out of lora.py (#5955)
* move several state dict conversion utils out of lora.py

* check

* check

* check

* check

* check

* check

* check

* revert back

* check

* check

* again check

* maybe fix?

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2023-11-29 20:24:21 +05:30
Linh Nguyen 636feba552 Rename output_dir argument (#5916)
Fix typo in output_dir argument: "text-inversion-model" → "dreambooth-model"
2023-11-29 15:47:16 +01:00
Andrés Romero 79dc7df03e [bug fix] Inpainting for MultiAdapter (#5922)
* bug in MultiAdapter for Inpainting

* adapter_input is a list for MultiAdapter

---------

Co-authored-by: andres <andres@hax.ai>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2023-11-29 15:46:26 +01:00
Charchit Sharma 6031ecbd23 added doc for Kandinsky3.0 (#5937)
* added en doc for Kandinsky3.0

* required changes

* Update docs/source/en/api/pipelines/kandinsky3.md

* Update docs/source/en/api/pipelines/kandinsky3.md

* Update docs/source/en/api/pipelines/kandinsky3.md

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2023-11-29 15:32:00 +01:00
Sayak Paul fdd003d8e2 [Tests] Refactor test_examples.py for better readability (#5946)
* control and custom diffusion

* dreambooth

* instructpix2pix and dreambooth ckpting

* t2i adapters.

* text to image ft

* textual inversion

* unconditional

* workflows

* import fix

* fix import
2023-11-29 18:43:59 +05:30
Steven Liu 172acc98b9 [docs] Update pipeline list (#5952)
add to list
2023-11-29 14:08:39 +01:00
estelleafl 5ae3c3a56b [ldm3d] Ldm3d upscaler to community pipeline (#5870)
---------
Co-authored-by: Aflalo <estellea@isl-gpu27.rr.intel.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2023-11-28 09:00:39 -10:00
Soumik Rakshit 21bc59ab24 fix: minor typo in docstring (#5961) 2023-11-28 18:18:34 +05:30
Steven Liu 50a749e909 [docs] Fix space (#5898)
* fix

* minor edits
2023-11-27 11:50:59 -08:00
YiYi Xu d9075be494 [load_textual_inversion]: allow multiple tokens (#5837)
Co-authored-by: yiyixuxu <yixu310@gmail,com>
2023-11-27 06:52:36 -10:00
Patrick von Platen b135b6e905 [From_pretrained] Fix warning (#5948) 2023-11-27 14:35:19 +01:00
T. Xu 14a0d21d2e [Community Pipeline] Diffusion Posterior Sampling for General Noisy Inverse Problems (#5939)
* [community pipeline] dps impl

* add type checking

* pass ruff check

* ruff formatter
2023-11-27 14:29:42 +01:00
Patrick von Platen ebf581e85f [Tests] Make sure that we don't run tests multiple times (#5949)
* [Tests] Make sure that we don't run tests mulitple times

* [Tests] Make sure that we don't run tests mulitple times

* [Tests] Make sure that we don't run tests mulitple times
2023-11-27 14:18:56 +01:00
Patrick von Platen e550163b9f [Vae] Make sure all vae's work with latent diffusion models (#5880)
* add comments to explain the code better

* add comments to explain the code better

* add comments to explain the code better

* add comments to explain the code better

* add comments to explain the code better

* fix more

* fix more

* fix more

* fix more

* fix more

* fix more
2023-11-27 14:17:47 +01:00
Viktor Grygorchuk 20f0cbc88f fix: error on device for lpw_stable_diffusion_xl pipeline if pipe.enable_sequential_cpu_offload() enabled (#5885)
fix: set device for pipe.enable_sequential_cpu_offload()
2023-11-27 13:47:47 +01:00
Chi d72a24b790 Replace multiple variables with one variable. (#5715)
* I added a new doc string to the class. This is more flexible to understanding other developers what are doing and where it's using.

* Update src/diffusers/models/unet_2d_blocks.py

This changes suggest by maintener.

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/models/unet_2d_blocks.py

Add suggested text

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update unet_2d_blocks.py

I changed the Parameter to Args text.

* Update unet_2d_blocks.py

proper indentation set in this file.

* Update unet_2d_blocks.py

a little bit of change in the act_fun argument line.

* I run the black command to reformat style in the code

* Update unet_2d_blocks.py

similar doc-string add to have in the original diffusion repository.

* I enhanced the code by replacing multiple redundant variables with a single variable, as they all served the same purpose. Additionally, I utilized the get_activation function for improved flexibility in choosing activation functions.

* Using as black package to reformated my file

* reverte some changes

* Remove conv_out_padding variables and using as conv_in_padding

* conv_out_padding create and add them into the code.

* run black command to solving styling problem

* add little bit space between comment and import statement

* I am utilizing the ruff library to address the style issues in my Makefile.

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2023-11-27 13:34:52 +01:00
ginjia d3cda804e7 add LoRA weights load and fuse support for IPEX pipeline (#5920)
add IPEX pipeline LoRA weights loading support
2023-11-27 13:32:43 +01:00
dg845 07eac4d65a Fix LCM Stable Diffusion distillation bug related to parsing unet_time_cond_proj_dim (#5893)
* Fix bug related to parsing unet_time_cond_proj_dim.

* Fix analogous bug in the SD-XL LCM distillation script.
2023-11-27 13:00:40 +01:00
Iván de Prado c079cae3d4 Avoid computing min() that is expensive when do_normalize is False in the image processor (#5896)
Avoid computing min() that is expensive when do_normalize is False

Avoid extra computing when do_normalize is False
2023-11-27 12:46:26 +01:00
Wang, Yi c7bfb8b22a set the model to train state before accelerator prepare (#5099)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2023-11-27 12:43:49 +01:00
dg845 67d070749a Add Custom Timesteps Support to LCMScheduler and Supported Pipelines (#5874)
* Add custom timesteps support to LCMScheduler.

* Add custom timesteps support to StableDiffusionPipeline.

* Add custom timesteps support to StableDiffusionXLPipeline.

* Add custom timesteps support to remaining Stable Diffusion pipelines which support LCMScheduler (img2img, inpaint).

* Add custom timesteps support to remaining Stable Diffusion XL pipelines which support LCMScheduler (img2img, inpaint).

* Add custom timesteps support to StableDiffusionControlNetPipeline.

* Add custom timesteps support to T21 Stable Diffusion (XL) Adapters.

* Clean up Stable Diffusion inpaint tests.

* Manually add support for custom timesteps to AltDiffusion pipelines since make fix-copies doesn't appear to work correctly (it deletes the whole pipeline).

* make style

* Refactor pipeline timestep handling into the retrieve_timesteps function.
2023-11-27 12:39:14 +01:00
Aryan V S 9c357bda3f Deprecate KarrasVeScheduler and ScoreSdeVpScheduler (#5269)
* deprecated: KarrasVeScheduler, ScoreSdeVpScheduler

* delete tests relevant to deprecated schedulers

* chore: run make style

* fix: import error caused due to incorrect _import_structure after deprecation

* fix: ScoreSdeVpScheduler was not importable from diffusers

* remove import added by assumption

* Update src/diffusers/schedulers/__init__.py as suggested by @patrickvonplaten

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* make it a part deprecated

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Fix

* fix

* fix doc

* fix doc....again.......

* remove karras_ve test folder

Co-Authored-By: YiYi Xu <yixu310@gmail.com>

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: yiyixuxu <yixu310@gmail,com>
2023-11-27 12:33:02 +01:00
Sayak Paul 3f7c3511dc [Core] add support for gradient checkpointing in transformer_2d (#5943)
add support for gradient checkpointing in transformer_2d
2023-11-27 16:21:12 +05:30
Junsong Chen 7d6f30e89b [Fix: pixart-alpha] random 512px resolution bug (#5842)
* [Fix: pixart-alpha]
add ASPECT_RATIO_512_BIN in use_resolution_binning for random 512px image generation.

* add slow test file for 512px generation without resolution binning

* fix: slow tests for resolution binning.

---------

Co-authored-by: jschen <chenjunsong4@h-partners.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2023-11-27 13:35:35 +05:30
Patrick von Platen 6d2e19f746 [Examples] Allow downloading variant model files (#5531)
* add variant

* add variant

* Apply suggestions from code review

* reformat

* fix: textual_inversion.py

* fix: variant in model_info

---------

Co-authored-by: sayakpaul <spsayakpaul@gmail.com>
2023-11-27 10:43:20 +05:30
Patrick von Platen 2a7f43a73b correct num inference steps 2023-11-24 17:09:26 +00:00
Patrick von Platen b978334d71 [@cene555][Kandinsky 3.0] Add Kandinsky 3.0 (#5913)
* finalize

* finalize

* finalize

* add slow test

* add slow test

* add slow test

* Fix more

* add slow test

* fix more

* fix more

* fix more

* fix more

* fix more

* fix more

* fix more

* fix more

* fix more

* Better

* Fix more

* Fix more

* add slow test

* Add auto pipelines

* add slow test

* Add all

* add slow test

* add slow test

* add slow test

* add slow test

* add slow test

* Apply suggestions from code review

* add slow test

* add slow test
2023-11-24 17:46:00 +01:00
Sayak Paul e5f232f76b [Docs] add: 8bit inference with pixart alpha (#5814)
* add: 8bit inference with pixart alpha

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* add: note on 4bit.

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* address comment

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2023-11-24 20:36:33 +05:30
Linoy Tsaban 3003ff4947 [bug fix] fix small bug in readme template of sdxl lora training script (#5914)
readme improvement and metadata fix
2023-11-23 19:08:49 +01:00
Linoy Tsaban 5ffa603244 [bug fix] fix small bug in readme template of sdxl lora training script (#5906)
* readme bug fix

* style fix

---------

Co-authored-by: Linoy Tsaban <linoy@huggingface.co>
2023-11-23 12:11:50 +01:00
Linoy Tsaban 0eeee618cf Adds an advanced version of the SD-XL DreamBooth LoRA training script supporting pivotal tuning (#5883)
* sdxl dreambooth lora training script with pivotal tuning

* bug fix - args missing from parse_args

* code quality fixes

* comment unnecessary code from TokenEmbedding handler class

* fixup

---------

Co-authored-by: Linoy Tsaban <linoy@huggingface.co>
2023-11-22 16:27:56 +01:00
Andrés Romero 93f1a14cab ControlNet+Adapter pipeline, and ControlNet+Adapter+Inpaint pipeline (#5869)
* ControlNet+Adapter pipeline, and +Inpaint pipeline


---------

Co-authored-by: andres <andres@hax.ai>
2023-11-21 08:59:29 -10:00
Patrick von Platen 13d73d9303 [Lora] Seperate logic (#5809)
* [Lora] Seperate logic

* [Lora] Seperate logic

* [Lora] Seperate logic

* add comments to explain the code better

* add comments to explain the code better
2023-11-21 18:58:37 +01:00
YiYi Xu ba352aea29 [feat] IP Adapters (author @okotaku ) (#5713)
* add ip-adapter


---------

Co-authored-by: okotaku <to78314910@gmail.com>
Co-authored-by: sayakpaul <spsayakpaul@gmail.com>
Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2023-11-21 07:34:30 -10:00
Linoy Tsaban 6fac1369d0 Add features to the Dreambooth LoRA SDXL training script (#5508)
* Additions:
- support for different lr for text encoder
- support for Prodigy optimizer
- support for min snr gamma
- support for custom captions and dataset loading from the hub

* adjusted --caption_column behaviour (to -not- use the second column of the dataset by default if --caption_column is not provided)

* fixed --output_dir / --model_dir_name confusion

* added --repeats, --adam_weight_decay_text_encoder
+ some fixes

* Update examples/dreambooth/train_dreambooth_lora_sdxl.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update examples/dreambooth/train_dreambooth_lora_sdxl.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update examples/dreambooth/train_dreambooth_lora_sdxl.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* - import compute_snr from diffusers/training_utils.py
- cluster adamw together
- when using 'prodigy', if --train_text_encoder == True and --text_encoder_lr != --learning rate, changes the lr of the text encoders optimization params to be --learning_rate (otherwise errors)

* shape fixes when custom captions are used

* formatting and a little cleanup

* code styling

* --repeats default value fixed, changed to 1

* bug fix - removed redundant lines of embedding concatenation when using prior_preservation (that duplicated class_prompt embeddings)

* changed dataset loading logic according to the following usecases (to avoid unnecessary dependency on datasets)-
1. user provides --dataset_name
2. user provides local dir --instance_data_dir that contains a metadata .jsonl file
3. user provides local dir --instance_data_dir that contains only images
in cases [1,2] we import datasets and use load_dataset method, in case [3] we process the data same as in the original script setting

* styling fix

* arg name fix

* adjusted the --repeats logic

* -removed redundant arg and 'if' when loading local folder with prompts
-updated readme template
-some default val fixes
-custom caption tests

* image path fix for readme

* code style

* bug fix

* --caption_column arg

* readme fix

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Linoy Tsaban <linoy@huggingface.co>
2023-11-21 17:38:43 +01:00
Steven Liu 1093f9d615 [docs] MusicLDM (#5854)
* fix

* feedback
2023-11-21 15:27:41 +01:00
Aryan V S 81780882b8 Addition of new callbacks to controlnets (#5812)
* add new callbacks to src/diffusers/pipelines/controlnet/pipeline_controlnet.py

* update callbacks

* fix repeated kwarg

* update

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2023-11-21 15:22:20 +01:00
Dhruv Nair ebc7bedeb7 Add tests fetcher (#5848)
* add tests fetcher to utils

* add test fetcher

* update

* update

* remove unused dependency version check script

* update

* fix mistake

* update

* update

* update

* update

* update

* update

* update

* remove concurrency params

* update

* update

* update

* update

* update

* update

* move test fetcher to dedicated workflow
2023-11-21 18:01:44 +05:30
219 changed files with 29462 additions and 3345 deletions
+176
View File
@@ -0,0 +1,176 @@
name: Fast tests for PRs - Test Fetcher
on:
pull_request:
branches:
- main
push:
branches:
- ci-*
env:
DIFFUSERS_IS_CI: yes
OMP_NUM_THREADS: 4
MKL_NUM_THREADS: 4
PYTEST_TIMEOUT: 60
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
setup_pr_tests:
name: Setup PR Tests
runs-on: docker-cpu
container:
image: diffusers/diffusers-pytorch-cpu
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
defaults:
run:
shell: bash
outputs:
matrix: ${{ steps.set_matrix.outputs.matrix }}
test_map: ${{ steps.set_matrix.outputs.test_map }}
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Install dependencies
run: |
apt-get update && apt-get install libsndfile1-dev libgl1 -y
python -m pip install -e .[quality,test]
- name: Environment
run: |
python utils/print_env.py
echo $(git --version)
- name: Fetch Tests
run: |
python utils/tests_fetcher.py | tee test_preparation.txt
- name: Report fetched tests
uses: actions/upload-artifact@v3
with:
name: test_fetched
path: test_preparation.txt
- id: set_matrix
name: Create Test Matrix
# The `keys` is used as GitHub actions matrix for jobs, i.e. `models`, `pipelines`, etc.
# The `test_map` is used to get the actual identified test files under each key.
# If no test to run (so no `test_map.json` file), create a dummy map (empty matrix will fail)
run: |
if [ -f test_map.json ]; then
keys=$(python3 -c 'import json; fp = open("test_map.json"); test_map = json.load(fp); fp.close(); d = list(test_map.keys()); print(json.dumps(d))')
test_map=$(python3 -c 'import json; fp = open("test_map.json"); test_map = json.load(fp); fp.close(); print(json.dumps(test_map))')
else
keys=$(python3 -c 'keys = ["dummy"]; print(keys)')
test_map=$(python3 -c 'test_map = {"dummy": []}; print(test_map)')
fi
echo $keys
echo $test_map
echo "matrix=$keys" >> $GITHUB_OUTPUT
echo "test_map=$test_map" >> $GITHUB_OUTPUT
run_pr_tests:
name: Run PR Tests
needs: setup_pr_tests
if: contains(fromJson(needs.setup_pr_tests.outputs.matrix), 'dummy') != true
strategy:
fail-fast: false
max-parallel: 2
matrix:
modules: ${{ fromJson(needs.setup_pr_tests.outputs.matrix) }}
runs-on: docker-cpu
container:
image: diffusers/diffusers-pytorch-cpu
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
defaults:
run:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
run: |
apt-get update && apt-get install libsndfile1-dev libgl1 -y
python -m pip install -e .[quality,test]
python -m pip install accelerate
- name: Environment
run: |
python utils/print_env.py
- name: Run all selected tests on CPU
run: |
python -m pytest -n 2 --dist=loadfile -v --make-reports=${{ matrix.modules }}_tests_cpu ${{ fromJson(needs.setup_pr_tests.outputs.test_map)[matrix.modules] }}
- name: Failure short reports
if: ${{ failure() }}
continue-on-error: true
run: |
cat reports/${{ matrix.modules }}_tests_cpu_stats.txt
cat reports/${{ matrix.modules }}_tests_cpu_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v3
with:
name: ${{ matrix.modules }}_test_reports
path: reports
run_staging_tests:
strategy:
fail-fast: false
matrix:
config:
- name: Hub tests for models, schedulers, and pipelines
framework: hub_tests_pytorch
runner: docker-cpu
image: diffusers/diffusers-pytorch-cpu
report: torch_hub
name: ${{ matrix.config.name }}
runs-on: ${{ matrix.config.runner }}
container:
image: ${{ matrix.config.image }}
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
defaults:
run:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
run: |
apt-get update && apt-get install libsndfile1-dev libgl1 -y
python -m pip install -e .[quality,test]
- name: Environment
run: |
python utils/print_env.py
- name: Run Hub tests for models, schedulers, and pipelines on a staging env
if: ${{ matrix.config.framework == 'hub_tests_pytorch' }}
run: |
HUGGINGFACE_CO_STAGING=true python -m pytest \
-m "is_staging_test" \
--make-reports=tests_${{ matrix.config.report }} \
tests
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v2
with:
name: pr_${{ matrix.config.report }}_test_reports
path: reports
+1 -1
View File
@@ -115,7 +115,7 @@ jobs:
run: |
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
--make-reports=tests_${{ matrix.config.report }} \
examples/test_examples.py
examples
- name: Failure short reports
if: ${{ failure() }}
+5 -1
View File
@@ -5,6 +5,10 @@ on:
branches:
- main
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
env:
DIFFUSERS_IS_CI: yes
HF_HOME: /mnt/cache
@@ -96,7 +100,7 @@ jobs:
run: |
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
--make-reports=tests_${{ matrix.config.report }} \
examples/test_examples.py
examples
- name: Failure short reports
if: ${{ failure() }}
+4
View File
@@ -13,6 +13,10 @@ env:
PYTEST_TIMEOUT: 600
RUN_SLOW: no
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
run_fast_tests_apple_m1:
name: Fast PyTorch MPS tests on MacOS
+1 -1
View File
@@ -355,7 +355,7 @@ You will need basic `git` proficiency to be able to contribute to
manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro
Git](https://git-scm.com/book/en/v2) is a very good reference.
Follow these steps to start contributing ([supported Python versions](https://github.com/huggingface/diffusers/blob/main/setup.py#L244)):
Follow these steps to start contributing ([supported Python versions](https://github.com/huggingface/diffusers/blob/main/setup.py#L265)):
1. Fork the [repository](https://github.com/huggingface/diffusers) by
clicking on the 'Fork' button on the repository's page. This creates a copy of the code
+1 -1
View File
@@ -41,7 +41,7 @@ repo-consistency:
quality:
ruff check $(check_dirs) setup.py
ruff format --check $(check_dirs) setup.py
ruff format --check $(check_dirs) setup.py
python utils/check_doc_toc.py
# Format source code automatically and check is there are any problems left that need manual fixing
+1 -1
View File
@@ -82,7 +82,7 @@ Models are designed as configurable toolboxes that are natural extensions of [Py
The following design principles are followed:
- Models correspond to **a type of model architecture**. *E.g.* the [`UNet2DConditionModel`] class is used for all UNet variations that expect 2D image inputs and are conditioned on some context.
- All models can be found in [`src/diffusers/models`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models) and every model architecture shall be defined in its file, e.g. [`unet_2d_condition.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py), [`transformer_2d.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformer_2d.py), etc...
- Models **do not** follow the single-file policy and should make use of smaller model building blocks, such as [`attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py), [`resnet.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py), [`embeddings.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py), etc... **Note**: This is in stark contrast to Transformers' modelling files and shows that models do not really follow the single-file policy.
- Models **do not** follow the single-file policy and should make use of smaller model building blocks, such as [`attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py), [`resnet.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py), [`embeddings.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py), etc... **Note**: This is in stark contrast to Transformers' modeling files and shows that models do not really follow the single-file policy.
- Models intend to expose complexity, just like PyTorch's `Module` class, and give clear error messages.
- Models all inherit from `ModelMixin` and `ConfigMixin`.
- Models can be optimized for performance when it doesnt demand major code changes, keep backward compatibility, and give significant memory or compute gain.
+11 -1
View File
@@ -72,6 +72,8 @@
title: Overview
- local: using-diffusers/sdxl
title: Stable Diffusion XL
- local: using-diffusers/sdxl_turbo
title: SDXL Turbo
- local: using-diffusers/kandinsky
title: Kandinsky
- local: using-diffusers/controlnet
@@ -94,6 +96,8 @@
title: Latent Consistency Model-LoRA
- local: using-diffusers/inference_with_lcm
title: Latent Consistency Model
- local: using-diffusers/svd
title: Stable Video Diffusion
title: Specific pipeline examples
- sections:
- local: training/overview
@@ -129,6 +133,8 @@
title: LoRA
- local: training/custom_diffusion
title: Custom Diffusion
- local: training/lcm_distill
title: Latent Consistency Distillation
- local: training/ddpo
title: Reinforcement learning training with DDPO
title: Methods
@@ -278,6 +284,8 @@
title: Kandinsky 2.1
- local: api/pipelines/kandinsky_v22
title: Kandinsky 2.2
- local: api/pipelines/kandinsky3
title: Kandinsky 3
- local: api/pipelines/latent_consistency_models
title: Latent Consistency Models
- local: api/pipelines/latent_diffusion
@@ -327,12 +335,14 @@
title: Stable Diffusion 2
- local: api/pipelines/stable_diffusion/stable_diffusion_xl
title: Stable Diffusion XL
- local: api/pipelines/stable_diffusion/sdxl_turbo
title: SDXL Turbo
- local: api/pipelines/stable_diffusion/latent_upscale
title: Latent upscaler
- local: api/pipelines/stable_diffusion/upscale
title: Super-resolution
- local: api/pipelines/stable_diffusion/ldm3d_diffusion
title: LDM3D Text-to-(RGB, Depth)
title: LDM3D Text-to-(RGB, Depth), Text-to-(RGB-pano, Depth-pano), LDM3D Upscaler
- local: api/pipelines/stable_diffusion/adapter
title: Stable Diffusion T2I-Adapter
- local: api/pipelines/stable_diffusion/gligen
@@ -0,0 +1,49 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Kandinsky 3
Kandinsky 3 is created by [Vladimir Arkhipkin](https://github.com/oriBetelgeuse),[Anastasia Maltseva](https://github.com/NastyaMittseva),[Igor Pavlov](https://github.com/boomb0om),[Andrei Filatov](https://github.com/anvilarth),[Arseniy Shakhmatov](https://github.com/cene555),[Andrey Kuznetsov](https://github.com/kuznetsoffandrey),[Denis Dimitrov](https://github.com/denndimitrov), [Zein Shaheen](https://github.com/zeinsh)
The description from it's Github page:
*Kandinsky 3.0 is an open-source text-to-image diffusion model built upon the Kandinsky2-x model family. In comparison to its predecessors, enhancements have been made to the text understanding and visual quality of the model, achieved by increasing the size of the text encoder and Diffusion U-Net models, respectively.*
Its architecture includes 3 main components:
1. [FLAN-UL2](https://huggingface.co/google/flan-ul2), which is an encoder decoder model based on the T5 architecture.
2. New U-Net architecture featuring BigGAN-deep blocks doubles depth while maintaining the same number of parameters.
3. Sber-MoVQGAN is a decoder proven to have superior results in image restoration.
The original codebase can be found at [ai-forever/Kandinsky-3](https://github.com/ai-forever/Kandinsky-3).
<Tip>
Check out the [Kandinsky Community](https://huggingface.co/kandinsky-community) organization on the Hub for the official model checkpoints for tasks like text-to-image, image-to-image, and inpainting.
</Tip>
<Tip>
Make sure to check out the schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## Kandinsky3Pipeline
[[autodoc]] Kandinsky3Pipeline
- all
- __call__
## Kandinsky3Img2ImgPipeline
[[autodoc]] Kandinsky3Img2ImgPipeline
- all
- __call__
+2 -1
View File
@@ -51,9 +51,10 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
| [InstructPix2Pix](pix2pix) | image editing |
| [Kandinsky 2.1](kandinsky) | text2image, image2image, inpainting, interpolation |
| [Kandinsky 2.2](kandinsky_v22) | text2image, image2image, inpainting |
| [Kandinsky 3](kandinsky3) | text2image, image2image |
| [Latent Consistency Models](latent_consistency_models) | text2image |
| [Latent Diffusion](latent_diffusion) | text2image, super-resolution |
| [LDM3D](stable_diffusion/ldm3d_diffusion) | text2image, text-to-3D |
| [LDM3D](stable_diffusion/ldm3d_diffusion) | text2image, text-to-3D, text-to-pano, upscaling |
| [MultiDiffusion](panorama) | text2image |
| [MusicLDM](musicldm) | text2audio |
| [Paint by Example](paint_by_example) | inpainting |
+106
View File
@@ -35,6 +35,112 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
</Tip>
## Inference with under 8GB GPU VRAM
Run the [`PixArtAlphaPipeline`] with under 8GB GPU VRAM by loading the text encoder in 8-bit precision. Let's walk through a full-fledged example.
First, install the [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) library:
```bash
pip install -U bitsandbytes
```
Then load the text encoder in 8-bit:
```python
from transformers import T5EncoderModel
from diffusers import PixArtAlphaPipeline
import torch
text_encoder = T5EncoderModel.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
subfolder="text_encoder",
load_in_8bit=True,
device_map="auto",
)
pipe = PixArtAlphaPipeline.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
text_encoder=text_encoder,
transformer=None,
device_map="auto"
)
```
Now, use the `pipe` to encode a prompt:
```python
with torch.no_grad():
prompt = "cute cat"
prompt_embeds, prompt_attention_mask, negative_embeds, negative_prompt_attention_mask = pipe.encode_prompt(prompt)
```
Since text embeddings have been computed, remove the `text_encoder` and `pipe` from the memory, and free up som GPU VRAM:
```python
import gc
def flush():
gc.collect()
torch.cuda.empty_cache()
del text_encoder
del pipe
flush()
```
Then compute the latents with the prompt embeddings as inputs:
```python
pipe = PixArtAlphaPipeline.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
text_encoder=None,
torch_dtype=torch.float16,
).to("cuda")
latents = pipe(
negative_prompt=None,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
num_images_per_prompt=1,
output_type="latent",
).images
del pipe.transformer
flush()
```
<Tip>
Notice that while initializing `pipe`, you're setting `text_encoder` to `None` so that it's not loaded.
</Tip>
Once the latents are computed, pass it off to the VAE to decode into a real image:
```python
with torch.no_grad():
image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
image = pipe.image_processor.postprocess(image, output_type="pil")[0]
image.save("cat.png")
```
By deleting components you aren't using and flushing the GPU VRAM, you should be able to run [`PixArtAlphaPipeline`] with under 8GB GPU VRAM.
![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/pixart/8bits_cat.png)
If you want a report of your memory-usage, run this [script](https://gist.github.com/sayakpaul/3ae0f847001d342af27018a96f467e4e).
<Tip warning={true}>
Text embeddings computed in 8-bit can impact the quality of the generated images because of the information loss in the representation space caused by the reduced precision. It's recommended to compare the outputs with and without 8-bit.
</Tip>
While loading the `text_encoder`, you set `load_in_8bit` to `True`. You could also specify `load_in_4bit` to bring your memory requirements down even further to under 7GB.
## PixArtAlphaPipeline
[[autodoc]] PixArtAlphaPipeline
@@ -14,6 +14,11 @@ specific language governing permissions and limitations under the License.
LDM3D was proposed in [LDM3D: Latent Diffusion Model for 3D](https://huggingface.co/papers/2305.10853) by Gabriela Ben Melech Stan, Diana Wofk, Scottie Fox, Alex Redden, Will Saxton, Jean Yu, Estelle Aflalo, Shao-Yen Tseng, Fabio Nonato, Matthias Muller, and Vasudev Lal. LDM3D generates an image and a depth map from a given text prompt unlike the existing text-to-image diffusion models such as [Stable Diffusion](./overview) which only generates an image. With almost the same number of parameters, LDM3D achieves to create a latent space that can compress both the RGB images and the depth maps.
Two checkpoints are available for use:
- [ldm3d-original](https://huggingface.co/Intel/ldm3d). The original checkpoint used in the [paper](https://arxiv.org/pdf/2305.10853.pdf)
- [ldm3d-4c](https://huggingface.co/Intel/ldm3d-4c). The new version of LDM3D using 4 channels inputs instead of 6-channels inputs and finetuned on higher resolution images.
The abstract from the paper is:
*This research paper proposes a Latent Diffusion Model for 3D (LDM3D) that generates both image and depth map data from a given text prompt, allowing users to generate RGBD images from text prompts. The LDM3D model is fine-tuned on a dataset of tuples containing an RGB image, depth map and caption, and validated through extensive experiments. We also develop an application called DepthFusion, which uses the generated RGB images and depth maps to create immersive and interactive 360-degree-view experiences using TouchDesigner. This technology has the potential to transform a wide range of industries, from entertainment and gaming to architecture and design. Overall, this paper presents a significant contribution to the field of generative AI and computer vision, and showcases the potential of LDM3D and DepthFusion to revolutionize content creation and digital experiences. A short video summarizing the approach can be found at [this url](https://t.ly/tdi2).*
@@ -26,12 +31,25 @@ Make sure to check out the Stable Diffusion [Tips](overview#tips) section to lea
## StableDiffusionLDM3DPipeline
[[autodoc]] StableDiffusionLDM3DPipeline
[[autodoc]] pipelines.stable_diffusion.pipeline_stable_diffusion_ldm3d.StableDiffusionLDM3DPipeline
- all
- __call__
## LDM3DPipelineOutput
[[autodoc]] pipelines.stable_diffusion.pipeline_stable_diffusion_ldm3d.LDM3DPipelineOutput
- all
- __call__
# Upscaler
[LDM3D-VR](https://arxiv.org/pdf/2311.03226.pdf) is an extended version of LDM3D.
The abstract from the paper is:
*Latent diffusion models have proven to be state-of-the-art in the creation and manipulation of visual outputs. However, as far as we know, the generation of depth maps jointly with RGB is still limited. We introduce LDM3D-VR, a suite of diffusion models targeting virtual reality development that includes LDM3D-pano and LDM3D-SR. These models enable the generation of panoramic RGBD based on textual prompts and the upscaling of low-resolution inputs to high-resolution RGBD, respectively. Our models are fine-tuned from existing pretrained models on datasets containing panoramic/high-resolution RGB images, depth maps and captions. Both models are evaluated in comparison to existing related methods*
Two checkpoints are available for use:
- [ldm3d-pano](https://huggingface.co/Intel/ldm3d-pano). This checkpoint enables the generation of panoramic images and requires the StableDiffusionLDM3DPipeline pipeline to be used.
- [ldm3d-sr](https://huggingface.co/Intel/ldm3d-sr). This checkpoint enables the upscaling of RGB and depth images. Can be used in cascade after the original LDM3D pipeline using the StableDiffusionUpscaleLDM3DPipeline from communauty pipeline.
@@ -121,10 +121,16 @@ The table below summarizes the available Stable Diffusion pipelines, their suppo
<td class="px-4 py-2 text-gray-700">
<a href="./ldm3d_diffusion">StableDiffusionLDM3D</a>
</td>
<td class="px-4 py-2 text-gray-700">text-to-rgb, text-to-depth</td>
<td class="px-4 py-2 text-gray-700">text-to-rgb, text-to-depth, text-to-pano</td>
<td class="px-4 py-2"><a href="https://huggingface.co/spaces/r23/ldm3d-space"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue"/></a>
</td>
</tr>
<tr>
<td class="px-4 py-2 text-gray-700">
<a href="./ldm3d_diffusion">StableDiffusionUpscaleLDM3D</a>
</td>
<td class="px-4 py-2 text-gray-700">ldm3d super-resolution</td>
</tr>
</tbody>
</table>
</div>
@@ -0,0 +1,53 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# SDXL Turbo
Stable Diffusion XL (SDXL) Turbo was proposed in [Adversarial Diffusion Distillation](https://stability.ai/research/adversarial-diffusion-distillation) by Axel Sauer, Dominik Lorenz, Andreas Blattmann, and Robin Rombach.
The abstract from the paper is:
*We introduce Adversarial Diffusion Distillation (ADD), a novel training approach that efficiently samples large-scale foundational image diffusion models in just 14 steps while maintaining high image quality. We use score distillation to leverage large-scale off-the-shelf image diffusion models as a teacher signal in combination with an adversarial loss to ensure high image fidelity even in the low-step regime of one or two sampling steps. Our analyses show that our model clearly outperforms existing few-step methods (GANs,Latent Consistency Models) in a single step and reaches the performance of state-of-the-art diffusion models (SDXL) in only four steps. ADD is the first method to unlock single-step, real-time image synthesis with foundation models.*
## Tips
- SDXL Turbo uses the exact same architecture as [SDXL](./stable_diffusion_xl).
- SDXL Turbo should disable guidance scale by setting `guidance_scale=0.0`
- SDXL Turbo should use `timestep_spacing='trailing'` for the scheduler and use between 1 and 4 steps.
- SDXL Turbo has been trained to generate images of size 512x512.
- SDXL Turbo is open-access, but not open-source meaning that one might have to buy a model license in order to use it for commercial applications. Make sure to read the [official model card](https://huggingface.co/stabilityai/sdxl-turbo) to learn more.
<Tip>
To learn how to use SDXL Turbo for various tasks, how to optimize performance, and other usage examples, take a look at the [Stable Diffusion XL](../../../using-diffusers/sdxl_turbo) guide.
Check out the [Stability AI](https://huggingface.co/stabilityai) Hub organization for the official base and refiner model checkpoints!
</Tip>
## StableDiffusionXLPipeline
[[autodoc]] StableDiffusionXLPipeline
- all
- __call__
## StableDiffusionXLImg2ImgPipeline
[[autodoc]] StableDiffusionXLImg2ImgPipeline
- all
- __call__
## StableDiffusionXLInpaintPipeline
[[autodoc]] StableDiffusionXLInpaintPipeline
- all
- __call__
@@ -92,6 +92,19 @@ imageio.mimsave("video.mp4", result, fps=4)
```
- #### SDXL Support
In order to use the SDXL model when generating a video from prompt, use the `TextToVideoZeroSDXLPipeline` pipeline:
```python
import torch
from diffusers import TextToVideoZeroSDXLPipeline
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
pipe = TextToVideoZeroSDXLPipeline.from_pretrained(
model_id, torch_dtype=torch.float16, variant="fp16", use_safetensors=True
).to("cuda")
```
### Text-To-Video with Pose Control
To generate a video from prompt with additional pose control
@@ -141,7 +154,33 @@ To generate a video from prompt with additional pose control
result = pipe(prompt=[prompt] * len(pose_images), image=pose_images, latents=latents).images
imageio.mimsave("video.mp4", result, fps=4)
```
- #### SDXL Support
Since our attention processor also works with SDXL, it can be utilized to generate a video from prompt using ControlNet models powered by SDXL:
```python
import torch
from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel
from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_zero import CrossFrameAttnProcessor
controlnet_model_id = 'thibaud/controlnet-openpose-sdxl-1.0'
model_id = 'stabilityai/stable-diffusion-xl-base-1.0'
controlnet = ControlNetModel.from_pretrained(controlnet_model_id, torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
model_id, controlnet=controlnet, torch_dtype=torch.float16
).to('cuda')
# Set the attention processor
pipe.unet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))
pipe.controlnet.set_attn_processor(CrossFrameAttnProcessor(batch_size=2))
# fix latents for all frames
latents = torch.randn((1, 4, 128, 128), device="cuda", dtype=torch.float16).repeat(len(pose_images), 1, 1, 1)
prompt = "Darth Vader dancing in a desert"
result = pipe(prompt=[prompt] * len(pose_images), image=pose_images, latents=latents).images
imageio.mimsave("video.mp4", result, fps=4)
```
### Text-To-Video with Edge Control
@@ -253,5 +292,10 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
- all
- __call__
## TextToVideoZeroSDXLPipeline
[[autodoc]] TextToVideoZeroSDXLPipeline
- all
- __call__
## TextToVideoPipelineOutput
[[autodoc]] pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.TextToVideoPipelineOutput
@@ -24,7 +24,7 @@ The abstract from the paper is:
*Model-based reinforcement learning methods often use learning only for the purpose of estimating an approximate dynamics model, offloading the rest of the decision-making work to classical trajectory optimizers. While conceptually simple, this combination has a number of empirical shortcomings, suggesting that learned models may not be well-suited to standard trajectory optimization. In this paper, we consider what it would look like to fold as much of the trajectory optimization pipeline as possible into the modeling problem, such that sampling from the model and planning with it become nearly identical. The core of our technical approach lies in a diffusion probabilistic model that plans by iteratively denoising trajectories. We show how classifier-guided sampling and image inpainting can be reinterpreted as coherent planning strategies, explore the unusual and useful properties of diffusion-based planning methods, and demonstrate the effectiveness of our framework in control settings that emphasize long-horizon decision-making and test-time flexibility.*
You can find additional information about the model on the [project page](https://diffusion-planning.github.io/), the [original codebase](https://github.com/jannerm/diffuser), or try it out in a demo [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/reinforcement_learning_with_diffusers.ipynb).
You can find additional information about the model on the [project page](https://diffusion-planning.github.io/), the [original codebase](https://github.com/jannerm/diffuser), or try it out in a demo [notebook](https://colab.research.google.com/drive/1rXm8CX4ZdN5qivjJ2lhwhkOmt_m0CvU0#scrollTo=6HXJvhyqcITc&uniqifier=1).
The script to run the model is available [here](https://github.com/huggingface/diffusers/tree/main/examples/reinforcement_learning).
@@ -25,4 +25,4 @@ The abstract from the paper is:
</Tip>
## ScoreSdeVpScheduler
[[autodoc]] schedulers.scheduling_sde_vp.ScoreSdeVpScheduler
[[autodoc]] schedulers.deprecated.scheduling_sde_vp.ScoreSdeVpScheduler
@@ -18,4 +18,4 @@ specific language governing permissions and limitations under the License.
[[autodoc]] KarrasVeScheduler
## KarrasVeOutput
[[autodoc]] schedulers.scheduling_karras_ve.KarrasVeOutput
[[autodoc]] schedulers.deprecated.scheduling_karras_ve.KarrasVeOutput
+27 -7
View File
@@ -297,17 +297,37 @@ if you don't know yet what specific component you would like to add:
- [Model or pipeline](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+pipeline%2Fmodel%22)
- [Scheduler](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+scheduler%22)
Before adding any of the three components, it is strongly recommended that you give the [Philosophy guide](philosophy) a read to better understand the design of any of the three components. Please be aware that
we cannot merge model, scheduler, or pipeline additions that strongly diverge from our design philosophy
as it will lead to API inconsistencies. If you fundamentally disagree with a design choice, please
open a [Feedback issue](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=) instead so that it can be discussed whether a certain design
pattern/design choice shall be changed everywhere in the library and whether we shall update our design philosophy. Consistency across the library is very important for us.
Before adding any of the three components, it is strongly recommended that you give the [Philosophy guide](philosophy) a read to better understand the design of any of the three components. Please be aware that we cannot merge model, scheduler, or pipeline additions that strongly diverge from our design philosophy
as it will lead to API inconsistencies. If you fundamentally disagree with a design choice, please open a [Feedback issue](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=) instead so that it can be discussed whether a certain design pattern/design choice shall be changed everywhere in the library and whether we shall update our design philosophy. Consistency across the library is very important for us.
Please make sure to add links to the original codebase/paper to the PR and ideally also ping the
original author directly on the PR so that they can follow the progress and potentially help with questions.
Please make sure to add links to the original codebase/paper to the PR and ideally also ping the original author directly on the PR so that they can follow the progress and potentially help with questions.
If you are unsure or stuck in the PR, don't hesitate to leave a message to ask for a first review or help.
#### Copied from mechanism
A unique and important feature to understand when adding any pipeline, model or scheduler code is the `# Copied from` mechanism. You'll see this all over the Diffusers codebase, and the reason we use it is to keep the codebase easy to understand and maintain. Marking code with the `# Copied from` mechanism forces the marked code to be identical to the code it was copied from. This makes it easy to update and propagate changes across many files whenever you run `make fix-copies`.
For example, in the code example below, [`~diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is the original code and `AltDiffusionPipelineOutput` uses the `# Copied from` mechanism to copy it. The only difference is changing the class prefix from `Stable` to `Alt`.
```py
# Copied from diffusers.pipelines.stable_diffusion.pipeline_output.StableDiffusionPipelineOutput with Stable->Alt
class AltDiffusionPipelineOutput(BaseOutput):
"""
Output class for Alt Diffusion pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
num_channels)`.
nsfw_content_detected (`List[bool]`)
List indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content or
`None` if safety checking could not be performed.
"""
```
To learn more, read this section of the [~Don't~ Repeat Yourself*](https://huggingface.co/blog/transformers-design-philosophy#4-machine-learning-models-are-static) blog post.
## How to write a good issue
**The better your issue is written, the higher the chances that it will be quickly resolved.**
+255
View File
@@ -0,0 +1,255 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Latent Consistency Distillation
[Latent Consistency Models (LCMs)](https://hf.co/papers/2310.04378) are able to generate high-quality images in just a few steps, representing a big leap forward because many pipelines require at least 25+ steps. LCMs are produced by applying the latent consistency distillation method to any Stable Diffusion model. This method works by applying *one-stage guided distillation* to the latent space, and incorporating a *skipping-step* method to consistently skip timesteps to accelerate the distillation process (refer to section 4.1, 4.2, and 4.3 of the paper for more details).
If you're training on a GPU with limited vRAM, try enabling `gradient_checkpointing`, `gradient_accumulation_steps`, and `mixed_precision` to reduce memory-usage and speedup training. You can reduce your memory-usage even more by enabling memory-efficient attention with [xFormers](../optimization/xformers) and [bitsandbytes'](https://github.com/TimDettmers/bitsandbytes) 8-bit optimizer.
This guide will explore the [train_lcm_distill_sd_wds.py](https://github.com/huggingface/diffusers/blob/main/examples/consistency_distillation/train_lcm_distill_sd_wds.py) script to help you become more familiar with it, and how you can adapt it for your own use-case.
Before running the script, make sure you install the library from source:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .
```
Then navigate to the example folder containing the training script and install the required dependencies for the script you're using:
```bash
cd examples/consistency_distillation
pip install -r requirements.txt
```
<Tip>
🤗 Accelerate is a library for helping you train on multiple GPUs/TPUs or with mixed-precision. It'll automatically configure your training setup based on your hardware and environment. Take a look at the 🤗 Accelerate [Quick tour](https://huggingface.co/docs/accelerate/quicktour) to learn more.
</Tip>
Initialize an 🤗 Accelerate environment (try enabling `torch.compile` to significantly speedup training):
```bash
accelerate config
```
To setup a default 🤗 Accelerate environment without choosing any configurations:
```bash
accelerate config default
```
Or if your environment doesn't support an interactive shell, like a notebook, you can use:
```bash
from accelerate.utils import write_basic_config
write_basic_config()
```
Lastly, if you want to train a model on your own dataset, take a look at the [Create a dataset for training](create_dataset) guide to learn how to create a dataset that works with the training script.
## Script parameters
<Tip>
The following sections highlight parts of the training script that are important for understanding how to modify it, but it doesn't cover every aspect of the script in detail. If you're interested in learning more, feel free to read through the [script](https://github.com/huggingface/diffusers/blob/main/examples/consistency_distillation/train_lcm_distill_sd_wds.py) and let us know if you have any questions or concerns.
</Tip>
The training script provides many parameters to help you customize your training run. All of the parameters and their descriptions are found in the [`parse_args()`](https://github.com/huggingface/diffusers/blob/3b37488fa3280aed6a95de044d7a42ffdcb565ef/examples/consistency_distillation/train_lcm_distill_sd_wds.py#L419) function. This function provides default values for each parameter, such as the training batch size and learning rate, but you can also set your own values in the training command if you'd like.
For example, to speedup training with mixed precision using the fp16 format, add the `--mixed_precision` parameter to the training command:
```bash
accelerate launch train_lcm_distill_sd_wds.py \
--mixed_precision="fp16"
```
Most of the parameters are identical to the parameters in the [Text-to-image](text2image#script-parameters) training guide, so you'll focus on the parameters that are relevant to latent consistency distillation in this guide.
- `--pretrained_teacher_model`: the path to a pretrained latent diffusion model to use as the teacher model
- `--pretrained_vae_model_name_or_path`: path to a pretrained VAE; the SDXL VAE is known to suffer from numerical instability, so this parameter allows you to specify an alternative VAE (like this [VAE]((https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)) by madebyollin which works in fp16)
- `--w_min` and `--w_max`: the minimum and maximum guidance scale values for guidance scale sampling
- `--num_ddim_timesteps`: the number of timesteps for DDIM sampling
- `--loss_type`: the type of loss (L2 or Huber) to calculate for latent consistency distillation; Huber loss is generally preferred because it's more robust to outliers
- `--huber_c`: the Huber loss parameter
## Training script
The training script starts by creating a dataset class - [`Text2ImageDataset`](https://github.com/huggingface/diffusers/blob/3b37488fa3280aed6a95de044d7a42ffdcb565ef/examples/consistency_distillation/train_lcm_distill_sd_wds.py#L141) - for preprocessing the images and creating a training dataset.
```py
def transform(example):
image = example["image"]
image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR)
c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution))
image = TF.crop(image, c_top, c_left, resolution, resolution)
image = TF.to_tensor(image)
image = TF.normalize(image, [0.5], [0.5])
example["image"] = image
return example
```
For improved performance on reading and writing large datasets stored in the cloud, this script uses the [WebDataset](https://github.com/webdataset/webdataset) format to create a preprocessing pipeline to apply transforms and create a dataset and dataloader for training. Images are processed and fed to the training loop without having to download the full dataset first.
```py
processing_pipeline = [
wds.decode("pil", handler=wds.ignore_and_continue),
wds.rename(image="jpg;png;jpeg;webp", text="text;txt;caption", handler=wds.warn_and_continue),
wds.map(filter_keys({"image", "text"})),
wds.map(transform),
wds.to_tuple("image", "text"),
]
```
In the [`main()`](https://github.com/huggingface/diffusers/blob/3b37488fa3280aed6a95de044d7a42ffdcb565ef/examples/consistency_distillation/train_lcm_distill_sd_wds.py#L768) function, all the necessary components like the noise scheduler, tokenizers, text encoders, and VAE are loaded. The teacher UNet is also loaded here and then you can create a student UNet from the teacher UNet. The student UNet is updated by the optimizer during training.
```py
teacher_unet = UNet2DConditionModel.from_pretrained(
args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision
)
unet = UNet2DConditionModel(**teacher_unet.config)
unet.load_state_dict(teacher_unet.state_dict(), strict=False)
unet.train()
```
Now you can create the [optimizer](https://github.com/huggingface/diffusers/blob/3b37488fa3280aed6a95de044d7a42ffdcb565ef/examples/consistency_distillation/train_lcm_distill_sd_wds.py#L979) to update the UNet parameters:
```py
optimizer = optimizer_class(
unet.parameters(),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
```
Create the [dataset](https://github.com/huggingface/diffusers/blob/3b37488fa3280aed6a95de044d7a42ffdcb565ef/examples/consistency_distillation/train_lcm_distill_sd_wds.py#L994):
```py
dataset = Text2ImageDataset(
train_shards_path_or_url=args.train_shards_path_or_url,
num_train_examples=args.max_train_samples,
per_gpu_batch_size=args.train_batch_size,
global_batch_size=args.train_batch_size * accelerator.num_processes,
num_workers=args.dataloader_num_workers,
resolution=args.resolution,
shuffle_buffer_size=1000,
pin_memory=True,
persistent_workers=True,
)
train_dataloader = dataset.train_dataloader
```
Next, you're ready to setup the [training loop](https://github.com/huggingface/diffusers/blob/3b37488fa3280aed6a95de044d7a42ffdcb565ef/examples/consistency_distillation/train_lcm_distill_sd_wds.py#L1049) and implement the latent consistency distillation method (see Algorithm 1 in the paper for more details). This section of the script takes care of adding noise to the latents, sampling and creating a guidance scale embedding, and predicting the original image from the noise.
```py
pred_x_0 = predicted_origin(
noise_pred,
start_timesteps,
noisy_model_input,
noise_scheduler.config.prediction_type,
alpha_schedule,
sigma_schedule,
)
model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
```
It gets the [teacher model predictions](https://github.com/huggingface/diffusers/blob/3b37488fa3280aed6a95de044d7a42ffdcb565ef/examples/consistency_distillation/train_lcm_distill_sd_wds.py#L1172) and the [LCM predictions](https://github.com/huggingface/diffusers/blob/3b37488fa3280aed6a95de044d7a42ffdcb565ef/examples/consistency_distillation/train_lcm_distill_sd_wds.py#L1209) next, calculates the loss, and then backpropagates it to the LCM.
```py
if args.loss_type == "l2":
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
elif args.loss_type == "huber":
loss = torch.mean(
torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
)
```
If you want to learn more about how the training loop works, check out the [Understanding pipelines, models and schedulers tutorial](../using-diffusers/write_own_pipeline) which breaks down the basic pattern of the denoising process.
## Launch the script
Now you're ready to launch the training script and start distilling!
For this guide, you'll use the `--train_shards_path_or_url` to specify the path to the [Conceptual Captions 12M](https://github.com/google-research-datasets/conceptual-12m) dataset stored on the Hub [here](https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset). Set the `MODEL_DIR` environment variable to the name of the teacher model and `OUTPUT_DIR` to where you want to save the model.
```bash
export MODEL_DIR="runwayml/stable-diffusion-v1-5"
export OUTPUT_DIR="path/to/saved/model"
accelerate launch train_lcm_distill_sd_wds.py \
--pretrained_teacher_model=$MODEL_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision=fp16 \
--resolution=512 \
--learning_rate=1e-6 --loss_type="huber" --ema_decay=0.95 --adam_weight_decay=0.0 \
--max_train_steps=1000 \
--max_train_samples=4000000 \
--dataloader_num_workers=8 \
--train_shards_path_or_url="pipe:curl -L -s https://huggingface.co/datasets/laion/conceptual-captions-12m-webdataset/resolve/main/data/{00000..01099}.tar?download=true" \
--validation_steps=200 \
--checkpointing_steps=200 --checkpoints_total_limit=10 \
--train_batch_size=12 \
--gradient_checkpointing --enable_xformers_memory_efficient_attention \
--gradient_accumulation_steps=1 \
--use_8bit_adam \
--resume_from_checkpoint=latest \
--report_to=wandb \
--seed=453645634 \
--push_to_hub
```
Once training is complete, you can use your new LCM for inference.
```py
from diffusers import UNet2DConditionModel, DiffusionPipeline, LCMScheduler
import torch
unet = UNet2DConditionModel.from_pretrained("your-username/your-model", torch_dtype=torch.float16, variant="fp16")
pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", unet=unet, torch_dtype=torch.float16, variant="fp16")
pipeline.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
pipeline.to("cuda")
prompt = "sushi rolls in the form of panda heads, sushi platter"
image = pipeline(prompt, num_inference_steps=4, guidance_scale=1.0).images[0]
```
## LoRA
LoRA is a training technique for significantly reducing the number of trainable parameters. As a result, training is faster and it is easier to store the resulting weights because they are a lot smaller (~100MBs). Use the [train_lcm_distill_lora_sd_wds.py](https://github.com/huggingface/diffusers/blob/main/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py) or [train_lcm_distill_lora_sdxl.wds.py](https://github.com/huggingface/diffusers/blob/main/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py) script to train with LoRA.
The LoRA training script is discussed in more detail in the [LoRA training](lora) guide.
## Stable Diffusion XL
Stable Diffusion XL (SDXL) is a powerful text-to-image model that generates high-resolution images, and it adds a second text-encoder to its architecture. Use the [train_lcm_distill_sdxl_wds.py](https://github.com/huggingface/diffusers/blob/main/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py) script to train a SDXL model with LoRA.
The SDXL training script is discussed in more detail in the [SDXL training](sdxl) guide.
## Next steps
Congratulations on distilling a LCM model! To learn more about LCM, the following may be helpful:
- Learn how to use [LCMs for inference](../using-diffusers/lcm) for text-to-image, image-to-image, and with LoRA checkpoints.
- Read the [SDXL in 4 steps with Latent Consistency LoRAs](https://huggingface.co/blog/lcm_lora) blog post to learn more about SDXL LCM-LoRA's for super fast inference, quality comparisons, benchmarks, and more.
@@ -20,6 +20,8 @@ The Kandinsky models are a series of multilingual text-to-image generation model
[Kandinsky 2.2](../api/pipelines/kandinsky_v22) improves on the previous model by replacing the image encoder of the image prior model with a larger CLIP-ViT-G model to improve quality. The image prior model was also retrained on images with different resolutions and aspect ratios to generate higher-resolution images and different image sizes.
[Kandinsky 3](../api/pipelines/kandinsky3) simplifies the architecture and shifts away from the two-stage generation process involving the prior model and diffusion model. Instead, Kandinsky 3 uses [Flan-UL2](https://huggingface.co/google/flan-ul2) to encode text, a UNet with [BigGan-deep](https://hf.co/papers/1809.11096) blocks, and [Sber-MoVQGAN](https://github.com/ai-forever/MoVQGAN) to decode the latents into images. Text understanding and generated image quality are primarily achieved by using a larger text encoder and UNet.
This guide will show you how to use the Kandinsky models for text-to-image, image-to-image, inpainting, interpolation, and more.
Before you begin, make sure you have the following libraries installed:
@@ -33,6 +35,10 @@ Before you begin, make sure you have the following libraries installed:
Kandinsky 2.1 and 2.2 usage is very similar! The only difference is Kandinsky 2.2 doesn't accept `prompt` as an input when decoding the latents. Instead, Kandinsky 2.2 only accepts `image_embeds` during decoding.
<br>
Kandinsky 3 has a more concise architecture and it doesn't require a prior model. This means it's usage is identical to other diffusion models like [Stable Diffusion XL](sdxl).
</Tip>
## Text-to-image
@@ -91,6 +97,23 @@ image
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-text-to-image.png"/>
</div>
</hfoption>
<hfoption id="Kandinsky 3">
Kandinsky 3 doesn't require a prior model so you can directly load the [`Kandinsky3Pipeline`] and pass a prompt to generate an image:
```py
from diffusers import Kandinsky3Pipeline
import torch
pipeline = Kandinsky3Pipeline.from_pretrained("kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16)
pipeline.enable_model_cpu_offload()
prompt = "A alien cheeseburger creature eating itself, claymation, cinematic, moody lighting"
image = pipeline(prompt).images[0]
image
```
</hfoption>
</hfoptions>
@@ -161,6 +184,20 @@ prior_pipeline = KandinskyPriorPipeline.from_pretrained("kandinsky-community/kan
pipeline = KandinskyV22Img2ImgPipeline.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, use_safetensors=True).to("cuda")
```
</hfoption>
<hfoption id="Kandinsky 3">
Kandinsky 3 doesn't require a prior model so you can directly load the image-to-image pipeline:
```py
from diffusers import Kandinsky3Img2ImgPipeline
from diffusers.utils import load_image
import torch
pipeline = Kandinsky3Img2ImgPipeline.from_pretrained("kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16)
pipeline.enable_model_cpu_offload()
```
</hfoption>
</hfoptions>
@@ -218,6 +255,14 @@ make_image_grid([original_image.resize((512, 512)), image.resize((512, 512))], r
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-image-to-image.png"/>
</div>
</hfoption>
<hfoption id="Kandinsky 3">
```py
image = pipeline(prompt, negative_prompt=negative_prompt, image=image, strength=0.75, num_inference_steps=25).images[0]
image
```
</hfoption>
</hfoptions>
@@ -307,3 +307,331 @@ prompt = "a house by william eggleston, sunrays, beautiful, sunlight, sunrays, b
image = pipeline(prompt=prompt).images[0]
image
```
## IP-Adapter
[IP-Adapter](https://ip-adapter.github.io/) is an effective and lightweight adapter that adds image prompting capabilities to a diffusion model. This adapter works by decoupling the cross-attention layers of the image and text features. All the other model components are frozen and only the embedded image features in the UNet are trained. As a result, IP-Adapter files are typically only ~100MBs.
IP-Adapter works with most of our pipelines, including Stable Diffusion, Stable Diffusion XL (SDXL), ControlNet, T2I-Adapter, AnimateDiff. And you can use any custom models finetuned from the same base models. It also works with LCM-Lora out of box.
<Tip>
You can find official IP-Adapter checkpoints in [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter).
IP-Adapter was contributed by [okotaku](https://github.com/okotaku).
</Tip>
Let's first create a Stable Diffusion Pipeline.
```py
from diffusers import AutoPipelineForText2Image
import torch
from diffusers.utils import load_image
pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
```
Now load the [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter) weights with the [`~loaders.IPAdapterMixin.load_ip_adapter`] method.
```py
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
```
<Tip>
IP-Adapter relies on an image encoder to generate the image features, if your IP-Adapter weights folder contains a "image_encoder" subfolder, the image encoder will be automatically loaded and registered to the pipeline. Otherwise you can so load a [`~transformers.CLIPVisionModelWithProjection`] model and pass it to a Stable Diffusion pipeline when you create it.
```py
from diffusers import AutoPipelineForText2Image, CLIPVisionModelWithProjection
import torch
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"h94/IP-Adapter",
subfolder="models/image_encoder",
torch_dtype=torch.float16,
).to("cuda")
pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, torch_dtype=torch.float16).to("cuda")
```
</Tip>
IP-Adapter allows you to use both image and text to condition the image generation process. For example, let's use the bear image from the [Textual Inversion](#textual-inversion) section as the image prompt (`ip_adapter_image`) along with a text prompt to add "sunglasses". 😎
```py
pipeline.set_ip_adapter_scale(0.6)
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png")
generator = torch.Generator(device="cpu").manual_seed(33)
images = pipeline(
    prompt='best quality, high quality, wearing sunglasses',
    ip_adapter_image=image,
    negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
    num_inference_steps=50,
    generator=generator,
).images
images[0]
```
<div class="flex justify-center">
    <img src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip-bear.png" />
</div>
<Tip>
You can use the [`~loaders.IPAdapterMixin.set_ip_adapter_scale`] method to adjust the text prompt and image prompt condition ratio.  If you're only using the image prompt, you should set the scale to `1.0`. You can lower the scale to get more generation diversity, but it'll be less aligned with the prompt.
`scale=0.5` can achieve good results in most cases when you use both text and image prompts.
</Tip>
IP-Adapter also works great with Image-to-Image and Inpainting pipelines. See below examples of how you can use it with Image-to-Image and Inpaint.
<hfoptions id="tasks">
<hfoption id="image-to-image">
```py
from diffusers import AutoPipelineForImage2Image
import torch
from diffusers.utils import load_image
pipeline = AutoPipelineForImage2Image.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/vermeer.jpg")
ip_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/river.png")
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
generator = torch.Generator(device="cpu").manual_seed(33)
images = pipeline(
    prompt='best quality, high quality',
    image = image,
    ip_adapter_image=ip_image,
    num_inference_steps=50,
    generator=generator,
    strength=0.6,
).images
images[0]
```
</hfoption>
<hfoption id="inpaint">
```py
from diffusers import AutoPipelineForInpaint
import torch
from diffusers.utils import load_image
pipeline = AutoPipelineForInpaint.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float).to("cuda")
image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/inpaint_image.png")
mask = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/mask.png")
ip_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/girl.png")
image = image.resize((512, 768))
mask = mask.resize((512, 768))
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
generator = torch.Generator(device="cpu").manual_seed(33)
images = pipeline(
prompt='best quality, high quality',
image = image,
mask_image = mask,
ip_adapter_image=ip_image,
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
num_inference_steps=50,
generator=generator,
strength=0.5,
).images
images[0]
```
</hfoption>
</hfoptions>
IP-Adapters can also be used with [SDXL](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)
```python
from diffusers import AutoPipelineForText2Image
from diffusers.utils import load_image
import torch
pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16
).to("cuda")
image = load_image("https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/watercolor_painting.jpeg")
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
generator = torch.Generator(device="cpu").manual_seed(33)
image = pipeline(
prompt="best quality, high quality",
ip_adapter_image=image,
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
num_inference_steps=25,
generator=generator,
).images[0]
image.save("sdxl_t2i.png")
```
<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/watercolor_painting.jpeg"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">input image</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/sdxl_t2i.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">adapted image</figcaption>
</div>
</div>
### LCM-Lora
You can use IP-Adapter with LCM-Lora to achieve "instant fine-tune" with custom images. Note that you need to load IP-Adapter weights before loading the LCM-Lora weights.
```py
from diffusers import DiffusionPipeline, LCMScheduler
import torch
from diffusers.utils import load_image
model_id = "sd-dreambooth-library/herge-style"
lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5"
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
pipe.load_lora_weights(lcm_lora_id)
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
prompt = "best quality, high quality"
image = load_image("https://user-images.githubusercontent.com/24734142/266492875-2d50d223-8475-44f0-a7c6-08b51cb53572.png")
images = pipe(
prompt=prompt,
ip_adapter_image=image,
num_inference_steps=4,
guidance_scale=1,
).images[0]
```
### Other pipelines
IP-Adapter is compatible with any pipeline that (1) uses a text prompt and (2) uses Stable Diffusion or Stable Diffusion XL checkpoint. To use IP-Adapter with a different pipeline, all you need to do is to run `load_ip_adapter()` method after you create the pipeline, and then pass your image to the pipeline as `ip_adapter_image`
<Tip>
🤗 Diffusers currently only supports using IP-Adapter with some of the most popular pipelines, feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you have a cool use-case and require integrating IP-adapters with a pipeline that does not support it yet!
</Tip>
You can find below examples on how to use IP-Adapter with ControlNet and AnimateDiff.
<hfoptions id="model">
<hfoption id="ControlNet">
```
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
import torch
from diffusers.utils import load_image
controlnet_model_path = "lllyasviel/control_v11f1p_sd15_depth"
controlnet = ControlNetModel.from_pretrained(controlnet_model_path, torch_dtype=torch.float16)
pipeline = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16)
pipeline.to("cuda")
image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/statue.png")
depth_map = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/depth.png")
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
generator = torch.Generator(device="cpu").manual_seed(33)
images = pipeline(
prompt='best quality, high quality',
image=depth_map,
ip_adapter_image=image,
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
num_inference_steps=50,
generator=generator,
).images
images[0]
```
<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/statue.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">input image</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ipa-controlnet-out.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">adapted image</figcaption>
</div>
</div>
</hfoption>
<hfoption id="AnimateDiff">
```py
# animate diff + ip adapter
import torch
from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler
from diffusers.utils import export_to_gif, load_image
# Load the motion adapter
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16)
# load SD 1.5 based finetuned model
model_id = "Lykon/DreamShaper"
pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter, torch_dtype=torch.float16)
# scheduler
scheduler = DDIMScheduler(
clip_sample=False,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="linear",
timestep_spacing="trailing",
steps_offset=1
)
pipe.scheduler = scheduler
# enable memory savings
pipe.enable_vae_slicing()
pipe.enable_model_cpu_offload()
# load ip_adapter
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
# load motion adapters
pipe.load_lora_weights("guoyww/animatediff-motion-lora-zoom-out", adapter_name="zoom-out")
pipe.load_lora_weights("guoyww/animatediff-motion-lora-tilt-up", adapter_name="tilt-up")
pipe.load_lora_weights("guoyww/animatediff-motion-lora-pan-left", adapter_name="pan-left")
seed = 42
image = load_image("https://user-images.githubusercontent.com/24734142/266492875-2d50d223-8475-44f0-a7c6-08b51cb53572.png")
images = [image] * 3
prompts = ["best quality, high quality"] * 3
negative_prompt = "bad quality, worst quality"
adapter_weights = [[0.75, 0.0, 0.0], [0.0, 0.0, 0.75], [0.0, 0.75, 0.75]]
# generate
output_frames = []
for prompt, image, adapter_weight in zip(prompts, images, adapter_weights):
pipe.set_adapters(["zoom-out", "tilt-up", "pan-left"], adapter_weights=adapter_weight)
output = pipe(
prompt= prompt,
num_frames=16,
guidance_scale=7.5,
num_inference_steps=30,
ip_adapter_image = image,
generator=torch.Generator("cpu").manual_seed(seed),
)
frames = output.frames[0]
output_frames.extend(frames)
export_to_gif(output_frames, "test_out_animation.gif")
```
</hfoption>
</hfoptions>
@@ -0,0 +1,116 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Stable Diffusion XL Turbo
[[open-in-colab]]
SDXL Turbo is an adversarial time-distilled [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) (SDXL) model capable
of running inference in as little as 1 step.
This guide will show you how to use SDXL-Turbo for text-to-image and image-to-image.
Before you begin, make sure you have the following libraries installed:
```py
# uncomment to install the necessary libraries in Colab
#!pip install -q diffusers transformers accelerate omegaconf
```
## Load model checkpoints
Model weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~StableDiffusionXLPipeline.from_pretrained`] method:
```py
from diffusers import AutoPipelineForText2Image, AutoPipelineForImage2Image
import torch
pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
pipeline = pipeline.to("cuda")
```
You can also use the [`~StableDiffusionXLPipeline.from_single_file`] method to load a model checkpoint stored in a single file format (`.ckpt` or `.safetensors`) from the Hub or locally:
```py
from diffusers import StableDiffusionXLPipeline
import torch
pipeline = StableDiffusionXLPipeline.from_single_file(
"https://huggingface.co/stabilityai/sdxl-turbo/blob/main/sd_xl_turbo_1.0_fp16.safetensors", torch_dtype=torch.float16)
pipeline = pipeline.to("cuda")
```
## Text-to-image
For text-to-image, pass a text prompt. By default, SDXL Turbo generates a 512x512 image, and that resolution gives the best results. You can try setting the `height` and `width` parameters to 768x768 or 1024x1024, but you should expect quality degradations when doing so.
Make sure to set `guidance_scale` to 0.0 to disable, as the model was trained without it. A single inference step is enough to generate high quality images.
Increasing the number of steps to 2, 3 or 4 should improve image quality.
```py
from diffusers import AutoPipelineForText2Image
import torch
pipeline_text2image = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
pipeline_text2image = pipeline_text2image.to("cuda")
prompt = "A cinematic shot of a baby racoon wearing an intricate italian priest robe."
image = pipeline_text2image(prompt=prompt, guidance_scale=0.0, num_inference_steps=1).images[0]
image
```
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/sdxl-turbo-text2img.png" alt="generated image of a racoon in a robe"/>
</div>
## Image-to-image
For image-to-image generation, make sure that `num_inference_steps * strength` is larger or equal to 1.
The image-to-image pipeline will run for `int(num_inference_steps * strength)` steps, e.g. `0.5 * 2.0 = 1` step in
our example below.
```py
from diffusers import AutoPipelineForImage2Image
from diffusers.utils import load_image, make_image_grid
# use from_pipe to avoid consuming additional memory when loading a checkpoint
pipeline = AutoPipelineForImage2Image.from_pipe(pipeline_text2image).to("cuda")
init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png")
init_image = init_image.resize((512, 512))
prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k"
image = pipeline(prompt, image=init_image, strength=0.5, guidance_scale=0.0, num_inference_steps=2).images[0]
make_image_grid([init_image, image], rows=1, cols=2)
```
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/sdxl-turbo-img2img.png" alt="Image-to-image generation sample using SDXL Turbo"/>
</div>
## Speed-up SDXL Turbo even more
- Compile the UNet if you are using PyTorch version 2 or better. The first inference run will be very slow, but subsequent ones will be much faster.
```py
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
```
- When using the default VAE, keep it in `float32` to avoid costly `dtype` conversions before and after each generation. You only need to do this one before your first generation:
```py
pipe.upcast_vae()
```
As an alternative, you can also use a [16-bit VAE](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix) created by community member [`@madebyollin`](https://huggingface.co/madebyollin) that does not need to be upcasted to `float32`.
+134
View File
@@ -0,0 +1,134 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Stable Video Diffusion
[[open-in-colab]]
[Stable Video Diffusion](https://static1.squarespace.com/static/6213c340453c3f502425776e/t/655ce779b9d47d342a93c890/1700587395994/stable_video_diffusion.pdf) is a powerful image-to-video generation model that can generate high resolution (576x1024) 2-4 second videos conditioned on the input image.
This guide will show you how to use SVD to short generate videos from images.
Before you begin, make sure you have the following libraries installed:
```py
!pip install -q -U diffusers transformers accelerate
```
## Image to Video Generation
The are two variants of SVD. [SVD](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid)
and [SVD-XT](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt). The svd checkpoint is trained to generate 14 frames and the svd-xt checkpoint is further
finetuned to generate 25 frames.
We will use the `svd-xt` checkpoint for this guide.
```python
import torch
from diffusers import StableVideoDiffusionPipeline
from diffusers.utils import load_image, export_to_video
pipe = StableVideoDiffusionPipeline.from_pretrained(
"stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16"
)
pipe.enable_model_cpu_offload()
# Load the conditioning image
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png?download=true")
image = image.resize((1024, 576))
generator = torch.manual_seed(42)
frames = pipe(image, decode_chunk_size=8, generator=generator).frames[0]
export_to_video(frames, "generated.mp4", fps=7)
```
<video controls width="1024" height="576">
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket_generated.webm" type="video/webm" />
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket_generated.mp4" type="video/mp4" />
</video>
<Tip>
Since generating videos is more memory intensive we can use the `decode_chunk_size` argument to control how many frames are decoded at once. This will reduce the memory usage. It's recommended to tweak this value based on your GPU memory.
Setting `decode_chunk_size=1` will decode one frame at a time and will use the least amount of memory but the video might have some flickering.
Additionally, we also use [model cpu offloading](../../optimization/memory#model-offloading) to reduce the memory usage.
</Tip>
### Torch.compile
You can achieve a 20-25% speed-up at the expense of slightly increased memory by compiling the UNet as follows:
```diff
- pipe.enable_model_cpu_offload()
+ pipe.to("cuda")
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
```
### Low-memory
Video generation is very memory intensive as we have to essentially generate `num_frames` all at once. The mechanism is very comparable to text-to-image generation with a high batch size. To reduce the memory requirement you have multiple options. The following options trade inference speed against lower memory requirement:
- enable model offloading: Each component of the pipeline is offloaded to CPU once it's not needed anymore.
- enable feed-forward chunking: The feed-forward layer runs in a loop instead of running with a single huge feed-forward batch size
- reduce `decode_chunk_size`: This means that the VAE decodes frames in chunks instead of decoding them all together. **Note**: In addition to leading to a small slowdown, this method also slightly leads to video quality deterioration
You can enable them as follows:
```diff
-pipe.enable_model_cpu_offload()
-frames = pipe(image, decode_chunk_size=8, generator=generator).frames[0]
+pipe.enable_model_cpu_offload()
+pipe.unet.enable_forward_chunking()
+frames = pipe(image, decode_chunk_size=2, generator=generator, num_frames=25).frames[0]
```
Including all these tricks should lower the memory requirement to less than 8GB VRAM.
### Micro-conditioning
Along with conditioning image Stable Diffusion Video also allows providing micro-conditioning that allows more control over the generated video.
It accepts the following arguments:
- `fps`: The frames per second of the generated video.
- `motion_bucket_id`: The motion bucket id to use for the generated video. This can be used to control the motion of the generated video. Increasing the motion bucket id will increase the motion of the generated video.
- `noise_aug_strength`: The amount of noise added to the conditioning image. The higher the values the less the video will resemble the conditioning image. Increasing this value will also increase the motion of the generated video.
Here is an example of using micro-conditioning to generate a video with more motion.
```python
import torch
from diffusers import StableVideoDiffusionPipeline
from diffusers.utils import load_image, export_to_video
pipe = StableVideoDiffusionPipeline.from_pretrained(
"stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16"
)
pipe.enable_model_cpu_offload()
# Load the conditioning image
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png?download=true")
image = image.resize((1024, 576))
generator = torch.manual_seed(42)
frames = pipe(image, decode_chunk_size=8, generator=generator, motion_bucket_id=180, noise_aug_strength=0.1).frames[0]
export_to_video(frames, "generated.mp4", fps=7)
```
<video width="1024" height="576" controls>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket_generated_motion.mp4" type="video/mp4">
</video>
@@ -14,54 +14,41 @@ specific language governing permissions and limitations under the License.
[[open-in-colab]]
Unconditional image generation is a relatively straightforward task. The model only generates images - without any additional context like text or an image - resembling the training data it was trained on.
Unconditional image generation generates images that look like a random sample from the training data the model was trained on because the denoising process is not guided by any additional context like text or image.
The [`DiffusionPipeline`] is the easiest way to use a pre-trained diffusion system for inference.
To get started, use the [`DiffusionPipeline`] to load the [anton-l/ddpm-butterflies-128](https://huggingface.co/anton-l/ddpm-butterflies-128) checkpoint to generate images of butterflies. The [`DiffusionPipeline`] downloads and caches all the model components required to generate an image.
Start by creating an instance of [`DiffusionPipeline`] and specify which pipeline checkpoint you would like to download.
You can use any of the 🧨 Diffusers [checkpoints](https://huggingface.co/models?library=diffusers&sort=downloads) from the Hub (the checkpoint you'll use generates images of butterflies).
<Tip>
💡 Want to train your own unconditional image generation model? Take a look at the training [guide](../training/unconditional_training) to learn how to generate your own images.
</Tip>
In this guide, you'll use [`DiffusionPipeline`] for unconditional image generation with [DDPM](https://arxiv.org/abs/2006.11239):
```python
```py
from diffusers import DiffusionPipeline
generator = DiffusionPipeline.from_pretrained("anton-l/ddpm-butterflies-128", use_safetensors=True)
```
The [`DiffusionPipeline`] downloads and caches all modeling, tokenization, and scheduling components.
Because the model consists of roughly 1.4 billion parameters, we strongly recommend running it on a GPU.
You can move the generator object to a GPU, just like you would in PyTorch:
```python
generator.to("cuda")
```
Now you can use the `generator` to generate an image:
```python
generator = DiffusionPipeline.from_pretrained("anton-l/ddpm-butterflies-128").to("cuda")
image = generator().images[0]
image
```
The output is by default wrapped into a [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class) object.
<Tip>
You can save the image by calling:
Want to generate images of something else? Take a look at the training [guide](../training/unconditional_training) to learn how to train a model to generate your own images.
```python
</Tip>
The output image is a [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class) object that can be saved:
```py
image.save("generated_image.png")
```
Try out the Spaces below, and feel free to play around with the inference steps parameter to see how it affects the image quality!
You can also try experimenting with the `num_inference_steps` parameter, which controls the number of denoising steps. More denoising steps typically produce higher quality images, but it'll take longer to generate. Feel free to play around with this parameter to see how it affects the image quality.
```py
image = generator(num_inference_steps=100).images[0]
image
```
Try out the Space below to generate an image of a butterfly!
<iframe
src="https://stevhliu-ddpm-butterflies-128.hf.space"
src="https://stevhliu-unconditional-image-generation.hf.space"
frameborder="0"
width="850"
height="500"
+1
View File
@@ -96,3 +96,4 @@ specific language governing permissions and limitations under the License.
| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Dual Image and Text Guided Generation |
| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |
| [stable_diffusion_ldm3d](./api/pipelines/stable_diffusion/ldm3d_diffusion) | [LDM3D: Latent Diffusion Model for 3D](https://arxiv.org/abs/2305.10853) | Text to Image and Depth Generation |
| [stable_diffusion_upscaler_ldm3d](./api/pipelines/stable_diffusion/ldm3d_diffusion) | [LDM3D-VR: Latent Diffusion Model for 3D VR](https://arxiv.org/pdf/2311.03226) | Image and Depth Upscaling |
File diff suppressed because it is too large Load Diff
+579 -1
View File
@@ -48,7 +48,9 @@ prompt-to-prompt | change parts of a prompt and retain image structure (see [pap
| Latent Consistency Pipeline | Implementation of [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378) | [Latent Consistency Pipeline](#latent-consistency-pipeline) | - | [Simian Luo](https://github.com/luosiallen) |
| Latent Consistency Img2img Pipeline | Img2img pipeline for Latent Consistency Models | [Latent Consistency Img2Img Pipeline](#latent-consistency-img2img-pipeline) | - | [Logan Zoellner](https://github.com/nagolinc) |
| Latent Consistency Interpolation Pipeline | Interpolate the latent space of Latent Consistency Models with multiple prompts | [Latent Consistency Interpolation Pipeline](#latent-consistency-interpolation-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1pK3NrLWJSiJsBynLns1K1-IDTW9zbPvl?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
| Regional Prompting Pipeline | Assign multiple prompts for different regions | [Regional Prompting Pipeline](#regional-prompting-pipeline) | - | [hako-mikan](https://github.com/hako-mikan) |
| LDM3D-sr (LDM3D upscaler) | Upscale low resolution RGB and depth inputs to high resolution | [StableDiffusionUpscaleLDM3D Pipeline](https://github.com/estelleafl/diffusers/tree/ldm3d_upscaler_community/examples/community#stablediffusionupscaleldm3d-pipeline) | - | [Estelle Aflalo](https://github.com/estelleafl) |
| DemoFusion Pipeline | Implementation of [DemoFusion: Democratising High-Resolution Image Generation With No $$$](https://arxiv.org/abs/2311.16973) | [DemoFusion Pipeline](#DemoFusion) | - | [Ruoyi Du](https://github.com/RuoyiDu) |
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
```py
@@ -76,6 +78,7 @@ from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained(
"longlian/lmd_plus",
custom_pipeline="llm_grounded_diffusion",
custom_revision="main",
variant="fp16", torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload()
@@ -2343,3 +2346,578 @@ images = pipe(
assert len(images) == (len(prompts) - 1) * num_interpolation_steps
```
### StableDiffusionUpscaleLDM3D Pipeline
[LDM3D-VR](https://arxiv.org/pdf/2311.03226.pdf) is an extended version of LDM3D.
The abstract from the paper is:
*Latent diffusion models have proven to be state-of-the-art in the creation and manipulation of visual outputs. However, as far as we know, the generation of depth maps jointly with RGB is still limited. We introduce LDM3D-VR, a suite of diffusion models targeting virtual reality development that includes LDM3D-pano and LDM3D-SR. These models enable the generation of panoramic RGBD based on textual prompts and the upscaling of low-resolution inputs to high-resolution RGBD, respectively. Our models are fine-tuned from existing pretrained models on datasets containing panoramic/high-resolution RGB images, depth maps and captions. Both models are evaluated in comparison to existing related methods*
Two checkpoints are available for use:
- [ldm3d-pano](https://huggingface.co/Intel/ldm3d-pano). This checkpoint enables the generation of panoramic images and requires the StableDiffusionLDM3DPipeline pipeline to be used.
- [ldm3d-sr](https://huggingface.co/Intel/ldm3d-sr). This checkpoint enables the upscaling of RGB and depth images. Can be used in cascade after the original LDM3D pipeline using the StableDiffusionUpscaleLDM3DPipeline pipeline.
'''py
from PIL import Image
import os
import torch
from diffusers import StableDiffusionLDM3DPipeline, DiffusionPipeline
#Generate a rgb/depth output from LDM3D
pipe_ldm3d = StableDiffusionLDM3DPipeline.from_pretrained("Intel/ldm3d-4c")
pipe_ldm3d.to("cuda")
prompt =f"A picture of some lemons on a table"
output = pipe_ldm3d(prompt)
rgb_image, depth_image = output.rgb, output.depth
rgb_image[0].save(f"lemons_ldm3d_rgb.jpg")
depth_image[0].save(f"lemons_ldm3d_depth.png")
#Upscale the previous output to a resolution of (1024, 1024)
pipe_ldm3d_upscale = DiffusionPipeline.from_pretrained("Intel/ldm3d-sr", custom_pipeline="pipeline_stable_diffusion_upscale_ldm3d")
pipe_ldm3d_upscale.to("cuda")
low_res_img = Image.open(f"lemons_ldm3d_rgb.jpg").convert("RGB")
low_res_depth = Image.open(f"lemons_ldm3d_depth.png").convert("L")
outputs = pipe_ldm3d_upscale(prompt="high quality high resolution uhd 4k image", rgb=low_res_img, depth=low_res_depth, num_inference_steps=50, target_res=[1024, 1024])
upscaled_rgb, upscaled_depth =outputs.rgb[0], outputs.depth[0]
upscaled_rgb.save(f"upscaled_lemons_rgb.png")
upscaled_depth.save(f"upscaled_lemons_depth.png")
'''
### ControlNet + T2I Adapter Pipeline
This pipelines combines both ControlNet and T2IAdapter into a single pipeline, where the forward pass is executed once.
It receives `control_image` and `adapter_image`, as well as `controlnet_conditioning_scale` and `adapter_conditioning_scale`, for the ControlNet and Adapter modules, respectively. Whenever `adapter_conditioning_scale = 0` or `controlnet_conditioning_scale = 0`, it will act as a full ControlNet module or as a full T2IAdapter module, respectively.
```py
import cv2
import numpy as np
import torch
from controlnet_aux.midas import MidasDetector
from PIL import Image
from diffusers import AutoencoderKL, ControlNetModel, MultiAdapter, T2IAdapter
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
from diffusers.utils import load_image
from examples.community.pipeline_stable_diffusion_xl_controlnet_adapter import (
StableDiffusionXLControlNetAdapterPipeline,
)
controlnet_depth = ControlNetModel.from_pretrained(
"diffusers/controlnet-depth-sdxl-1.0",
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True
)
adapter_depth = T2IAdapter.from_pretrained(
"TencentARC/t2i-adapter-depth-midas-sdxl-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
)
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True)
pipe = StableDiffusionXLControlNetAdapterPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
controlnet=controlnet_depth,
adapter=adapter_depth,
vae=vae,
variant="fp16",
use_safetensors=True,
torch_dtype=torch.float16,
)
pipe = pipe.to("cuda")
pipe.enable_xformers_memory_efficient_attention()
# pipe.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)
midas_depth = MidasDetector.from_pretrained(
"valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
).to("cuda")
prompt = "a tiger sitting on a park bench"
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
image = load_image(img_url).resize((1024, 1024))
depth_image = midas_depth(
image, detect_resolution=512, image_resolution=1024
)
strength = 0.5
images = pipe(
prompt,
control_image=depth_image,
adapter_image=depth_image,
num_inference_steps=30,
controlnet_conditioning_scale=strength,
adapter_conditioning_scale=strength,
).images
images[0].save("controlnet_and_adapter.png")
```
### ControlNet + T2I Adapter + Inpainting Pipeline
```py
import cv2
import numpy as np
import torch
from controlnet_aux.midas import MidasDetector
from PIL import Image
from diffusers import AutoencoderKL, ControlNetModel, MultiAdapter, T2IAdapter
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
from diffusers.utils import load_image
from examples.community.pipeline_stable_diffusion_xl_controlnet_adapter_inpaint import (
StableDiffusionXLControlNetAdapterInpaintPipeline,
)
controlnet_depth = ControlNetModel.from_pretrained(
"diffusers/controlnet-depth-sdxl-1.0",
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True
)
adapter_depth = T2IAdapter.from_pretrained(
"TencentARC/t2i-adapter-depth-midas-sdxl-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
)
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True)
pipe = StableDiffusionXLControlNetAdapterInpaintPipeline.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
controlnet=controlnet_depth,
adapter=adapter_depth,
vae=vae,
variant="fp16",
use_safetensors=True,
torch_dtype=torch.float16,
)
pipe = pipe.to("cuda")
pipe.enable_xformers_memory_efficient_attention()
# pipe.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)
midas_depth = MidasDetector.from_pretrained(
"valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
).to("cuda")
prompt = "a tiger sitting on a park bench"
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
image = load_image(img_url).resize((1024, 1024))
mask_image = load_image(mask_url).resize((1024, 1024))
depth_image = midas_depth(
image, detect_resolution=512, image_resolution=1024
)
strength = 0.4
images = pipe(
prompt,
image=image,
mask_image=mask_image,
control_image=depth_image,
adapter_image=depth_image,
num_inference_steps=30,
controlnet_conditioning_scale=strength,
adapter_conditioning_scale=strength,
strength=0.7,
).images
images[0].save("controlnet_and_adapter_inpaint.png")
```
### Regional Prompting Pipeline
This pipeline is a port of the [Regional Prompter extension](https://github.com/hako-mikan/sd-webui-regional-prompter) for [Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) to diffusers.
This code implements a pipeline for the Stable Diffusion model, enabling the division of the canvas into multiple regions, with different prompts applicable to each region. Users can specify regions in two ways: using `Cols` and `Rows` modes for grid-like divisions, or the `Prompt` mode for regions calculated based on prompts.
![sample](https://github.com/hako-mikan/sd-webui-regional-prompter/blob/imgs/rp_pipeline1.png)
### Usage
### Sample Code
```
from from examples.community.regional_prompting_stable_diffusion import RegionalPromptingStableDiffusionPipeline
pipe = RegionalPromptingStableDiffusionPipeline.from_single_file(model_path, vae=vae)
rp_args = {
"mode":"rows",
"div": "1;1;1"
}
prompt ="""
green hair twintail BREAK
red blouse BREAK
blue skirt
"""
images = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=7.5,
height = 768,
width = 512,
num_inference_steps =20,
num_images_per_prompt = 1,
rp_args = rp_args
).images
time = time.strftime(r"%Y%m%d%H%M%S")
i = 1
for image in images:
i += 1
fileName = f'img-{time}-{i+1}.png'
image.save(fileName)
```
### Cols, Rows mode
In the Cols, Rows mode, you can split the screen vertically and horizontally and assign prompts to each region. The split ratio can be specified by 'div', and you can set the division ratio like '3;3;2' or '0.1;0.5'. Furthermore, as will be described later, you can also subdivide the split Cols, Rows to specify more complex regions.
In this image, the image is divided into three parts, and a separate prompt is applied to each. The prompts are divided by 'BREAK', and each is applied to the respective region.
![sample](https://github.com/hako-mikan/sd-webui-regional-prompter/blob/imgs/rp_pipeline2.png)
```
green hair twintail BREAK
red blouse BREAK
blue skirt
```
### 2-Dimentional division
The prompt consists of instructions separated by the term `BREAK` and is assigned to different regions of a two-dimensional space. The image is initially split in the main splitting direction, which in this case is rows, due to the presence of a single semicolon`;`, dividing the space into an upper and a lower section. Additional sub-splitting is then applied, indicated by commas. The upper row is split into ratios of `2:1:1`, while the lower row is split into a ratio of `4:6`. Rows themselves are split in a `1:2` ratio. According to the reference image, the blue sky is designated as the first region, green hair as the second, the bookshelf as the third, and so on, in a sequence based on their position from the top left. The terrarium is placed on the desk in the fourth region, and the orange dress and sofa are in the fifth region, conforming to their respective splits.
```
rp_args = {
"mode":"rows",
"div": "1,2,1,1;2,4,6"
}
prompt ="""
blue sky BREAK
green hair BREAK
book shelf BREAK
terrarium on desk BREAK
orange dress and sofa
"""
```
![sample](https://github.com/hako-mikan/sd-webui-regional-prompter/blob/imgs/rp_pipeline4.png)
### Prompt Mode
There are limitations to methods of specifying regions in advance. This is because specifying regions can be a hindrance when designating complex shapes or dynamic compositions. In the region specified by the prompt, the regions is determined after the image generation has begun. This allows us to accommodate compositions and complex regions.
For further infomagen, see [here](https://github.com/hako-mikan/sd-webui-regional-prompter/blob/main/prompt_en.md).
### syntax
```
baseprompt target1 target2 BREAK
effect1, target1 BREAK
effect2 ,target2
```
First, write the base prompt. In the base prompt, write the words (target1, target2) for which you want to create a mask. Next, separate them with BREAK. Next, write the prompt corresponding to target1. Then enter a comma and write target1. The order of the targets in the base prompt and the order of the BREAK-separated targets can be back to back.
```
target2 baseprompt target1 BREAK
effect1, target1 BREAK
effect2 ,target2
```
is also effective.
### Sample
In this example, masks are calculated for shirt, tie, skirt, and color prompts are specified only for those regions.
```
rp_args = {
"mode":"prompt-ex",
"save_mask":True,
"th": "0.4,0.6,0.6",
}
prompt ="""
a girl in street with shirt, tie, skirt BREAK
red, shirt BREAK
green, tie BREAK
blue , skirt
"""
```
![sample](https://github.com/hako-mikan/sd-webui-regional-prompter/blob/imgs/rp_pipeline3.png)
### threshold
The threshold used to determine the mask created by the prompt. This can be set as many times as there are masks, as the range varies widely depending on the target prompt. If multiple regions are used, enter them separated by commas. For example, hair tends to be ambiguous and requires a small value, while face tends to be large and requires a small value. These should be ordered by BREAK.
```
a lady ,hair, face BREAK
red, hair BREAK
tanned ,face
```
`threshold : 0.4,0.6`
If only one input is given for multiple regions, they are all assumed to be the same value.
### Prompt and Prompt-EX
The difference is that in Prompt, duplicate regions are added, whereas in Prompt-EX, duplicate regions are overwritten sequentially. Since they are processed in order, setting a TARGET with a large regions first makes it easier for the effect of small regions to remain unmuffled.
### Accuracy
In the case of a 512 x 512 image, Attention mode reduces the size of the region to about 8 x 8 pixels deep in the U-Net, so that small regions get mixed up; Latent mode calculates 64*64, so that the region is exact.
```
girl hair twintail frills,ribbons, dress, face BREAK
girl, ,face
```
### Mask
When an image is generated, the generated mask is displayed. It is generated at the same size as the image, but is actually used at a much smaller size.
### Use common prompt
You can attach the prompt up to ADDCOMM to all prompts by separating it first with ADDCOMM. This is useful when you want to include elements common to all regions. For example, when generating pictures of three people with different appearances, it's necessary to include the instruction of 'three people' in all regions. It's also useful when inserting quality tags and other things."For example, if you write as follows:
```
best quality, 3persons in garden, ADDCOMM
a girl white dress BREAK
a boy blue shirt BREAK
an old man red suit
```
If common is enabled, this prompt is converted to the following:
```
best quality, 3persons in garden, a girl white dress BREAK
best quality, 3persons in garden, a boy blue shirt BREAK
best quality, 3persons in garden, an old man red suit
```
### Negative prompt
Negative prompts are equally effective across all regions, but it is possible to set region-specific prompts for negative prompts as well. The number of BREAKs must be the same as the number of prompts. If the number of prompts does not match, the negative prompts will be used without being divided into regions.
### Parameters
To activate Regional Prompter, it is necessary to enter settings in rp_args. The items that can be set are as follows. rp_args is a dictionary type.
### Input Parameters
Parameters are specified through the `rp_arg`(dictionary type).
```
rp_args = {
"mode":"rows",
"div": "1;1;1"
}
pipe(prompt =prompt, rp_args = rp_args)
```
### Required Parameters
- `mode`: Specifies the method for defining regions. Choose from `Cols`, `Rows`, `Prompt` or `Prompt-Ex`. This parameter is case-insensitive.
- `divide`: Used in `Cols` and `Rows` modes. Details on how to specify this are provided under the respective `Cols` and `Rows` sections.
- `th`: Used in `Prompt` mode. The method of specification is detailed under the `Prompt` section.
### Optional Parameters
- `save_mask`: In `Prompt` mode, choose whether to output the generated mask along with the image. The default is `False`.
The Pipeline supports `compel` syntax. Input prompts using the `compel` structure will be automatically applied and processed.
## Diffusion Posterior Sampling Pipeline
* Reference paper
```
@article{chung2022diffusion,
title={Diffusion posterior sampling for general noisy inverse problems},
author={Chung, Hyungjin and Kim, Jeongsol and Mccann, Michael T and Klasky, Marc L and Ye, Jong Chul},
journal={arXiv preprint arXiv:2209.14687},
year={2022}
}
```
* This pipeline allows zero-shot conditional sampling from the posterior distribution $p(x|y)$, given observation on $y$, unconditional generative model $p(x)$ and differentiable operator $y=f(x)$.
* For example, $f(.)$ can be downsample operator, then $y$ is a downsampled image, and the pipeline becomes a super-resolution pipeline.
* To use this pipeline, you need to know your operator $f(.)$ and corrupted image $y$, and pass them during the call. For example, as in the main function of dps_pipeline.py, you need to first define the Gaussian blurring operator $f(.)$. The operator should be a callable nn.Module, with all the parameter gradient disabled:
```python
import torch.nn.functional as F
import scipy
from torch import nn
# define the Gaussian blurring operator first
class GaussialBlurOperator(nn.Module):
def __init__(self, kernel_size, intensity):
super().__init__()
class Blurkernel(nn.Module):
def __init__(self, blur_type='gaussian', kernel_size=31, std=3.0):
super().__init__()
self.blur_type = blur_type
self.kernel_size = kernel_size
self.std = std
self.seq = nn.Sequential(
nn.ReflectionPad2d(self.kernel_size//2),
nn.Conv2d(3, 3, self.kernel_size, stride=1, padding=0, bias=False, groups=3)
)
self.weights_init()
def forward(self, x):
return self.seq(x)
def weights_init(self):
if self.blur_type == "gaussian":
n = np.zeros((self.kernel_size, self.kernel_size))
n[self.kernel_size // 2,self.kernel_size // 2] = 1
k = scipy.ndimage.gaussian_filter(n, sigma=self.std)
k = torch.from_numpy(k)
self.k = k
for name, f in self.named_parameters():
f.data.copy_(k)
elif self.blur_type == "motion":
k = Kernel(size=(self.kernel_size, self.kernel_size), intensity=self.std).kernelMatrix
k = torch.from_numpy(k)
self.k = k
for name, f in self.named_parameters():
f.data.copy_(k)
def update_weights(self, k):
if not torch.is_tensor(k):
k = torch.from_numpy(k)
for name, f in self.named_parameters():
f.data.copy_(k)
def get_kernel(self):
return self.k
self.kernel_size = kernel_size
self.conv = Blurkernel(blur_type='gaussian',
kernel_size=kernel_size,
std=intensity)
self.kernel = self.conv.get_kernel()
self.conv.update_weights(self.kernel.type(torch.float32))
for param in self.parameters():
param.requires_grad=False
def forward(self, data, **kwargs):
return self.conv(data)
def transpose(self, data, **kwargs):
return data
def get_kernel(self):
return self.kernel.view(1, 1, self.kernel_size, self.kernel_size)
```
* Next, you should obtain the corrupted image $y$ by the operator. In this example, we generate $y$ from the source image $x$. However in practice, having the operator $f(.)$ and corrupted image $y$ is enough:
```python
# set up source image
src = Image.open('sample.png')
# read image into [1,3,H,W]
src = torch.from_numpy(np.array(src, dtype=np.float32)).permute(2,0,1)[None]
# normalize image to [-1,1]
src = (src / 127.5) - 1.0
src = src.to("cuda")
# set up operator and measurement
operator = GaussialBlurOperator(kernel_size=61, intensity=3.0).to("cuda")
measurement = operator(src)
# save the source and corrupted images
save_image((src+1.0)/2.0, "dps_src.png")
save_image((measurement+1.0)/2.0, "dps_mea.png")
```
* We provide an example pair of saved source and corrupted images, using the Gaussian blur operator above
* Source image:
* ![sample](https://github.com/tongdaxu/Images/assets/22267548/4d2a1216-08d1-4aeb-9ce3-7a2d87561d65)
* Gaussian blurred image:
* ![ddpm_generated_image](https://github.com/tongdaxu/Images/assets/22267548/65076258-344b-4ed8-b704-a04edaade8ae)
* You can download those image to run the example on your own.
* Next, we need to define a loss function used for diffusion posterior sample. For most of the cases, the RMSE is fine:
```python
def RMSELoss(yhat, y):
return torch.sqrt(torch.sum((yhat-y)**2))
```
* And next, as any other diffusion models, we need the score estimator and scheduler. As we are working with $256x256$ face images, we use ddmp-celebahq-256:
```python
# set up scheduler
scheduler = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256")
scheduler.set_timesteps(1000)
# set up model
model = UNet2DModel.from_pretrained("google/ddpm-celebahq-256").to("cuda")
```
* And finally, run the pipeline:
```python
# finally, the pipeline
dpspipe = DPSPipeline(model, scheduler)
image = dpspipe(
measurement = measurement,
operator = operator,
loss_fn = RMSELoss,
zeta = 1.0,
).images[0]
image.save("dps_generated_image.png")
```
* The zeta is a hyperparameter that is in range of $[0,1]$. It need to be tuned for best effect. By setting zeta=1, you should be able to have the reconstructed result:
* Reconstructed image:
* ![sample](https://github.com/tongdaxu/Images/assets/22267548/0ceb5575-d42e-4f0b-99c0-50e69c982209)
* The reconstruction is perceptually similar to the source image, but different in details.
* In dps_pipeline.py, we also provide a super-resolution example, which should produce:
* Downsampled image:
* ![dps_mea](https://github.com/tongdaxu/Images/assets/22267548/ff6a33d6-26f0-42aa-88ce-f8a76ba45a13)
* Reconstructed image:
* ![dps_generated_image](https://github.com/tongdaxu/Images/assets/22267548/b74f084d-93f4-4845-83d8-44c0fa758a5f)
### DemoFusion
This pipeline is the official implementation of [DemoFusion: Democratising High-Resolution Image Generation With No $$$](https://arxiv.org/abs/2311.16973).
The original repo can be found at [repo](https://github.com/PRIS-CV/DemoFusion).
- `view_batch_size` (`int`, defaults to 16):
The batch size for multiple denoising paths. Typically, a larger batch size can result in higher efficiency but comes with increased GPU memory requirements.
- `stride` (`int`, defaults to 64):
The stride of moving local patches. A smaller stride is better for alleviating seam issues, but it also introduces additional computational overhead and inference time.
- `cosine_scale_1` (`float`, defaults to 3):
Control the strength of skip-residual. For specific impacts, please refer to Appendix C in the DemoFusion paper.
- `cosine_scale_2` (`float`, defaults to 1):
Control the strength of dilated sampling. For specific impacts, please refer to Appendix C in the DemoFusion paper.
- `cosine_scale_3` (`float`, defaults to 1):
Control the strength of the Gaussian filter. For specific impacts, please refer to Appendix C in the DemoFusion paper.
- `sigma` (`float`, defaults to 1):
The standard value of the Gaussian filter. Larger sigma promotes the global guidance of dilated sampling, but has the potential of over-smoothing.
- `multi_decoder` (`bool`, defaults to True):
Determine whether to use a tiled decoder. Generally, when the resolution exceeds 3072x3072, a tiled decoder becomes necessary.
- `show_image` (`bool`, defaults to False):
Determine whether to show intermediate results during generation.
```
from pipeline_demofusion_sdxl import DemoFusionSDXLPipeline
model_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"
pipe = DemoFusionSDXLPipeline.from_pretrained(model_ckpt, torch_dtype=torch.float16)
pipe = pipe.to("cuda")
prompt = "Envision a portrait of an elderly woman, her face a canvas of time, framed by a headscarf with muted tones of rust and cream. Her eyes, blue like faded denim. Her attire, simple yet dignified."
negative_prompt = "blurry, ugly, duplicate, poorly drawn, deformed, mosaic"
images = pipe(
prompt,
negative_prompt=negative_prompt,
height=3072,
width=3072,
view_batch_size=16,
stride=64,
num_inference_steps=50,
guidance_scale=7.5,
cosine_scale_1=3,
cosine_scale_2=1,
cosine_scale_3=1,
sigma=0.8,
multi_decoder=True,
show_image=True
)
```
You can display and save the generated images as:
```
def image_grid(imgs, save_path=None):
w = 0
for i, img in enumerate(imgs):
h_, w_ = imgs[i].size
w += w_
h = h_
grid = Image.new('RGB', size=(w, h))
grid_w, grid_h = grid.size
w = 0
for i, img in enumerate(imgs):
h_, w_ = imgs[i].size
grid.paste(img, box=(w, h - h_))
if save_path != None:
img.save(save_path + "/img_{}.jpg".format((i + 1) * 1024))
w += w_
return grid
image_grid(images, save_path="./outputs/")
```
![output_example](https://github.com/PRIS-CV/DemoFusion/blob/main/output_example.png)
+466
View File
@@ -0,0 +1,466 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from math import pi
from typing import Callable, List, Optional, Tuple, Union
import numpy as np
import torch
from PIL import Image
from diffusers import DDPMScheduler, DiffusionPipeline, ImagePipelineOutput, UNet2DModel
from diffusers.utils.torch_utils import randn_tensor
class DPSPipeline(DiffusionPipeline):
r"""
Pipeline for Diffusion Posterior Sampling.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Parameters:
unet ([`UNet2DModel`]):
A `UNet2DModel` to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
[`DDPMScheduler`], or [`DDIMScheduler`].
"""
model_cpu_offload_seq = "unet"
def __init__(self, unet, scheduler):
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(
self,
measurement: torch.Tensor,
operator: torch.nn.Module,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
batch_size: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
num_inference_steps: int = 1000,
output_type: Optional[str] = "pil",
return_dict: bool = True,
zeta: float = 0.3,
) -> Union[ImagePipelineOutput, Tuple]:
r"""
The call function to the pipeline for generation.
Args:
measurement (`torch.Tensor`, *required*):
A 'torch.Tensor', the corrupted image
operator (`torch.nn.Module`, *required*):
A 'torch.nn.Module', the operator generating the corrupted image
loss_fn (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *required*):
A 'Callable[[torch.Tensor, torch.Tensor], torch.Tensor]', the loss function used
between the measurements, for most of the cases using RMSE is fine.
batch_size (`int`, *optional*, defaults to 1):
The number of images to generate.
generator (`torch.Generator`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
num_inference_steps (`int`, *optional*, defaults to 1000):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
Example:
```py
>>> from diffusers import DDPMPipeline
>>> # load model and scheduler
>>> pipe = DDPMPipeline.from_pretrained("google/ddpm-cat-256")
>>> # run pipeline in inference (sample random noise and denoise)
>>> image = pipe().images[0]
>>> # save image
>>> image.save("ddpm_generated_image.png")
```
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
returned where the first element is a list with the generated images
"""
# Sample gaussian noise to begin loop
if isinstance(self.unet.config.sample_size, int):
image_shape = (
batch_size,
self.unet.config.in_channels,
self.unet.config.sample_size,
self.unet.config.sample_size,
)
else:
image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)
if self.device.type == "mps":
# randn does not work reproducibly on mps
image = randn_tensor(image_shape, generator=generator)
image = image.to(self.device)
else:
image = randn_tensor(image_shape, generator=generator, device=self.device)
# set step values
self.scheduler.set_timesteps(num_inference_steps)
for t in self.progress_bar(self.scheduler.timesteps):
with torch.enable_grad():
# 1. predict noise model_output
image = image.requires_grad_()
model_output = self.unet(image, t).sample
# 2. compute previous image x'_{t-1} and original prediction x0_{t}
scheduler_out = self.scheduler.step(model_output, t, image, generator=generator)
image_pred, origi_pred = scheduler_out.prev_sample, scheduler_out.pred_original_sample
# 3. compute y'_t = f(x0_{t})
measurement_pred = operator(origi_pred)
# 4. compute loss = d(y, y'_t-1)
loss = loss_fn(measurement, measurement_pred)
loss.backward()
print("distance: {0:.4f}".format(loss.item()))
with torch.no_grad():
image_pred = image_pred - zeta * image.grad
image = image_pred.detach()
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
if __name__ == "__main__":
import scipy
from torch import nn
from torchvision.utils import save_image
# defining the operators f(.) of y = f(x)
# super-resolution operator
class SuperResolutionOperator(nn.Module):
def __init__(self, in_shape, scale_factor):
super().__init__()
# Resizer local class, do not use outiside the SR operator class
class Resizer(nn.Module):
def __init__(self, in_shape, scale_factor=None, output_shape=None, kernel=None, antialiasing=True):
super(Resizer, self).__init__()
# First standardize values and fill missing arguments (if needed) by deriving scale from output shape or vice versa
scale_factor, output_shape = self.fix_scale_and_size(in_shape, output_shape, scale_factor)
# Choose interpolation method, each method has the matching kernel size
def cubic(x):
absx = np.abs(x)
absx2 = absx**2
absx3 = absx**3
return (1.5 * absx3 - 2.5 * absx2 + 1) * (absx <= 1) + (
-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2
) * ((1 < absx) & (absx <= 2))
def lanczos2(x):
return (
(np.sin(pi * x) * np.sin(pi * x / 2) + np.finfo(np.float32).eps)
/ ((pi**2 * x**2 / 2) + np.finfo(np.float32).eps)
) * (abs(x) < 2)
def box(x):
return ((-0.5 <= x) & (x < 0.5)) * 1.0
def lanczos3(x):
return (
(np.sin(pi * x) * np.sin(pi * x / 3) + np.finfo(np.float32).eps)
/ ((pi**2 * x**2 / 3) + np.finfo(np.float32).eps)
) * (abs(x) < 3)
def linear(x):
return (x + 1) * ((-1 <= x) & (x < 0)) + (1 - x) * ((0 <= x) & (x <= 1))
method, kernel_width = {
"cubic": (cubic, 4.0),
"lanczos2": (lanczos2, 4.0),
"lanczos3": (lanczos3, 6.0),
"box": (box, 1.0),
"linear": (linear, 2.0),
None: (cubic, 4.0), # set default interpolation method as cubic
}.get(kernel)
# Antialiasing is only used when downscaling
antialiasing *= np.any(np.array(scale_factor) < 1)
# Sort indices of dimensions according to scale of each dimension. since we are going dim by dim this is efficient
sorted_dims = np.argsort(np.array(scale_factor))
self.sorted_dims = [int(dim) for dim in sorted_dims if scale_factor[dim] != 1]
# Iterate over dimensions to calculate local weights for resizing and resize each time in one direction
field_of_view_list = []
weights_list = []
for dim in self.sorted_dims:
# for each coordinate (along 1 dim), calculate which coordinates in the input image affect its result and the
# weights that multiply the values there to get its result.
weights, field_of_view = self.contributions(
in_shape[dim], output_shape[dim], scale_factor[dim], method, kernel_width, antialiasing
)
# convert to torch tensor
weights = torch.tensor(weights.T, dtype=torch.float32)
# We add singleton dimensions to the weight matrix so we can multiply it with the big tensor we get for
# tmp_im[field_of_view.T], (bsxfun style)
weights_list.append(
nn.Parameter(
torch.reshape(weights, list(weights.shape) + (len(scale_factor) - 1) * [1]),
requires_grad=False,
)
)
field_of_view_list.append(
nn.Parameter(
torch.tensor(field_of_view.T.astype(np.int32), dtype=torch.long), requires_grad=False
)
)
self.field_of_view = nn.ParameterList(field_of_view_list)
self.weights = nn.ParameterList(weights_list)
def forward(self, in_tensor):
x = in_tensor
# Use the affecting position values and the set of weights to calculate the result of resizing along this 1 dim
for dim, fov, w in zip(self.sorted_dims, self.field_of_view, self.weights):
# To be able to act on each dim, we swap so that dim 0 is the wanted dim to resize
x = torch.transpose(x, dim, 0)
# This is a bit of a complicated multiplication: x[field_of_view.T] is a tensor of order image_dims+1.
# for each pixel in the output-image it matches the positions the influence it from the input image (along 1 dim
# only, this is why it only adds 1 dim to 5the shape). We then multiply, for each pixel, its set of positions with
# the matching set of weights. we do this by this big tensor element-wise multiplication (MATLAB bsxfun style:
# matching dims are multiplied element-wise while singletons mean that the matching dim is all multiplied by the
# same number
x = torch.sum(x[fov] * w, dim=0)
# Finally we swap back the axes to the original order
x = torch.transpose(x, dim, 0)
return x
def fix_scale_and_size(self, input_shape, output_shape, scale_factor):
# First fixing the scale-factor (if given) to be standardized the function expects (a list of scale factors in the
# same size as the number of input dimensions)
if scale_factor is not None:
# By default, if scale-factor is a scalar we assume 2d resizing and duplicate it.
if np.isscalar(scale_factor) and len(input_shape) > 1:
scale_factor = [scale_factor, scale_factor]
# We extend the size of scale-factor list to the size of the input by assigning 1 to all the unspecified scales
scale_factor = list(scale_factor)
scale_factor = [1] * (len(input_shape) - len(scale_factor)) + scale_factor
# Fixing output-shape (if given): extending it to the size of the input-shape, by assigning the original input-size
# to all the unspecified dimensions
if output_shape is not None:
output_shape = list(input_shape[len(output_shape) :]) + list(np.uint(np.array(output_shape)))
# Dealing with the case of non-give scale-factor, calculating according to output-shape. note that this is
# sub-optimal, because there can be different scales to the same output-shape.
if scale_factor is None:
scale_factor = 1.0 * np.array(output_shape) / np.array(input_shape)
# Dealing with missing output-shape. calculating according to scale-factor
if output_shape is None:
output_shape = np.uint(np.ceil(np.array(input_shape) * np.array(scale_factor)))
return scale_factor, output_shape
def contributions(self, in_length, out_length, scale, kernel, kernel_width, antialiasing):
# This function calculates a set of 'filters' and a set of field_of_view that will later on be applied
# such that each position from the field_of_view will be multiplied with a matching filter from the
# 'weights' based on the interpolation method and the distance of the sub-pixel location from the pixel centers
# around it. This is only done for one dimension of the image.
# When anti-aliasing is activated (default and only for downscaling) the receptive field is stretched to size of
# 1/sf. this means filtering is more 'low-pass filter'.
fixed_kernel = (lambda arg: scale * kernel(scale * arg)) if antialiasing else kernel
kernel_width *= 1.0 / scale if antialiasing else 1.0
# These are the coordinates of the output image
out_coordinates = np.arange(1, out_length + 1)
# since both scale-factor and output size can be provided simulatneously, perserving the center of the image requires shifting
# the output coordinates. the deviation is because out_length doesn't necesary equal in_length*scale.
# to keep the center we need to subtract half of this deivation so that we get equal margins for boths sides and center is preserved.
shifted_out_coordinates = out_coordinates - (out_length - in_length * scale) / 2
# These are the matching positions of the output-coordinates on the input image coordinates.
# Best explained by example: say we have 4 horizontal pixels for HR and we downscale by SF=2 and get 2 pixels:
# [1,2,3,4] -> [1,2]. Remember each pixel number is the middle of the pixel.
# The scaling is done between the distances and not pixel numbers (the right boundary of pixel 4 is transformed to
# the right boundary of pixel 2. pixel 1 in the small image matches the boundary between pixels 1 and 2 in the big
# one and not to pixel 2. This means the position is not just multiplication of the old pos by scale-factor).
# So if we measure distance from the left border, middle of pixel 1 is at distance d=0.5, border between 1 and 2 is
# at d=1, and so on (d = p - 0.5). we calculate (d_new = d_old / sf) which means:
# (p_new-0.5 = (p_old-0.5) / sf) -> p_new = p_old/sf + 0.5 * (1-1/sf)
match_coordinates = shifted_out_coordinates / scale + 0.5 * (1 - 1 / scale)
# This is the left boundary to start multiplying the filter from, it depends on the size of the filter
left_boundary = np.floor(match_coordinates - kernel_width / 2)
# Kernel width needs to be enlarged because when covering has sub-pixel borders, it must 'see' the pixel centers
# of the pixels it only covered a part from. So we add one pixel at each side to consider (weights can zeroize them)
expanded_kernel_width = np.ceil(kernel_width) + 2
# Determine a set of field_of_view for each each output position, these are the pixels in the input image
# that the pixel in the output image 'sees'. We get a matrix whos horizontal dim is the output pixels (big) and the
# vertical dim is the pixels it 'sees' (kernel_size + 2)
field_of_view = np.squeeze(
np.int16(np.expand_dims(left_boundary, axis=1) + np.arange(expanded_kernel_width) - 1)
)
# Assign weight to each pixel in the field of view. A matrix whos horizontal dim is the output pixels and the
# vertical dim is a list of weights matching to the pixel in the field of view (that are specified in
# 'field_of_view')
weights = fixed_kernel(1.0 * np.expand_dims(match_coordinates, axis=1) - field_of_view - 1)
# Normalize weights to sum up to 1. be careful from dividing by 0
sum_weights = np.sum(weights, axis=1)
sum_weights[sum_weights == 0] = 1.0
weights = 1.0 * weights / np.expand_dims(sum_weights, axis=1)
# We use this mirror structure as a trick for reflection padding at the boundaries
mirror = np.uint(np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1))))
field_of_view = mirror[np.mod(field_of_view, mirror.shape[0])]
# Get rid of weights and pixel positions that are of zero weight
non_zero_out_pixels = np.nonzero(np.any(weights, axis=0))
weights = np.squeeze(weights[:, non_zero_out_pixels])
field_of_view = np.squeeze(field_of_view[:, non_zero_out_pixels])
# Final products are the relative positions and the matching weights, both are output_size X fixed_kernel_size
return weights, field_of_view
self.down_sample = Resizer(in_shape, 1 / scale_factor)
for param in self.parameters():
param.requires_grad = False
def forward(self, data, **kwargs):
return self.down_sample(data)
# Gaussian blurring operator
class GaussialBlurOperator(nn.Module):
def __init__(self, kernel_size, intensity):
super().__init__()
class Blurkernel(nn.Module):
def __init__(self, blur_type="gaussian", kernel_size=31, std=3.0):
super().__init__()
self.blur_type = blur_type
self.kernel_size = kernel_size
self.std = std
self.seq = nn.Sequential(
nn.ReflectionPad2d(self.kernel_size // 2),
nn.Conv2d(3, 3, self.kernel_size, stride=1, padding=0, bias=False, groups=3),
)
self.weights_init()
def forward(self, x):
return self.seq(x)
def weights_init(self):
if self.blur_type == "gaussian":
n = np.zeros((self.kernel_size, self.kernel_size))
n[self.kernel_size // 2, self.kernel_size // 2] = 1
k = scipy.ndimage.gaussian_filter(n, sigma=self.std)
k = torch.from_numpy(k)
self.k = k
for name, f in self.named_parameters():
f.data.copy_(k)
def update_weights(self, k):
if not torch.is_tensor(k):
k = torch.from_numpy(k)
for name, f in self.named_parameters():
f.data.copy_(k)
def get_kernel(self):
return self.k
self.kernel_size = kernel_size
self.conv = Blurkernel(blur_type="gaussian", kernel_size=kernel_size, std=intensity)
self.kernel = self.conv.get_kernel()
self.conv.update_weights(self.kernel.type(torch.float32))
for param in self.parameters():
param.requires_grad = False
def forward(self, data, **kwargs):
return self.conv(data)
def transpose(self, data, **kwargs):
return data
def get_kernel(self):
return self.kernel.view(1, 1, self.kernel_size, self.kernel_size)
# assuming the forward process y = f(x) is polluted by Gaussian noise, use l2 norm
def RMSELoss(yhat, y):
return torch.sqrt(torch.sum((yhat - y) ** 2))
# set up source image
src = Image.open("sample.png")
# read image into [1,3,H,W]
src = torch.from_numpy(np.array(src, dtype=np.float32)).permute(2, 0, 1)[None]
# normalize image to [-1,1]
src = (src / 127.5) - 1.0
src = src.to("cuda")
# set up operator and measurement
# operator = SuperResolutionOperator(in_shape=src.shape, scale_factor=4).to("cuda")
operator = GaussialBlurOperator(kernel_size=61, intensity=3.0).to("cuda")
measurement = operator(src)
# set up scheduler
scheduler = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256")
scheduler.set_timesteps(1000)
# set up model
model = UNet2DModel.from_pretrained("google/ddpm-celebahq-256").to("cuda")
save_image((src + 1.0) / 2.0, "dps_src.png")
save_image((measurement + 1.0) / 2.0, "dps_mea.png")
# finally, the pipeline
dpspipe = DPSPipeline(model, scheduler)
image = dpspipe(
measurement=measurement,
operator=operator,
loss_fn=RMSELoss,
zeta=1.0,
).images[0]
image.save("dps_generated_image.png")
+620 -22
View File
@@ -16,6 +16,7 @@
import ast
import gc
import inspect
import math
import warnings
from collections.abc import Iterable
@@ -23,16 +24,29 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch
import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from packaging import version
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers.configuration_utils import FrozenDict
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.attention import Attention, GatedSelfAttentionDense
from diffusers.models.attention_processor import AttnProcessor2_0
from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.pipelines import DiffusionPipeline
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import logging, replace_example_docstring
from diffusers.utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from diffusers.utils.torch_utils import randn_tensor
EXAMPLE_DOC_STRING = """
@@ -44,6 +58,7 @@ EXAMPLE_DOC_STRING = """
>>> pipe = DiffusionPipeline.from_pretrained(
... "longlian/lmd_plus",
... custom_pipeline="llm_grounded_diffusion",
... custom_revision="main",
... variant="fp16", torch_dtype=torch.float16
... )
>>> pipe.enable_model_cpu_offload()
@@ -96,7 +111,12 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# All keys in Stable Diffusion models: [('down', 0, 0, 0), ('down', 0, 1, 0), ('down', 1, 0, 0), ('down', 1, 1, 0), ('down', 2, 0, 0), ('down', 2, 1, 0), ('mid', 0, 0, 0), ('up', 1, 0, 0), ('up', 1, 1, 0), ('up', 1, 2, 0), ('up', 2, 0, 0), ('up', 2, 1, 0), ('up', 2, 2, 0), ('up', 3, 0, 0), ('up', 3, 1, 0), ('up', 3, 2, 0)]
# Note that the first up block is `UpBlock2D` rather than `CrossAttnUpBlock2D` and does not have attention. The last index is always 0 in our case since we have one `BasicTransformerBlock` in each `Transformer2DModel`.
DEFAULT_GUIDANCE_ATTN_KEYS = [("mid", 0, 0, 0), ("up", 1, 0, 0), ("up", 1, 1, 0), ("up", 1, 2, 0)]
DEFAULT_GUIDANCE_ATTN_KEYS = [
("mid", 0, 0, 0),
("up", 1, 0, 0),
("up", 1, 1, 0),
("up", 1, 2, 0),
]
def convert_attn_keys(key):
@@ -126,7 +146,15 @@ def scale_proportion(obj_box, H, W):
# Adapted from the parent class `AttnProcessor2_0`
class AttnProcessorWithHook(AttnProcessor2_0):
def __init__(self, attn_processor_key, hidden_size, cross_attention_dim, hook=None, fast_attn=True, enabled=True):
def __init__(
self,
attn_processor_key,
hidden_size,
cross_attention_dim,
hook=None,
fast_attn=True,
enabled=True,
):
super().__init__()
self.attn_processor_key = attn_processor_key
self.hidden_size = hidden_size
@@ -165,15 +193,16 @@ class AttnProcessorWithHook(AttnProcessor2_0):
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, scale=scale)
args = () if USE_PEFT_BACKEND else (scale,)
query = attn.to_q(hidden_states, *args)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, scale=scale)
value = attn.to_v(encoder_hidden_states, scale=scale)
key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
@@ -186,7 +215,13 @@ class AttnProcessorWithHook(AttnProcessor2_0):
if self.hook is not None and self.enabled:
# Call the hook with query, key, value, and attention maps
self.hook(self.attn_processor_key, query_batch_dim, key_batch_dim, value_batch_dim, attention_probs)
self.hook(
self.attn_processor_key,
query_batch_dim,
key_batch_dim,
value_batch_dim,
attention_probs,
)
if self.fast_attn:
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
@@ -202,7 +237,12 @@ class AttnProcessorWithHook(AttnProcessor2_0):
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
@@ -211,7 +251,7 @@ class AttnProcessorWithHook(AttnProcessor2_0):
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states, scale=scale)
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
@@ -226,7 +266,9 @@ class AttnProcessorWithHook(AttnProcessor2_0):
return hidden_states
class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
class LLMGroundedDiffusionPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
):
r"""
Pipeline for layout-grounded text-to-image generation using LLM-grounded Diffusion (LMD+): https://arxiv.org/pdf/2305.13655.pdf.
@@ -257,6 +299,11 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
Whether a safety checker is needed for this pipeline.
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
_exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
objects_text = "Objects: "
bg_prompt_text = "Background prompt: "
bg_prompt_text_no_trailing_space = bg_prompt_text.rstrip()
@@ -272,12 +319,91 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
):
super().__init__(
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
)
# This is copied from StableDiffusionPipeline, with hook initizations for LMD+.
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Initialize the attention hooks for LLM-grounded Diffusion
self.register_attn_hooks(unet)
self._saved_attn = None
@@ -464,7 +590,14 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
return token_map
def get_phrase_indices(self, prompt, phrases, token_map=None, add_suffix_if_not_found=False, verbose=False):
def get_phrase_indices(
self,
prompt,
phrases,
token_map=None,
add_suffix_if_not_found=False,
verbose=False,
):
for obj in phrases:
# Suffix the prompt with object name for attention guidance if object is not in the prompt, using "|" to separate the prompt and the suffix
if obj not in prompt:
@@ -485,7 +618,14 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
phrase_token_map_str = " ".join(phrase_token_map)
if verbose:
logger.info("Full str:", token_map_str, "Substr:", phrase_token_map_str, "Phrase:", phrases)
logger.info(
"Full str:",
token_map_str,
"Substr:",
phrase_token_map_str,
"Phrase:",
phrases,
)
# Count the number of token before substr
# The substring comes with a trailing space that needs to be removed by minus one in the index.
@@ -552,7 +692,15 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
return loss
def compute_ca_loss(self, saved_attn, bboxes, phrase_indices, guidance_attn_keys, verbose=False, **kwargs):
def compute_ca_loss(
self,
saved_attn,
bboxes,
phrase_indices,
guidance_attn_keys,
verbose=False,
**kwargs,
):
"""
The `saved_attn` is supposed to be passed to `save_attn_to_dict` in `cross_attention_kwargs` prior to computing ths loss.
`AttnProcessor` will put attention maps into the `save_attn_to_dict`.
@@ -605,6 +753,7 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
@@ -662,6 +811,7 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
@@ -724,9 +874,10 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
phrase_indices = []
prompt_parsed = []
for prompt_item in prompt:
phrase_indices_parsed_item, prompt_parsed_item = self.get_phrase_indices(
prompt_item, add_suffix_if_not_found=True
)
(
phrase_indices_parsed_item,
prompt_parsed_item,
) = self.get_phrase_indices(prompt_item, add_suffix_if_not_found=True)
phrase_indices.append(phrase_indices_parsed_item)
prompt_parsed.append(prompt_parsed_item)
prompt = prompt_parsed
@@ -759,6 +910,11 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None:
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
@@ -801,7 +957,10 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
if n_objs:
cond_boxes[:n_objs] = torch.tensor(boxes)
text_embeddings = torch.zeros(
max_objs, self.unet.config.cross_attention_dim, device=device, dtype=self.text_encoder.dtype
max_objs,
self.unet.config.cross_attention_dim,
device=device,
dtype=self.text_encoder.dtype,
)
if n_objs:
text_embeddings[:n_objs] = _text_embeddings
@@ -833,6 +992,9 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 6.1 Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
loss_attn = torch.tensor(10000.0)
# 7. Denoising loop
@@ -869,6 +1031,7 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
).sample
# perform guidance
@@ -1013,3 +1176,438 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
self.enable_attn_hook(enabled=False)
return latents, loss
# Below are methods copied from StableDiffusionPipeline
# The design choice of not inheriting from StableDiffusionPipeline is discussed here: https://github.com/huggingface/diffusers/pull/5993#issuecomment-1834258517
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
**kwargs,
):
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
prompt_embeds_tuple = self.encode_prompt(
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=lora_scale,
**kwargs,
)
# concatenate for backwards comp
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
def encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
lora_scale (`float`, *optional*):
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
"""
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
if clip_skip is None:
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
prompt_embeds = prompt_embeds[0]
else:
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
output_hidden_states=True,
)
# Access the `hidden_states` first, that contains a tuple of
# all the hidden states from the encoder layers. Then index into
# the tuple to access the hidden states from the desired layer.
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
# We also need to apply the final LayerNorm here to not mess with the
# representations. The `last_hidden_states` that we typically use for
# obtaining the final prompt representations passes through the LayerNorm
# layer.
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype
elif self.unet is not None:
prompt_embeds_dtype = self.unet.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = uncond_input.attention_mask.to(device)
else:
attention_mask = None
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
shape = (
batch_size,
num_channels_latents,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
The suffixes after the scaling factors represent the stages where they are being applied.
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
Args:
s1 (`float`):
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
mitigate "oversmoothing effect" in the enhanced denoising process.
s2 (`float`):
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
mitigate "oversmoothing effect" in the enhanced denoising process.
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
"""
if not hasattr(self, "unet"):
raise ValueError("The pipeline must have `unet` for using FreeU.")
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
"""
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
Args:
timesteps (`torch.Tensor`):
generate embedding vectors at these timesteps
embedding_dim (`int`, *optional*, defaults to 512):
dimension of the embeddings to generate
dtype:
data type of the generated embeddings
Returns:
`torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
"""
assert len(w.shape) == 1
w = w * 1000.0
half_dim = embedding_dim // 2
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
emb = w.to(dtype)[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1))
assert emb.shape == (w.shape[0], embedding_dim)
return emb
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale
@property
def guidance_scale(self):
return self._guidance_scale
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_rescale
@property
def guidance_rescale(self):
return self._guidance_rescale
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip
@property
def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps
@property
def num_timesteps(self):
return self._num_timesteps
+14 -10
View File
@@ -250,6 +250,7 @@ def get_weighted_text_embeddings_sdxl(
neg_prompt: str = "",
neg_prompt_2: str = None,
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
):
"""
This function can process long prompt with weights, no length limitation
@@ -262,10 +263,13 @@ def get_weighted_text_embeddings_sdxl(
neg_prompt (str)
neg_prompt_2 (str)
num_images_per_prompt (int)
device (torch.device)
Returns:
prompt_embeds (torch.Tensor)
neg_prompt_embeds (torch.Tensor)
"""
device = device or pipe._execution_device
if prompt_2:
prompt = f"{prompt} {prompt_2}"
@@ -330,17 +334,17 @@ def get_weighted_text_embeddings_sdxl(
# get prompt embeddings one by one is not working.
for i in range(len(prompt_token_groups)):
# get positive prompt embeddings with weights
token_tensor = torch.tensor([prompt_token_groups[i]], dtype=torch.long, device=pipe.device)
weight_tensor = torch.tensor(prompt_weight_groups[i], dtype=torch.float16, device=pipe.device)
token_tensor = torch.tensor([prompt_token_groups[i]], dtype=torch.long, device=device)
weight_tensor = torch.tensor(prompt_weight_groups[i], dtype=torch.float16, device=device)
token_tensor_2 = torch.tensor([prompt_token_groups_2[i]], dtype=torch.long, device=pipe.device)
token_tensor_2 = torch.tensor([prompt_token_groups_2[i]], dtype=torch.long, device=device)
# use first text encoder
prompt_embeds_1 = pipe.text_encoder(token_tensor.to(pipe.device), output_hidden_states=True)
prompt_embeds_1 = pipe.text_encoder(token_tensor.to(device), output_hidden_states=True)
prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-2]
# use second text encoder
prompt_embeds_2 = pipe.text_encoder_2(token_tensor_2.to(pipe.device), output_hidden_states=True)
prompt_embeds_2 = pipe.text_encoder_2(token_tensor_2.to(device), output_hidden_states=True)
prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-2]
pooled_prompt_embeds = prompt_embeds_2[0]
@@ -357,16 +361,16 @@ def get_weighted_text_embeddings_sdxl(
embeds.append(token_embedding)
# get negative prompt embeddings with weights
neg_token_tensor = torch.tensor([neg_prompt_token_groups[i]], dtype=torch.long, device=pipe.device)
neg_token_tensor_2 = torch.tensor([neg_prompt_token_groups_2[i]], dtype=torch.long, device=pipe.device)
neg_weight_tensor = torch.tensor(neg_prompt_weight_groups[i], dtype=torch.float16, device=pipe.device)
neg_token_tensor = torch.tensor([neg_prompt_token_groups[i]], dtype=torch.long, device=device)
neg_token_tensor_2 = torch.tensor([neg_prompt_token_groups_2[i]], dtype=torch.long, device=device)
neg_weight_tensor = torch.tensor(neg_prompt_weight_groups[i], dtype=torch.float16, device=device)
# use first text encoder
neg_prompt_embeds_1 = pipe.text_encoder(neg_token_tensor.to(pipe.device), output_hidden_states=True)
neg_prompt_embeds_1 = pipe.text_encoder(neg_token_tensor.to(device), output_hidden_states=True)
neg_prompt_embeds_1_hidden_states = neg_prompt_embeds_1.hidden_states[-2]
# use second text encoder
neg_prompt_embeds_2 = pipe.text_encoder_2(neg_token_tensor_2.to(pipe.device), output_hidden_states=True)
neg_prompt_embeds_2 = pipe.text_encoder_2(neg_token_tensor_2.to(device), output_hidden_states=True)
neg_prompt_embeds_2_hidden_states = neg_prompt_embeds_2.hidden_states[-2]
negative_pooled_prompt_embeds = neg_prompt_embeds_2[0]
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,772 @@
# Copyright 2023 The Intel Labs Team Authors and the HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import PIL
import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from diffusers import DiffusionPipeline
from diffusers.image_processor import PipelineDepthInput, PipelineImageInput, VaeImageProcessorLDM3D
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_ldm3d import LDM3DPipelineOutput
from diffusers.schedulers import DDPMScheduler, KarrasDiffusionSchedulers
from diffusers.utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from diffusers.utils.torch_utils import randn_tensor
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> from diffusers import StableDiffusionUpscaleLDM3DPipeline
>>> from PIL import Image
>>> from io import BytesIO
>>> import requests
>>> pipe = StableDiffusionUpscaleLDM3DPipeline.from_pretrained("Intel/ldm3d-sr")
>>> pipe = pipe.to("cuda")
>>> rgb_path = "https://huggingface.co/Intel/ldm3d-sr/resolve/main/lemons_ldm3d_rgb.jpg"
>>> depth_path = "https://huggingface.co/Intel/ldm3d-sr/resolve/main/lemons_ldm3d_depth.png"
>>> low_res_rgb = Image.open(BytesIO(requests.get(rgb_path).content)).convert("RGB")
>>> low_res_depth = Image.open(BytesIO(requests.get(depth_path).content)).convert("L")
>>> output = pipe(
... prompt="high quality high resolution uhd 4k image",
... rgb=low_res_rgb,
... depth=low_res_depth,
... num_inference_steps=50,
... target_res=[1024, 1024],
... )
>>> rgb_image, depth_image = output.rgb, output.depth
>>> rgb_image[0].save("hr_ldm3d_rgb.jpg")
>>> depth_image[0].save("hr_ldm3d_depth.png")
```
"""
class StableDiffusionUpscaleLDM3DPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
):
r"""
Pipeline for text-to-image and 3D generation using LDM3D.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
text_encoder ([`~transformers.CLIPTextModel`]):
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
tokenizer ([`~transformers.CLIPTokenizer`]):
A `CLIPTokenizer` to tokenize text.
unet ([`UNet2DConditionModel`]):
A `UNet2DConditionModel` to denoise the encoded image latents.
low_res_scheduler ([`SchedulerMixin`]):
A scheduler used to add initial noise to the low resolution conditioning image. It must be an instance of
[`DDPMScheduler`].
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
low_res_scheduler: DDPMScheduler,
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
watermarker: Optional[Any] = None,
max_noise_level: int = 350,
):
super().__init__()
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
low_res_scheduler=low_res_scheduler,
scheduler=scheduler,
safety_checker=safety_checker,
watermarker=watermarker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessorLDM3D(vae_scale_factor=self.vae_scale_factor, resample="bilinear")
# self.register_to_config(requires_safety_checker=requires_safety_checker)
self.register_to_config(max_noise_level=max_noise_level)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_ldm3d.StableDiffusionLDM3DPipeline._encode_prompt
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
**kwargs,
):
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
prompt_embeds_tuple = self.encode_prompt(
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=lora_scale,
**kwargs,
)
# concatenate for backwards comp
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_ldm3d.StableDiffusionLDM3DPipeline.encode_prompt
def encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
lora_scale (`float`, *optional*):
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
"""
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
if clip_skip is None:
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
prompt_embeds = prompt_embeds[0]
else:
prompt_embeds = self.text_encoder(
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
)
# Access the `hidden_states` first, that contains a tuple of
# all the hidden states from the encoder layers. Then index into
# the tuple to access the hidden states from the desired layer.
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
# We also need to apply the final LayerNorm here to not mess with the
# representations. The `last_hidden_states` that we typically use for
# obtaining the final prompt representations passes through the LayerNorm
# layer.
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype
elif self.unet is not None:
prompt_embeds_dtype = self.unet.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = uncond_input.attention_mask.to(device)
else:
attention_mask = None
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
rgb_feature_extractor_input = feature_extractor_input[0]
safety_checker_input = self.feature_extractor(rgb_feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(
self,
prompt,
image,
noise_level,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
target_res=None,
):
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if (
not isinstance(image, torch.Tensor)
and not isinstance(image, PIL.Image.Image)
and not isinstance(image, np.ndarray)
and not isinstance(image, list)
):
raise ValueError(
f"`image` has to be of type `torch.Tensor`, `np.ndarray`, `PIL.Image.Image` or `list` but is {type(image)}"
)
# verify batch size of prompt and image are same if image is a list or tensor or numpy array
if isinstance(image, list) or isinstance(image, torch.Tensor) or isinstance(image, np.ndarray):
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if isinstance(image, list):
image_batch_size = len(image)
else:
image_batch_size = image.shape[0]
if batch_size != image_batch_size:
raise ValueError(
f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}."
" Please make sure that passed `prompt` matches the batch size of `image`."
)
# check noise level
if noise_level > self.config.max_noise_level:
raise ValueError(f"`noise_level` has to be <= {self.config.max_noise_level} but is {noise_level}")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height, width)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
# def upcast_vae(self):
# dtype = self.vae.dtype
# self.vae.to(dtype=torch.float32)
# use_torch_2_0_or_xformers = isinstance(
# self.vae.decoder.mid_block.attentions[0].processor,
# (
# AttnProcessor2_0,
# XFormersAttnProcessor,
# LoRAXFormersAttnProcessor,
# LoRAAttnProcessor2_0,
# ),
# )
# # if xformers or torch_2_0 is used attention block does not need
# # to be in float32 which can save lots of memory
# if use_torch_2_0_or_xformers:
# self.vae.post_quant_conv.to(dtype)
# self.vae.decoder.conv_in.to(dtype)
# self.vae.decoder.mid_block.to(dtype)
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
rgb: PipelineImageInput = None,
depth: PipelineDepthInput = None,
num_inference_steps: int = 75,
guidance_scale: float = 9.0,
noise_level: int = 20,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
target_res: Optional[List[int]] = [1024, 1024],
):
r"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image` or tensor representing an image batch to be upscaled.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 5.0):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
Examples:
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list with the generated images and the
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
rgb,
noise_level,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 4. Preprocess image
rgb, depth = self.image_processor.preprocess(rgb, depth, target_res=target_res)
rgb = rgb.to(dtype=prompt_embeds.dtype, device=device)
depth = depth.to(dtype=prompt_embeds.dtype, device=device)
# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 6. Encode low resolutiom image to latent space
image = torch.cat([rgb, depth], axis=1)
latent_space_image = self.vae.encode(image).latent_dist.sample(generator)
latent_space_image *= self.vae.scaling_factor
noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
# noise_rgb = randn_tensor(rgb.shape, generator=generator, device=device, dtype=prompt_embeds.dtype)
# rgb = self.low_res_scheduler.add_noise(rgb, noise_rgb, noise_level)
# noise_depth = randn_tensor(depth.shape, generator=generator, device=device, dtype=prompt_embeds.dtype)
# depth = self.low_res_scheduler.add_noise(depth, noise_depth, noise_level)
batch_multiplier = 2 if do_classifier_free_guidance else 1
latent_space_image = torch.cat([latent_space_image] * batch_multiplier * num_images_per_prompt)
noise_level = torch.cat([noise_level] * latent_space_image.shape[0])
# 7. Prepare latent variables
height, width = latent_space_image.shape[2:]
num_channels_latents = self.vae.config.latent_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 8. Check that sizes of image and latents match
num_channels_image = latent_space_image.shape[1]
if num_channels_latents + num_channels_image != self.unet.config.in_channels:
raise ValueError(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_image`: {num_channels_image} "
f" = {num_channels_latents+num_channels_image}. Please verify the config of"
" `pipeline.unet` or your `image` input."
)
# 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 10. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# concat latents, mask, masked_image_latents in the channel dimension
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
latent_model_input = torch.cat([latent_model_input, latent_space_image], dim=1)
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=noise_level,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
if needs_upcasting:
self.upcast_vae()
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
image = self.vae.decode(latents / self.vae.scaling_factor, return_dict=False)[0]
# cast back to fp16 if needed
if needs_upcasting:
self.vae.to(dtype=torch.float16)
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
rgb, depth = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
# 11. Apply watermark
if output_type == "pil" and self.watermarker is not None:
rgb = self.watermarker.apply_watermark(rgb)
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
if not return_dict:
return ((rgb, depth), has_nsfw_concept)
return LDM3DPipelineOutput(rgb=rgb, depth=depth, nsfw_content_detected=has_nsfw_concept)
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,589 @@
import math
from typing import Dict, Optional
import torch
import torchvision.transforms.functional as FF
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers import StableDiffusionPipeline
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import USE_PEFT_BACKEND
try:
from compel import Compel
except ImportError:
Compel = None
KCOMM = "ADDCOMM"
KBRK = "BREAK"
class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
r"""
Args for Regional Prompting Pipeline:
rp_args:dict
Required
rp_args["mode"]: cols, rows, prompt, prompt-ex
for cols, rows mode
rp_args["div"]: ex) 1;1;1(Divide into 3 regions)
for prompt, prompt-ex mode
rp_args["th"]: ex) 0.5,0.5,0.6 (threshold for prompt mode)
Optional
rp_args["save_mask"]: True/False (save masks in prompt mode)
Pipeline for text-to-image generation using Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
):
super().__init__(
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
@torch.no_grad()
def __call__(
self,
prompt: str,
height: int = 512,
width: int = 512,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: str = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
rp_args: Dict[str, str] = None,
):
active = KBRK in prompt[0] if type(prompt) == list else KBRK in prompt # noqa: E721
if negative_prompt is None:
negative_prompt = "" if type(prompt) == str else [""] * len(prompt) # noqa: E721
device = self._execution_device
regions = 0
self.power = int(rp_args["power"]) if "power" in rp_args else 1
prompts = prompt if type(prompt) == list else [prompt] # noqa: E721
n_prompts = negative_prompt if type(negative_prompt) == list else [negative_prompt] # noqa: E721
self.batch = batch = num_images_per_prompt * len(prompts)
all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt)
all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt)
cn = len(all_prompts_cn) == len(all_n_prompts_cn)
if Compel:
compel = Compel(tokenizer=self.tokenizer, text_encoder=self.text_encoder)
def getcompelembs(prps):
embl = []
for prp in prps:
embl.append(compel.build_conditioning_tensor(prp))
return torch.cat(embl)
conds = getcompelembs(all_prompts_cn)
unconds = getcompelembs(all_n_prompts_cn) if cn else getcompelembs(n_prompts)
embs = getcompelembs(prompts)
n_embs = getcompelembs(n_prompts)
prompt = negative_prompt = None
else:
conds = self.encode_prompt(prompts, device, 1, True)[0]
unconds = (
self.encode_prompt(n_prompts, device, 1, True)[0]
if cn
else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0]
)
embs = n_embs = None
if not active:
pcallback = None
mode = None
else:
if any(x in rp_args["mode"].upper() for x in ["COL", "ROW"]):
mode = "COL" if "COL" in rp_args["mode"].upper() else "ROW"
ocells, icells, regions = make_cells(rp_args["div"])
elif "PRO" in rp_args["mode"].upper():
regions = len(all_prompts_p[0])
mode = "PROMPT"
reset_attnmaps(self)
self.ex = "EX" in rp_args["mode"].upper()
self.target_tokens = target_tokens = tokendealer(self, all_prompts_p)
thresholds = [float(x) for x in rp_args["th"].split(",")]
orig_hw = (height, width)
revers = True
def pcallback(s_self, step: int, timestep: int, latents: torch.FloatTensor, selfs=None):
if "PRO" in mode: # in Prompt mode, make masks from sum of attension maps
self.step = step
if len(self.attnmaps_sizes) > 3:
self.history[step] = self.attnmaps.copy()
for hw in self.attnmaps_sizes:
allmasks = []
basemasks = [None] * batch
for tt, th in zip(target_tokens, thresholds):
for b in range(batch):
key = f"{tt}-{b}"
_, mask, _ = makepmask(self, self.attnmaps[key], hw[0], hw[1], th, step)
mask = mask.unsqueeze(0).unsqueeze(-1)
if self.ex:
allmasks[b::batch] = [x - mask for x in allmasks[b::batch]]
allmasks[b::batch] = [torch.where(x > 0, 1, 0) for x in allmasks[b::batch]]
allmasks.append(mask)
basemasks[b] = mask if basemasks[b] is None else basemasks[b] + mask
basemasks = [1 - mask for mask in basemasks]
basemasks = [torch.where(x > 0, 1, 0) for x in basemasks]
allmasks = basemasks + allmasks
self.attnmasks[hw] = torch.cat(allmasks)
self.maskready = True
return latents
def hook_forward(module):
# diffusers==0.23.2
def forward(
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
) -> torch.Tensor:
attn = module
xshape = hidden_states.shape
self.hw = (h, w) = split_dims(xshape[1], *orig_hw)
if revers:
nx, px = hidden_states.chunk(2)
else:
px, nx = hidden_states.chunk(2)
if cn:
hidden_states = torch.cat([px for i in range(regions)] + [nx for i in range(regions)], 0)
encoder_hidden_states = torch.cat([conds] + [unconds])
else:
hidden_states = torch.cat([px for i in range(regions)] + [nx], 0)
encoder_hidden_states = torch.cat([conds] + [unconds])
residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
args = () if USE_PEFT_BACKEND else (scale,)
query = attn.to_q(hidden_states, *args)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = scaled_dot_product_attention(
self,
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
getattn="PRO" in mode,
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
#### Regional Prompting Col/Row mode
if any(x in mode for x in ["COL", "ROW"]):
reshaped = hidden_states.reshape(hidden_states.size()[0], h, w, hidden_states.size()[2])
center = reshaped.shape[0] // 2
px = reshaped[0:center] if cn else reshaped[0:-batch]
nx = reshaped[center:] if cn else reshaped[-batch:]
outs = [px, nx] if cn else [px]
for out in outs:
c = 0
for i, ocell in enumerate(ocells):
for icell in icells[i]:
if "ROW" in mode:
out[
0:batch,
int(h * ocell[0]) : int(h * ocell[1]),
int(w * icell[0]) : int(w * icell[1]),
:,
] = out[
c * batch : (c + 1) * batch,
int(h * ocell[0]) : int(h * ocell[1]),
int(w * icell[0]) : int(w * icell[1]),
:,
]
else:
out[
0:batch,
int(h * icell[0]) : int(h * icell[1]),
int(w * ocell[0]) : int(w * ocell[1]),
:,
] = out[
c * batch : (c + 1) * batch,
int(h * icell[0]) : int(h * icell[1]),
int(w * ocell[0]) : int(w * ocell[1]),
:,
]
c += 1
px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx)
hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0)
hidden_states = hidden_states.reshape(xshape)
#### Regional Prompting Prompt mode
elif "PRO" in mode:
center = reshaped.shape[0] // 2
px = reshaped[0:center] if cn else reshaped[0:-batch]
nx = reshaped[center:] if cn else reshaped[-batch:]
if (h, w) in self.attnmasks and self.maskready:
def mask(input):
out = torch.multiply(input, self.attnmasks[(h, w)])
for b in range(batch):
for r in range(1, regions):
out[b] = out[b] + out[r * batch + b]
return out
px, nx = (mask(px), mask(nx)) if cn else (mask(px), nx)
px, nx = (px[0:batch], nx[0:batch]) if cn else (px[0:batch], nx)
hidden_states = torch.cat([nx, px], 0) if revers else torch.cat([px, nx], 0)
return hidden_states
return forward
def hook_forwards(root_module: torch.nn.Module):
for name, module in root_module.named_modules():
if "attn2" in name and module.__class__.__name__ == "Attention":
module.forward = hook_forward(module)
hook_forwards(self.unet)
output = StableDiffusionPipeline(**self.components)(
prompt=prompt,
prompt_embeds=embs,
negative_prompt=negative_prompt,
negative_prompt_embeds=n_embs,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
eta=eta,
generator=generator,
latents=latents,
output_type=output_type,
return_dict=return_dict,
callback_on_step_end=pcallback,
)
if "save_mask" in rp_args:
save_mask = rp_args["save_mask"]
else:
save_mask = False
if mode == "PROMPT" and save_mask:
saveattnmaps(self, output, height, width, thresholds, num_inference_steps // 2, regions)
return output
### Make prompt list for each regions
def promptsmaker(prompts, batch):
out_p = []
plen = len(prompts)
for prompt in prompts:
add = ""
if KCOMM in prompt:
add, prompt = prompt.split(KCOMM)
add = add + " "
prompts = prompt.split(KBRK)
out_p.append([add + p for p in prompts])
out = [None] * batch * len(out_p[0]) * len(out_p)
for p, prs in enumerate(out_p): # inputs prompts
for r, pr in enumerate(prs): # prompts for regions
start = (p + r * plen) * batch
out[start : start + batch] = [pr] * batch # P1R1B1,P1R1B2...,P1R2B1,P1R2B2...,P2R1B1...
return out, out_p
### make regions from ratios
### ";" makes outercells, "," makes inner cells
def make_cells(ratios):
if ";" not in ratios and "," in ratios:
ratios = ratios.replace(",", ";")
ratios = ratios.split(";")
ratios = [inratios.split(",") for inratios in ratios]
icells = []
ocells = []
def startend(cells, array):
current_start = 0
array = [float(x) for x in array]
for value in array:
end = current_start + (value / sum(array))
cells.append([current_start, end])
current_start = end
startend(ocells, [r[0] for r in ratios])
for inratios in ratios:
if 2 > len(inratios):
icells.append([[0, 1]])
else:
add = []
startend(add, inratios[1:])
icells.append(add)
return ocells, icells, sum(len(cell) for cell in icells)
def make_emblist(self, prompts):
with torch.no_grad():
tokens = self.tokenizer(
prompts, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
).input_ids.to(self.device)
embs = self.text_encoder(tokens, output_hidden_states=True).last_hidden_state.to(self.device, dtype=self.dtype)
return embs
def split_dims(xs, height, width):
xs = xs
def repeat_div(x, y):
while y > 0:
x = math.ceil(x / 2)
y = y - 1
return x
scale = math.ceil(math.log2(math.sqrt(height * width / xs)))
dsh = repeat_div(height, scale)
dsw = repeat_div(width, scale)
return dsh, dsw
##### for prompt mode
def get_attn_maps(self, attn):
height, width = self.hw
target_tokens = self.target_tokens
if (height, width) not in self.attnmaps_sizes:
self.attnmaps_sizes.append((height, width))
for b in range(self.batch):
for t in target_tokens:
power = self.power
add = attn[b, :, :, t[0] : t[0] + len(t)] ** (power) * (self.attnmaps_sizes.index((height, width)) + 1)
add = torch.sum(add, dim=2)
key = f"{t}-{b}"
if key not in self.attnmaps:
self.attnmaps[key] = add
else:
if self.attnmaps[key].shape[1] != add.shape[1]:
add = add.view(8, height, width)
add = FF.resize(add, self.attnmaps_sizes[0], antialias=None)
add = add.reshape_as(self.attnmaps[key])
self.attnmaps[key] = self.attnmaps[key] + add
def reset_attnmaps(self): # init parameters in every batch
self.step = 0
self.attnmaps = {} # maked from attention maps
self.attnmaps_sizes = [] # height,width set of u-net blocks
self.attnmasks = {} # maked from attnmaps for regions
self.maskready = False
self.history = {}
def saveattnmaps(self, output, h, w, th, step, regions):
masks = []
for i, mask in enumerate(self.history[step].values()):
img, _, mask = makepmask(self, mask, h, w, th[i % len(th)], step)
if self.ex:
masks = [x - mask for x in masks]
masks.append(mask)
if len(masks) == regions - 1:
output.images.extend([FF.to_pil_image(mask) for mask in masks])
masks = []
else:
output.images.append(img)
def makepmask(
self, mask, h, w, th, step
): # make masks from attention cache return [for preview, for attention, for Latent]
th = th - step * 0.005
if 0.05 >= th:
th = 0.05
mask = torch.mean(mask, dim=0)
mask = mask / mask.max().item()
mask = torch.where(mask > th, 1, 0)
mask = mask.float()
mask = mask.view(1, *self.attnmaps_sizes[0])
img = FF.to_pil_image(mask)
img = img.resize((w, h))
mask = FF.resize(mask, (h, w), interpolation=FF.InterpolationMode.NEAREST, antialias=None)
lmask = mask
mask = mask.reshape(h * w)
mask = torch.where(mask > 0.1, 1, 0)
return img, mask, lmask
def tokendealer(self, all_prompts):
for prompts in all_prompts:
targets = [p.split(",")[-1] for p in prompts[1:]]
tt = []
for target in targets:
ptokens = (
self.tokenizer(
prompts,
max_length=self.tokenizer.model_max_length,
padding=True,
truncation=True,
return_tensors="pt",
).input_ids
)[0]
ttokens = (
self.tokenizer(
target,
max_length=self.tokenizer.model_max_length,
padding=True,
truncation=True,
return_tensors="pt",
).input_ids
)[0]
tlist = []
for t in range(ttokens.shape[0] - 2):
for p in range(ptokens.shape[0]):
if ttokens[t + 1] == ptokens[p]:
tlist.append(p)
if tlist != []:
tt.append(tlist)
return tt
def scaled_dot_product_attention(
self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, getattn=False
) -> torch.Tensor:
# Efficient implementation equivalent to the following:
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=self.device)
if is_causal:
assert attn_mask is None
temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
if getattn:
get_attn_maps(self, attn_weight)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight @ value
+2 -2
View File
@@ -21,7 +21,7 @@ from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers.configuration_utils import FrozenDict
from diffusers.loaders import TextualInversionLoaderMixin
from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
@@ -62,7 +62,7 @@ EXAMPLE_DOC_STRING = """
"""
class StableDiffusionIPEXPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
class StableDiffusionIPEXPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
r"""
Pipeline for text-to-image generation using Stable Diffusion on IPEX.
@@ -41,7 +41,7 @@ from polygraphy.backend.trt import (
save_engine,
)
from polygraphy.backend.trt import util as trt_util
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import (
@@ -709,6 +709,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
scheduler: DDIMScheduler,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
stages=["clip", "unet", "vae", "vae_encoder"],
image_height: int = 512,
@@ -724,7 +725,15 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
timing_cache: str = "timing_cache",
):
super().__init__(
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
requires_safety_checker=requires_safety_checker,
)
self.vae.forward = self.vae.decode
@@ -41,7 +41,7 @@ from polygraphy.backend.trt import (
save_engine,
)
from polygraphy.backend.trt import util as trt_util
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import (
@@ -710,6 +710,7 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
scheduler: DDIMScheduler,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
stages=["clip", "unet", "vae", "vae_encoder"],
image_height: int = 512,
@@ -725,7 +726,15 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
timing_cache: str = "timing_cache",
):
super().__init__(
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
requires_safety_checker=requires_safety_checker,
)
self.vae.forward = self.vae.decode
@@ -40,7 +40,7 @@ from polygraphy.backend.trt import (
save_engine,
)
from polygraphy.backend.trt import util as trt_util
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import (
@@ -624,6 +624,7 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
scheduler: DDIMScheduler,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
stages=["clip", "unet", "vae"],
image_height: int = 768,
@@ -639,7 +640,15 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
timing_cache: str = "timing_cache",
):
super().__init__(
vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
requires_safety_checker=requires_safety_checker,
)
self.vae.forward = self.vae.decode
@@ -71,7 +71,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.18.0.dev0")
check_min_version("0.25.0.dev0")
logger = get_logger(__name__)
@@ -72,7 +72,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.18.0.dev0")
check_min_version("0.25.0.dev0")
logger = get_logger(__name__)
@@ -70,7 +70,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.18.0.dev0")
check_min_version("0.25.0.dev0")
logger = get_logger(__name__)
@@ -657,6 +657,15 @@ def parse_args():
default=0.001,
help="The huber loss parameter. Only used if `--loss_type=huber`.",
)
parser.add_argument(
"--unet_time_cond_proj_dim",
type=int,
default=256,
help=(
"The dimension of the guidance scale embedding in the U-Net, which will be used if the teacher U-Net"
" does not have `time_cond_proj_dim` set."
),
)
# ----Exponential Moving Average (EMA)----
parser.add_argument(
"--ema_decay",
@@ -1138,7 +1147,7 @@ def main(args):
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
w_embedding = guidance_scale_embedding(w, embedding_dim=args.unet_time_cond_proj_dim)
w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim)
w = w.reshape(bsz, 1, 1, 1)
# Move to U-Net device and dtype
w = w.to(device=latents.device, dtype=latents.dtype)
@@ -71,7 +71,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.18.0.dev0")
check_min_version("0.25.0.dev0")
logger = get_logger(__name__)
@@ -677,6 +677,15 @@ def parse_args():
default=0.001,
help="The huber loss parameter. Only used if `--loss_type=huber`.",
)
parser.add_argument(
"--unet_time_cond_proj_dim",
type=int,
default=256,
help=(
"The dimension of the guidance scale embedding in the U-Net, which will be used if the teacher U-Net"
" does not have `time_cond_proj_dim` set."
),
)
# ----Exponential Moving Average (EMA)----
parser.add_argument(
"--ema_decay",
@@ -1233,6 +1242,7 @@ def main(args):
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim)
w = w.reshape(bsz, 1, 1, 1)
w = w.to(device=latents.device, dtype=latents.dtype)
@@ -1243,7 +1253,7 @@ def main(args):
noise_pred = unet(
noisy_model_input,
start_timesteps,
timestep_cond=None,
timestep_cond=w_embedding,
encoder_hidden_states=prompt_embeds.float(),
added_cond_kwargs=encoded_text,
).sample
@@ -1308,7 +1318,7 @@ def main(args):
target_noise_pred = target_unet(
x_prev.float(),
timesteps,
timestep_cond=None,
timestep_cond=w_embedding,
encoder_hidden_states=prompt_embeds.float(),
added_cond_kwargs=encoded_text,
).sample
+120
View File
@@ -0,0 +1,120 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
import tempfile
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class ControlNet(ExamplesTestsAccelerate):
def test_controlnet_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/controlnet/train_controlnet.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
--dataset_name=hf-internal-testing/fill10
--output_dir={tmpdir}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=6
--checkpoints_total_limit=2
--checkpointing_steps=2
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)
def test_controlnet_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/controlnet/train_controlnet.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
--dataset_name=hf-internal-testing/fill10
--output_dir={tmpdir}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
--max_train_steps=9
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
)
resume_run_args = f"""
examples/controlnet/train_controlnet.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
--dataset_name=hf-internal-testing/fill10
--output_dir={tmpdir}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
--max_train_steps=11
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-8
--checkpoints_total_limit=3
""".split()
run_command(self._launch_args + resume_run_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-8", "checkpoint-10", "checkpoint-12"},
)
class ControlNetSDXL(ExamplesTestsAccelerate):
def test_controlnet_sdxl(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/controlnet/train_controlnet_sdxl.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-xl-pipe
--dataset_name=hf-internal-testing/fill10
--output_dir={tmpdir}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet-sdxl
--max_train_steps=9
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
+14 -8
View File
@@ -56,7 +56,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.24.0.dev0")
check_min_version("0.25.0.dev0")
logger = get_logger(__name__)
@@ -86,6 +86,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler
controlnet=controlnet,
safety_checker=None,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
@@ -249,10 +250,13 @@ def parse_args(input_args=None):
type=str,
default=None,
required=False,
help=(
"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
" float32 precision."
),
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--tokenizer_name",
@@ -767,11 +771,13 @@ def main(args):
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = text_encoder_cls.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
if args.controlnet_model_name_or_path:
+1 -1
View File
@@ -59,7 +59,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.24.0.dev0")
check_min_version("0.25.0.dev0")
logger = logging.getLogger(__name__)
+21 -10
View File
@@ -58,7 +58,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.24.0.dev0")
check_min_version("0.25.0.dev0")
logger = get_logger(__name__)
@@ -74,6 +74,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
unet=unet,
controlnet=controlnet,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
@@ -243,15 +244,18 @@ def parse_args(input_args=None):
help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
" If not specified controlnet weights are initialized from unet.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help=(
"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
" float32 precision."
),
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--tokenizer_name",
@@ -793,10 +797,16 @@ def main(args):
# Load the tokenizers
tokenizer_one = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
use_fast=False,
)
tokenizer_two = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
args.pretrained_model_name_or_path,
subfolder="tokenizer_2",
revision=args.revision,
use_fast=False,
)
# import correct text encoder classes
@@ -810,10 +820,10 @@ def main(args):
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
text_encoder_two = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
)
vae_path = (
args.pretrained_model_name_or_path
@@ -824,9 +834,10 @@ def main(args):
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
if args.controlnet_model_name_or_path:
@@ -0,0 +1,130 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
import tempfile
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class CustomDiffusion(ExamplesTestsAccelerate):
def test_custom_diffusion(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/custom_diffusion/train_custom_diffusion.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
--instance_data_dir docs/source/en/imgs
--instance_prompt <new1>
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 1.0e-05
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--modifier_token <new1>
--no_safe_serialization
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_custom_diffusion_weights.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "<new1>.bin")))
def test_custom_diffusion_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/custom_diffusion/train_custom_diffusion.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
--instance_data_dir=docs/source/en/imgs
--output_dir={tmpdir}
--instance_prompt=<new1>
--resolution=64
--train_batch_size=1
--modifier_token=<new1>
--dataloader_num_workers=0
--max_train_steps=6
--checkpoints_total_limit=2
--checkpointing_steps=2
--no_safe_serialization
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)
def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/custom_diffusion/train_custom_diffusion.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
--instance_data_dir=docs/source/en/imgs
--output_dir={tmpdir}
--instance_prompt=<new1>
--resolution=64
--train_batch_size=1
--modifier_token=<new1>
--dataloader_num_workers=0
--max_train_steps=9
--checkpointing_steps=2
--no_safe_serialization
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
)
resume_run_args = f"""
examples/custom_diffusion/train_custom_diffusion.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
--instance_data_dir=docs/source/en/imgs
--output_dir={tmpdir}
--instance_prompt=<new1>
--resolution=64
--train_batch_size=1
--modifier_token=<new1>
--dataloader_num_workers=0
--max_train_steps=11
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-8
--checkpoints_total_limit=3
--no_safe_serialization
""".split()
run_command(self._launch_args + resume_run_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
)
@@ -62,7 +62,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.24.0.dev0")
check_min_version("0.25.0.dev0")
logger = get_logger(__name__)
@@ -332,6 +332,12 @@ def parse_args(input_args=None):
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--tokenizer_name",
type=str,
@@ -740,6 +746,7 @@ def main(args):
torch_dtype=torch_dtype,
safety_checker=None,
revision=args.revision,
variant=args.variant,
)
pipeline.set_progress_bar_config(disable=True)
@@ -801,11 +808,13 @@ def main(args):
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = text_encoder_cls.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
# Adding a modifier token which is optimized ####
@@ -1229,6 +1238,7 @@ def main(args):
text_encoder=accelerator.unwrap_model(text_encoder),
tokenizer=tokenizer,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
@@ -1278,7 +1288,7 @@ def main(args):
# Final inference
# Load previous pipeline
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline = pipeline.to(accelerator.device)
+230
View File
@@ -0,0 +1,230 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import shutil
import sys
import tempfile
from diffusers import DiffusionPipeline, UNet2DConditionModel
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class DreamBooth(ExamplesTestsAccelerate):
def test_dreambooth(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
--instance_data_dir docs/source/en/imgs
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
def test_dreambooth_if(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth.py
--pretrained_model_name_or_path hf-internal-testing/tiny-if-pipe
--instance_data_dir docs/source/en/imgs
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--pre_compute_text_embeddings
--tokenizer_max_length=77
--text_encoder_use_attention_mask
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
def test_dreambooth_checkpointing(self):
instance_prompt = "photo"
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
with tempfile.TemporaryDirectory() as tmpdir:
# Run training script with checkpointing
# max_train_steps == 5, checkpointing_steps == 2
# Should create checkpoints at steps 2, 4
initial_run_args = f"""
examples/dreambooth/train_dreambooth.py
--pretrained_model_name_or_path {pretrained_model_name_or_path}
--instance_data_dir docs/source/en/imgs
--instance_prompt {instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 5
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--seed=0
""".split()
run_command(self._launch_args + initial_run_args)
# check can run the original fully trained output pipeline
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
pipe(instance_prompt, num_inference_steps=2)
# check checkpoint directories exist
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
# check can run an intermediate checkpoint
unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)
pipe(instance_prompt, num_inference_steps=2)
# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
# Run training script for 7 total steps resuming from checkpoint 4
resume_run_args = f"""
examples/dreambooth/train_dreambooth.py
--pretrained_model_name_or_path {pretrained_model_name_or_path}
--instance_data_dir docs/source/en/imgs
--instance_prompt {instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-4
--seed=0
""".split()
run_command(self._launch_args + resume_run_args)
# check can run new fully trained pipeline
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
pipe(instance_prompt, num_inference_steps=2)
# check old checkpoints do not exist
self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
# check new checkpoints exist
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6")))
def test_dreambooth_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
--instance_data_dir=docs/source/en/imgs
--output_dir={tmpdir}
--instance_prompt=prompt
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=6
--checkpoints_total_limit=2
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)
def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
--instance_data_dir=docs/source/en/imgs
--output_dir={tmpdir}
--instance_prompt=prompt
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=9
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
)
resume_run_args = f"""
examples/dreambooth/train_dreambooth.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
--instance_data_dir=docs/source/en/imgs
--output_dir={tmpdir}
--instance_prompt=prompt
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=11
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-8
--checkpoints_total_limit=3
""".split()
run_command(self._launch_args + resume_run_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
)
+388
View File
@@ -0,0 +1,388 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
import tempfile
import safetensors
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
from diffusers import DiffusionPipeline # noqa: E402
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class DreamBoothLoRA(ExamplesTestsAccelerate):
def test_dreambooth_lora(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth_lora.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
--instance_data_dir docs/source/en/imgs
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"unet"` in their names.
starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
self.assertTrue(starts_with_unet)
def test_dreambooth_lora_with_text_encoder(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth_lora.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
--instance_data_dir docs/source/en/imgs
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--train_text_encoder
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# check `text_encoder` is present at all.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
keys = lora_state_dict.keys()
is_text_encoder_present = any(k.startswith("text_encoder") for k in keys)
self.assertTrue(is_text_encoder_present)
# the names of the keys of the state dict should either start with `unet`
# or `text_encoder`.
is_correct_naming = all(k.startswith("unet") or k.startswith("text_encoder") for k in keys)
self.assertTrue(is_correct_naming)
def test_dreambooth_lora_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth_lora.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
--instance_data_dir=docs/source/en/imgs
--output_dir={tmpdir}
--instance_prompt=prompt
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=6
--checkpoints_total_limit=2
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)
def test_dreambooth_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth_lora.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
--instance_data_dir=docs/source/en/imgs
--output_dir={tmpdir}
--instance_prompt=prompt
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=9
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
)
resume_run_args = f"""
examples/dreambooth/train_dreambooth_lora.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
--instance_data_dir=docs/source/en/imgs
--output_dir={tmpdir}
--instance_prompt=prompt
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=11
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-8
--checkpoints_total_limit=3
""".split()
run_command(self._launch_args + resume_run_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
)
def test_dreambooth_lora_if_model(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth_lora.py
--pretrained_model_name_or_path hf-internal-testing/tiny-if-pipe
--instance_data_dir docs/source/en/imgs
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--pre_compute_text_embeddings
--tokenizer_max_length=77
--text_encoder_use_attention_mask
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"unet"` in their names.
starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
self.assertTrue(starts_with_unet)
class DreamBoothLoRASDXL(ExamplesTestsAccelerate):
def test_dreambooth_lora_sdxl(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth_lora_sdxl.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
--instance_data_dir docs/source/en/imgs
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"unet"` in their names.
starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
self.assertTrue(starts_with_unet)
def test_dreambooth_lora_sdxl_with_text_encoder(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth_lora_sdxl.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
--instance_data_dir docs/source/en/imgs
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--train_text_encoder
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"unet"` or `"text_encoder"` or `"text_encoder_2"` in their names.
keys = lora_state_dict.keys()
starts_with_unet = all(
k.startswith("unet") or k.startswith("text_encoder") or k.startswith("text_encoder_2") for k in keys
)
self.assertTrue(starts_with_unet)
def test_dreambooth_lora_sdxl_custom_captions(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth_lora_sdxl.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
--dataset_name hf-internal-testing/dummy_image_text_data
--caption_column text
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
def test_dreambooth_lora_sdxl_text_encoder_custom_captions(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth_lora_sdxl.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
--dataset_name hf-internal-testing/dummy_image_text_data
--caption_column text
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--train_text_encoder
""".split()
run_command(self._launch_args + test_args)
def test_dreambooth_lora_sdxl_checkpointing_checkpoints_total_limit(self):
pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth_lora_sdxl.py
--pretrained_model_name_or_path {pipeline_path}
--instance_data_dir docs/source/en/imgs
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--checkpointing_steps=2
--checkpoints_total_limit=2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
pipe = DiffusionPipeline.from_pretrained(pipeline_path)
pipe.load_lora_weights(tmpdir)
pipe("a prompt", num_inference_steps=2)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
# checkpoint-2 should have been deleted
{"checkpoint-4", "checkpoint-6"},
)
def test_dreambooth_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self):
pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth_lora_sdxl.py
--pretrained_model_name_or_path {pipeline_path}
--instance_data_dir docs/source/en/imgs
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--checkpointing_steps=2
--checkpoints_total_limit=2
--train_text_encoder
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
pipe = DiffusionPipeline.from_pretrained(pipeline_path)
pipe.load_lora_weights(tmpdir)
pipe("a prompt", num_inference_steps=2)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
# checkpoint-2 should have been deleted
{"checkpoint-4", "checkpoint-6"},
)
+15 -9
View File
@@ -61,7 +61,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.24.0.dev0")
check_min_version("0.25.0.dev0")
logger = get_logger(__name__)
@@ -139,6 +139,7 @@ def log_validation(
text_encoder=text_encoder,
unet=accelerator.unwrap_model(unet),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
**pipeline_args,
)
@@ -239,10 +240,13 @@ def parse_args(input_args=None):
type=str,
default=None,
required=False,
help=(
"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
" float32 precision."
),
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--tokenizer_name",
@@ -296,7 +300,7 @@ def parse_args(input_args=None):
parser.add_argument(
"--output_dir",
type=str,
default="text-inversion-model",
default="dreambooth-model",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
@@ -859,6 +863,7 @@ def main(args):
torch_dtype=torch_dtype,
safety_checker=None,
revision=args.revision,
variant=args.variant,
)
pipeline.set_progress_bar_config(disable=True)
@@ -912,18 +917,18 @@ def main(args):
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = text_encoder_cls.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
if model_has_vae(args):
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
)
else:
vae = None
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
@@ -1379,6 +1384,7 @@ def main(args):
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
revision=args.revision,
variant=args.variant,
**pipeline_args,
)
+9 -3
View File
@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.24.0.dev0")
check_min_version("0.25.0.dev0")
# Cache compiled models across invocations of this script.
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
@@ -460,7 +460,10 @@ def main():
# Load models and create wrapper for stable diffusion
text_encoder = FlaxCLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype, revision=args.revision
args.pretrained_model_name_or_path,
subfolder="text_encoder",
dtype=weight_dtype,
revision=args.revision,
)
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
vae_arg,
@@ -468,7 +471,10 @@ def main():
**vae_kwargs,
)
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", dtype=weight_dtype, revision=args.revision
args.pretrained_model_name_or_path,
subfolder="unet",
dtype=weight_dtype,
revision=args.revision,
)
# Optimization
+47 -6
View File
@@ -57,7 +57,7 @@ from diffusers.models.attention_processor import (
AttnAddedKVProcessor2_0,
SlicedAttnAddedKVProcessor,
)
from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
from diffusers.training_utils import unet_lora_state_dict
from diffusers.utils import check_min_version, is_wandb_available
@@ -65,11 +65,44 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.24.0.dev0")
check_min_version("0.25.0.dev0")
logger = get_logger(__name__)
# TODO: This function should be removed once training scripts are rewritten in PEFT
def text_encoder_lora_state_dict(text_encoder):
state_dict = {}
def text_encoder_attn_modules(text_encoder):
from transformers import CLIPTextModel, CLIPTextModelWithProjection
attn_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn
attn_modules.append((name, mod))
return attn_modules
for name, module in text_encoder_attn_modules(text_encoder):
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
return state_dict
def save_model_card(
repo_id: str,
images=None,
@@ -150,6 +183,12 @@ def parse_args(input_args=None):
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--tokenizer_name",
type=str,
@@ -717,6 +756,7 @@ def main(args):
torch_dtype=torch_dtype,
safety_checker=None,
revision=args.revision,
variant=args.variant,
)
pipeline.set_progress_bar_config(disable=True)
@@ -770,11 +810,11 @@ def main(args):
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = text_encoder_cls.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
try:
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
)
except OSError:
# IF does not have a VAE so let's just set it to None
@@ -782,7 +822,7 @@ def main(args):
vae = None
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
# We only train the additional adapter LoRA layers
@@ -1277,6 +1317,7 @@ def main(args):
unet=accelerator.unwrap_model(unet),
text_encoder=None if args.pre_compute_text_embeddings else accelerator.unwrap_model(text_encoder),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
@@ -1362,7 +1403,7 @@ def main(args):
# Final inference
# Load previous pipeline
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype
)
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
+499 -111
View File
@@ -50,50 +50,113 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.loaders import LoraLoaderMixin
from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
from diffusers.training_utils import unet_lora_state_dict
from diffusers.training_utils import compute_snr, unet_lora_state_dict
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.24.0.dev0")
check_min_version("0.25.0.dev0")
logger = get_logger(__name__)
# TODO: This function should be removed once training scripts are rewritten in PEFT
def text_encoder_lora_state_dict(text_encoder):
state_dict = {}
def text_encoder_attn_modules(text_encoder):
from transformers import CLIPTextModel, CLIPTextModelWithProjection
attn_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn
attn_modules.append((name, mod))
return attn_modules
for name, module in text_encoder_attn_modules(text_encoder):
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
return state_dict
def save_model_card(
repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, vae_path=None
repo_id: str,
images=None,
base_model=str,
train_text_encoder=False,
instance_prompt=str,
validation_prompt=str,
repo_folder=None,
vae_path=None,
):
img_str = ""
img_str = "widget:\n" if images else ""
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n"
img_str += f"""
- text: '{validation_prompt if validation_prompt else ' ' }'
output:
url:
"image_{i}.png"
"""
yaml = f"""
---
license: openrail++
base_model: {base_model}
instance_prompt: {prompt}
tags:
- stable-diffusion-xl
- stable-diffusion-xl-diffusers
- text-to-image
- diffusers
- lora
inference: true
- template:sd-lora
{img_str}
base_model: {base_model}
instance_prompt: {instance_prompt}
license: openrail++
---
"""
model_card = f"""
# LoRA DreamBooth - {repo_id}
These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
{img_str}
model_card = f"""
# SDXL LoRA DreamBooth - {repo_id}
<Gallery />
## Model description
These are {repo_id} LoRA adaption weights for {base_model}.
The weights were trained using [DreamBooth](https://dreambooth.github.io/).
LoRA for the text encoder was enabled: {train_text_encoder}.
Special VAE used for training: {vae_path}.
## Trigger words
You should use {instance_prompt} to trigger the image generation.
## Download model
Weights for this model are available in Safetensors format.
[Download]({repo_id}/tree/main) them in the Files & versions tab.
"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
@@ -141,13 +204,59 @@ def parse_args(input_args=None):
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--dataset_name",
type=str,
default=None,
help=(
"The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
" or to a folder containing files that 🤗 Datasets can understand."
),
)
parser.add_argument(
"--dataset_config_name",
type=str,
default=None,
help="The config of the Dataset, leave as None if there's only one config.",
)
parser.add_argument(
"--instance_data_dir",
type=str,
default=None,
required=True,
help="A folder containing the training data of instance images.",
help=("A folder containing the training data. "),
)
parser.add_argument(
"--cache_dir",
type=str,
default=None,
help="The directory where the downloaded models and datasets will be stored.",
)
parser.add_argument(
"--image_column",
type=str,
default="image",
help="The column of the dataset containing the target image. By "
"default, the standard Image Dataset maps out 'file_name' "
"to 'image'.",
)
parser.add_argument(
"--caption_column",
type=str,
default=None,
help="The column of the dataset containing the instance prompt for each image",
)
parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
parser.add_argument(
"--class_data_dir",
type=str,
@@ -160,7 +269,7 @@ def parse_args(input_args=None):
type=str,
default=None,
required=True,
help="The prompt with identifier specifying the instance",
help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
)
parser.add_argument(
"--class_prompt",
@@ -299,9 +408,16 @@ def parse_args(input_args=None):
parser.add_argument(
"--learning_rate",
type=float,
default=5e-4,
default=1e-4,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--text_encoder_lr",
type=float,
default=5e-6,
help="Text encoder learning rate to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
@@ -317,6 +433,14 @@ def parse_args(input_args=None):
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--snr_gamma",
type=float,
default=None,
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
"More details here: https://arxiv.org/abs/2303.09556.",
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
@@ -335,13 +459,59 @@ def parse_args(input_args=None):
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
"--optimizer",
type=str,
default="AdamW",
help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
)
parser.add_argument(
"--use_8bit_adam",
action="store_true",
help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
)
parser.add_argument(
"--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
)
parser.add_argument(
"--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
)
parser.add_argument(
"--prodigy_beta3",
type=float,
default=None,
help="coefficients for computing the Prodidy stepsize using running averages. If set to None, "
"uses the value of square root of beta2. Ignored if optimizer is adamW",
)
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
parser.add_argument(
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
)
parser.add_argument(
"--adam_epsilon",
type=float,
default=1e-08,
help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
)
parser.add_argument(
"--prodigy_use_bias_correction",
type=bool,
default=True,
help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
)
parser.add_argument(
"--prodigy_safeguard_warmup",
type=bool,
default=True,
help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
"Ignored if optimizer is adamW",
)
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
@@ -414,6 +584,12 @@ def parse_args(input_args=None):
else:
args = parser.parse_args()
if args.dataset_name is None and args.instance_data_dir is None:
raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`")
if args.dataset_name is not None and args.instance_data_dir is not None:
raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`")
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
@@ -442,20 +618,84 @@ class DreamBoothDataset(Dataset):
def __init__(
self,
instance_data_root,
instance_prompt,
class_prompt,
class_data_root=None,
class_num=None,
size=1024,
repeats=1,
center_crop=False,
):
self.size = size
self.center_crop = center_crop
self.instance_data_root = Path(instance_data_root)
if not self.instance_data_root.exists():
raise ValueError("Instance images root doesn't exists.")
self.instance_prompt = instance_prompt
self.custom_instance_prompts = None
self.class_prompt = class_prompt
self.instance_images_path = list(Path(instance_data_root).iterdir())
self.num_instance_images = len(self.instance_images_path)
# if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
# we load the training data using load_dataset
if args.dataset_name is not None:
try:
from datasets import load_dataset
except ImportError:
raise ImportError(
"You are trying to load your data using the datasets library. If you wish to train using custom "
"captions please install the datasets library: `pip install datasets`. If you wish to load a "
"local folder containing images only, specify --instance_data_dir instead."
)
# Downloading and loading a dataset from the hub.
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
dataset = load_dataset(
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
)
# Preprocessing the datasets.
column_names = dataset["train"].column_names
# 6. Get the column names for input/target.
if args.image_column is None:
image_column = column_names[0]
logger.info(f"image column defaulting to {image_column}")
else:
image_column = args.image_column
if image_column not in column_names:
raise ValueError(
f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
instance_images = dataset["train"][image_column]
if args.caption_column is None:
logger.info(
"No caption column provided, defaulting to instance_prompt for all images. If your dataset "
"contains captions/prompts for the images, make sure to specify the "
"column as --caption_column"
)
self.custom_instance_prompts = None
else:
if args.caption_column not in column_names:
raise ValueError(
f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
custom_instance_prompts = dataset["train"][args.caption_column]
# create final list of captions according to --repeats
self.custom_instance_prompts = []
for caption in custom_instance_prompts:
self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))
else:
self.instance_data_root = Path(instance_data_root)
if not self.instance_data_root.exists():
raise ValueError("Instance images root doesn't exists.")
instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
self.custom_instance_prompts = None
self.instance_images = []
for img in instance_images:
self.instance_images.extend(itertools.repeat(img, repeats))
self.num_instance_images = len(self.instance_images)
self._length = self.num_instance_images
if class_data_root is not None:
@@ -484,13 +724,23 @@ class DreamBoothDataset(Dataset):
def __getitem__(self, index):
example = {}
instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
instance_image = self.instance_images[index % self.num_instance_images]
instance_image = exif_transpose(instance_image)
if not instance_image.mode == "RGB":
instance_image = instance_image.convert("RGB")
example["instance_images"] = self.image_transforms(instance_image)
if self.custom_instance_prompts:
caption = self.custom_instance_prompts[index % self.num_instance_images]
if caption:
example["instance_prompt"] = caption
else:
example["instance_prompt"] = self.instance_prompt
else: # costum prompts were provided, but length does not match size of image dataset
example["instance_prompt"] = self.instance_prompt
if self.class_data_root:
class_image = Image.open(self.class_images_path[index % self.num_class_images])
class_image = exif_transpose(class_image)
@@ -498,22 +748,25 @@ class DreamBoothDataset(Dataset):
if not class_image.mode == "RGB":
class_image = class_image.convert("RGB")
example["class_images"] = self.image_transforms(class_image)
example["class_prompt"] = self.class_prompt
return example
def collate_fn(examples, with_prior_preservation=False):
pixel_values = [example["instance_images"] for example in examples]
prompts = [example["instance_prompt"] for example in examples]
# Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes.
if with_prior_preservation:
pixel_values += [example["class_images"] for example in examples]
prompts += [example["class_prompt"] for example in examples]
pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
batch = {"pixel_values": pixel_values}
batch = {"pixel_values": pixel_values, "prompts": prompts}
return batch
@@ -630,6 +883,7 @@ def main(args):
args.pretrained_model_name_or_path,
torch_dtype=torch_dtype,
revision=args.revision,
variant=args.variant,
)
pipeline.set_progress_bar_config(disable=True)
@@ -668,10 +922,16 @@ def main(args):
# Load the tokenizers
tokenizer_one = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
use_fast=False,
)
tokenizer_two = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
args.pretrained_model_name_or_path,
subfolder="tokenizer_2",
revision=args.revision,
use_fast=False,
)
# import correct text encoder classes
@@ -685,10 +945,10 @@ def main(args):
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
text_encoder_two = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
)
vae_path = (
args.pretrained_model_name_or_path
@@ -696,10 +956,13 @@ def main(args):
else args.pretrained_vae_model_name_or_path
)
vae = AutoencoderKL.from_pretrained(
vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
# We only train the additional adapter LoRA layers
@@ -732,7 +995,8 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, "
"please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
else:
@@ -866,35 +1130,119 @@ def main(args):
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
optimizer_class = bnb.optim.AdamW8bit
# Optimization parameters
unet_lora_parameters_with_lr = {"params": unet_lora_parameters, "lr": args.learning_rate}
if args.train_text_encoder:
# different learning rate for text encoder and unet
text_lora_parameters_one_with_lr = {
"params": text_lora_parameters_one,
"weight_decay": args.adam_weight_decay_text_encoder,
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
}
text_lora_parameters_two_with_lr = {
"params": text_lora_parameters_two,
"weight_decay": args.adam_weight_decay_text_encoder,
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
}
params_to_optimize = [
unet_lora_parameters_with_lr,
text_lora_parameters_one_with_lr,
text_lora_parameters_two_with_lr,
]
else:
optimizer_class = torch.optim.AdamW
params_to_optimize = [unet_lora_parameters_with_lr]
# Optimizer creation
params_to_optimize = (
itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two)
if args.train_text_encoder
else unet_lora_parameters
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
logger.warn(
f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
"Defaulting to adamW"
)
args.optimizer = "adamw"
if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
logger.warn(
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
f"set to {args.optimizer.lower()}"
)
if args.optimizer.lower() == "adamw":
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(
params_to_optimize,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
if args.optimizer.lower() == "prodigy":
try:
import prodigyopt
except ImportError:
raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
optimizer_class = prodigyopt.Prodigy
if args.learning_rate <= 0.1:
logger.warn(
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
)
if args.train_text_encoder and args.text_encoder_lr:
logger.warn(
f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:"
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
f"When using prodigy only learning_rate is used as the initial learning rate."
)
# changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
# --learning_rate
params_to_optimize[1]["lr"] = args.learning_rate
params_to_optimize[2]["lr"] = args.learning_rate
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
decouple=args.prodigy_decouple,
use_bias_correction=args.prodigy_use_bias_correction,
safeguard_warmup=args.prodigy_safeguard_warmup,
)
# Dataset and DataLoaders creation:
train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt,
class_prompt=args.class_prompt,
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
class_num=args.num_class_images,
size=args.resolution,
repeats=args.repeats,
center_crop=args.center_crop,
)
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.train_batch_size,
shuffle=True,
collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
num_workers=args.dataloader_num_workers,
)
# Computes additional embeddings/ids required by the SDXL UNet.
# regular text emebddings (when `train_text_encoder` is not True)
# regular text embeddings (when `train_text_encoder` is not True)
# pooled text embeddings
# time ids
@@ -921,7 +1269,11 @@ def main(args):
# Handle instance prompt.
instance_time_ids = compute_time_ids()
if not args.train_text_encoder:
# If no type of tuning is done on the text_encoder and custom instance prompts are NOT
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
# the redundant encoding.
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(
args.instance_prompt, text_encoders, tokenizers
)
@@ -934,49 +1286,36 @@ def main(args):
args.class_prompt, text_encoders, tokenizers
)
# Clear the memory here.
if not args.train_text_encoder:
# Clear the memory here
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
del tokenizers, text_encoders
gc.collect()
torch.cuda.empty_cache()
# Pack the statically computed variables appropriately. This is so that we don't
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
# have to pass them to the dataloader.
add_time_ids = instance_time_ids
if args.with_prior_preservation:
add_time_ids = torch.cat([add_time_ids, class_time_ids], dim=0)
if not args.train_text_encoder:
prompt_embeds = instance_prompt_hidden_states
unet_add_text_embeds = instance_pooled_prompt_embeds
if args.with_prior_preservation:
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)
else:
tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt)
if args.with_prior_preservation:
class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt)
class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt)
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
# Dataset and DataLoaders creation:
train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir,
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
class_num=args.num_class_images,
size=args.resolution,
center_crop=args.center_crop,
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.train_batch_size,
shuffle=True,
collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
num_workers=args.dataloader_num_workers,
)
if not train_dataset.custom_instance_prompts:
if not args.train_text_encoder:
prompt_embeds = instance_prompt_hidden_states
unet_add_text_embeds = instance_pooled_prompt_embeds
if args.with_prior_preservation:
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)
# if we're optmizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
# batch prompts on all training steps
else:
tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt)
if args.with_prior_preservation:
class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt)
class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt)
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
@@ -1079,6 +1418,17 @@ def main(args):
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
prompts = batch["prompts"]
# encode batch prompts when custom prompts are provided for each image -
if train_dataset.custom_instance_prompts:
if not args.train_text_encoder:
prompt_embeds, unet_add_text_embeds = compute_text_embeddings(
prompts, text_encoders, tokenizers
)
else:
tokens_one = tokenize_prompt(tokenizer_one, prompts)
tokens_two = tokenize_prompt(tokenizer_two, prompts)
# Convert images to latent space
model_input = vae.encode(pixel_values).latent_dist.sample()
@@ -1099,16 +1449,21 @@ def main(args):
# (this is the forward diffusion process)
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
# Calculate the elements to repeat depending on the use of prior-preservation.
elems_to_repeat = bsz // 2 if args.with_prior_preservation else bsz
# Calculate the elements to repeat depending on the use of prior-preservation and custom captions.
if not train_dataset.custom_instance_prompts:
elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz
elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz
else:
elems_to_repeat_text_embeds = 1
elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz
# Predict the noise residual
if not args.train_text_encoder:
unet_added_conditions = {
"time_ids": add_time_ids.repeat(elems_to_repeat, 1),
"text_embeds": unet_add_text_embeds.repeat(elems_to_repeat, 1),
"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1),
"text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1),
}
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat, 1, 1)
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
model_pred = unet(
noisy_model_input,
timesteps,
@@ -1116,15 +1471,17 @@ def main(args):
added_cond_kwargs=unet_added_conditions,
).sample
else:
unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat, 1)}
unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)}
prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two],
tokenizers=None,
prompt=None,
text_input_ids_list=[tokens_one, tokens_two],
)
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat, 1)})
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat, 1, 1)
unet_added_conditions.update(
{"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)}
)
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
model_pred = unet(
noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions
).sample
@@ -1142,16 +1499,34 @@ def main(args):
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
target, target_prior = torch.chunk(target, 2, dim=0)
# Compute instance loss
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
# Compute prior loss
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
if args.snr_gamma is None:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
else:
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps)
base_weight = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = base_weight + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
if args.with_prior_preservation:
# Add the prior loss to the instance loss.
loss = loss + args.prior_loss_weight * prior_loss
else:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss)
if accelerator.sync_gradients:
@@ -1212,10 +1587,16 @@ def main(args):
# create pipeline
if not args.train_text_encoder:
text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
args.pretrained_model_name_or_path,
subfolder="text_encoder",
revision=args.revision,
variant=args.variant,
)
text_encoder_two = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
args.pretrained_model_name_or_path,
subfolder="text_encoder_2",
revision=args.revision,
variant=args.variant,
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path,
@@ -1224,6 +1605,7 @@ def main(args):
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
unet=accelerator.unwrap_model(unet),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
@@ -1301,10 +1683,15 @@ def main(args):
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path, vae=vae, revision=args.revision, torch_dtype=weight_dtype
args.pretrained_model_name_or_path,
vae=vae,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
@@ -1353,7 +1740,8 @@ def main(args):
images=images,
base_model=args.pretrained_model_name_or_path,
train_text_encoder=args.train_text_encoder,
prompt=args.instance_prompt,
instance_prompt=args.instance_prompt,
validation_prompt=args.validation_prompt,
repo_folder=args.output_dir,
vae_path=args.pretrained_vae_model_name_or_path,
)
@@ -0,0 +1,101 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
import tempfile
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class InstructPix2Pix(ExamplesTestsAccelerate):
def test_instruct_pix2pix_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/instruct_pix2pix/train_instruct_pix2pix.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
--dataset_name=hf-internal-testing/instructpix2pix-10-samples
--resolution=64
--random_flip
--train_batch_size=1
--max_train_steps=7
--checkpointing_steps=2
--checkpoints_total_limit=2
--output_dir {tmpdir}
--seed=0
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)
def test_instruct_pix2pix_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/instruct_pix2pix/train_instruct_pix2pix.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
--dataset_name=hf-internal-testing/instructpix2pix-10-samples
--resolution=64
--random_flip
--train_batch_size=1
--max_train_steps=9
--checkpointing_steps=2
--output_dir {tmpdir}
--seed=0
""".split()
run_command(self._launch_args + test_args)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
)
resume_run_args = f"""
examples/instruct_pix2pix/train_instruct_pix2pix.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-pipe
--dataset_name=hf-internal-testing/instructpix2pix-10-samples
--resolution=64
--random_flip
--train_batch_size=1
--max_train_steps=11
--checkpointing_steps=2
--output_dir {tmpdir}
--seed=0
--resume_from_checkpoint=checkpoint-8
--checkpoints_total_limit=3
""".split()
run_command(self._launch_args + resume_run_args)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
)
@@ -52,7 +52,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.24.0.dev0")
check_min_version("0.25.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -78,6 +78,12 @@ def parse_args():
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--dataset_name",
type=str,
@@ -435,9 +441,11 @@ def main():
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
)
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
)
@@ -915,6 +923,7 @@ def main():
text_encoder=accelerator.unwrap_model(text_encoder),
vae=accelerator.unwrap_model(vae),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
@@ -966,6 +975,7 @@ def main():
vae=accelerator.unwrap_model(vae),
unet=unet,
revision=args.revision,
variant=args.variant,
)
pipeline.save_pretrained(args.output_dir)
@@ -55,7 +55,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.24.0.dev0")
check_min_version("0.25.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -118,6 +118,12 @@ def parse_args():
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--dataset_name",
type=str,
@@ -484,9 +490,10 @@ def main():
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
# InstructPix2Pix uses an additional image for conditioning. To accommodate that,
@@ -695,10 +702,16 @@ def main():
# Load scheduler, tokenizer and models.
tokenizer_1 = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
use_fast=False,
)
tokenizer_2 = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
args.pretrained_model_name_or_path,
subfolder="tokenizer_2",
revision=args.revision,
use_fast=False,
)
text_encoder_cls_1 = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
text_encoder_cls_2 = import_model_class_from_model_name_or_path(
@@ -708,10 +721,10 @@ def main():
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder_1 = text_encoder_cls_1.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
text_encoder_2 = text_encoder_cls_2.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
)
# We ALWAYS pre-compute the additional condition embeddings needed for SDXL
@@ -1109,6 +1122,7 @@ def main():
tokenizer_2=tokenizer_2,
vae=vae,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
@@ -1176,6 +1190,7 @@ def main():
vae=vae,
unet=unet,
revision=args.revision,
variant=args.variant,
)
pipeline.save_pretrained(args.output_dir)
@@ -52,7 +52,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.24.0.dev0")
check_min_version("0.25.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.24.0.dev0")
check_min_version("0.25.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.24.0.dev0")
check_min_version("0.25.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -51,7 +51,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.24.0.dev0")
check_min_version("0.25.0.dev0")
logger = get_logger(__name__, log_level="INFO")
+51
View File
@@ -0,0 +1,51 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
import tempfile
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class T2IAdapter(ExamplesTestsAccelerate):
def test_t2i_adapter_sdxl(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/t2i_adapter/train_t2i_adapter_sdxl.py
--pretrained_model_name_or_path=hf-internal-testing/tiny-stable-diffusion-xl-pipe
--adapter_model_name_or_path=hf-internal-testing/tiny-adapter
--dataset_name=hf-internal-testing/fill10
--output_dir={tmpdir}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=9
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
+20 -6
View File
@@ -58,7 +58,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.24.0.dev0")
check_min_version("0.25.0.dev0")
logger = get_logger(__name__)
@@ -85,6 +85,7 @@ def log_validation(vae, unet, adapter, args, accelerator, weight_dtype, step):
unet=unet,
adapter=adapter,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
@@ -262,6 +263,12 @@ def parse_args(input_args=None):
" float32 precision."
),
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--tokenizer_name",
type=str,
@@ -812,10 +819,16 @@ def main(args):
# Load the tokenizers
tokenizer_one = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
use_fast=False,
)
tokenizer_two = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
args.pretrained_model_name_or_path,
subfolder="tokenizer_2",
revision=args.revision,
use_fast=False,
)
# import correct text encoder classes
@@ -829,10 +842,10 @@ def main(args):
# Load scheduler and models
noise_scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
text_encoder_two = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
)
vae_path = (
args.pretrained_model_name_or_path
@@ -843,9 +856,10 @@ def main(args):
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
if args.adapter_model_name_or_path:
File diff suppressed because it is too large Load Diff
+61
View File
@@ -0,0 +1,61 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import subprocess
import tempfile
import unittest
from typing import List
from accelerate.utils import write_basic_config
# These utils relate to ensuring the right error message is received when running scripts
class SubprocessCallException(Exception):
pass
def run_command(command: List[str], return_stdout=False):
"""
Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture
if an error occurred while running `command`
"""
try:
output = subprocess.check_output(command, stderr=subprocess.STDOUT)
if return_stdout:
if hasattr(output, "decode"):
output = output.decode("utf-8")
return output
except subprocess.CalledProcessError as e:
raise SubprocessCallException(
f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
) from e
class ExamplesTestsAccelerate(unittest.TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._tmpdir = tempfile.mkdtemp()
cls.configPath = os.path.join(cls._tmpdir, "default_config.yml")
write_basic_config(save_location=cls.configPath)
cls._launch_args = ["accelerate", "launch", "--config_file", cls.configPath]
@classmethod
def tearDownClass(cls):
super().tearDownClass()
shutil.rmtree(cls._tmpdir)
@@ -0,0 +1,373 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import shutil
import sys
import tempfile
from diffusers import DiffusionPipeline, UNet2DConditionModel # noqa: E402
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class TextToImage(ExamplesTestsAccelerate):
def test_text_to_image(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/text_to_image/train_text_to_image.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--center_crop
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
def test_text_to_image_checkpointing(self):
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
prompt = "a prompt"
with tempfile.TemporaryDirectory() as tmpdir:
# Run training script with checkpointing
# max_train_steps == 5, checkpointing_steps == 2
# Should create checkpoints at steps 2, 4
initial_run_args = f"""
examples/text_to_image/train_text_to_image.py
--pretrained_model_name_or_path {pretrained_model_name_or_path}
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--center_crop
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 5
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--seed=0
""".split()
run_command(self._launch_args + initial_run_args)
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
pipe(prompt, num_inference_steps=2)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4"},
)
# check can run an intermediate checkpoint
unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)
pipe(prompt, num_inference_steps=2)
# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
# Run training script for 7 total steps resuming from checkpoint 4
resume_run_args = f"""
examples/text_to_image/train_text_to_image.py
--pretrained_model_name_or_path {pretrained_model_name_or_path}
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--center_crop
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-4
--seed=0
""".split()
run_command(self._launch_args + resume_run_args)
# check can run new fully trained pipeline
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
pipe(prompt, num_inference_steps=2)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{
# no checkpoint-2 -> check old checkpoints do not exist
# check new checkpoints exist
"checkpoint-4",
"checkpoint-6",
},
)
def test_text_to_image_checkpointing_use_ema(self):
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
prompt = "a prompt"
with tempfile.TemporaryDirectory() as tmpdir:
# Run training script with checkpointing
# max_train_steps == 5, checkpointing_steps == 2
# Should create checkpoints at steps 2, 4
initial_run_args = f"""
examples/text_to_image/train_text_to_image.py
--pretrained_model_name_or_path {pretrained_model_name_or_path}
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--center_crop
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 5
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--use_ema
--seed=0
""".split()
run_command(self._launch_args + initial_run_args)
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
pipe(prompt, num_inference_steps=2)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4"},
)
# check can run an intermediate checkpoint
unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)
pipe(prompt, num_inference_steps=2)
# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
# Run training script for 7 total steps resuming from checkpoint 4
resume_run_args = f"""
examples/text_to_image/train_text_to_image.py
--pretrained_model_name_or_path {pretrained_model_name_or_path}
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--center_crop
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-4
--use_ema
--seed=0
""".split()
run_command(self._launch_args + resume_run_args)
# check can run new fully trained pipeline
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
pipe(prompt, num_inference_steps=2)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{
# no checkpoint-2 -> check old checkpoints do not exist
# check new checkpoints exist
"checkpoint-4",
"checkpoint-6",
},
)
def test_text_to_image_checkpointing_checkpoints_total_limit(self):
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
prompt = "a prompt"
with tempfile.TemporaryDirectory() as tmpdir:
# Run training script with checkpointing
# max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
# Should create checkpoints at steps 2, 4, 6
# with checkpoint at step 2 deleted
initial_run_args = f"""
examples/text_to_image/train_text_to_image.py
--pretrained_model_name_or_path {pretrained_model_name_or_path}
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--center_crop
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--checkpoints_total_limit=2
--seed=0
""".split()
run_command(self._launch_args + initial_run_args)
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
pipe(prompt, num_inference_steps=2)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
# checkpoint-2 should have been deleted
{"checkpoint-4", "checkpoint-6"},
)
def test_text_to_image_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
prompt = "a prompt"
with tempfile.TemporaryDirectory() as tmpdir:
# Run training script with checkpointing
# max_train_steps == 9, checkpointing_steps == 2
# Should create checkpoints at steps 2, 4, 6, 8
initial_run_args = f"""
examples/text_to_image/train_text_to_image.py
--pretrained_model_name_or_path {pretrained_model_name_or_path}
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--center_crop
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 9
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--seed=0
""".split()
run_command(self._launch_args + initial_run_args)
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
pipe(prompt, num_inference_steps=2)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
)
# resume and we should try to checkpoint at 10, where we'll have to remove
# checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint
resume_run_args = f"""
examples/text_to_image/train_text_to_image.py
--pretrained_model_name_or_path {pretrained_model_name_or_path}
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--center_crop
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 11
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-8
--checkpoints_total_limit=3
--seed=0
""".split()
run_command(self._launch_args + resume_run_args)
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
pipe(prompt, num_inference_steps=2)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
)
class TextToImageSDXL(ExamplesTestsAccelerate):
def test_text_to_image_sdxl(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/text_to_image/train_text_to_image_sdxl.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--center_crop
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
@@ -0,0 +1,308 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
import tempfile
import safetensors
from diffusers import DiffusionPipeline # noqa: E402
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class TextToImageLoRA(ExamplesTestsAccelerate):
def test_text_to_image_lora_sdxl_checkpointing_checkpoints_total_limit(self):
prompt = "a prompt"
pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
with tempfile.TemporaryDirectory() as tmpdir:
# Run training script with checkpointing
# max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
# Should create checkpoints at steps 2, 4, 6
# with checkpoint at step 2 deleted
initial_run_args = f"""
examples/text_to_image/train_text_to_image_lora_sdxl.py
--pretrained_model_name_or_path {pipeline_path}
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--checkpoints_total_limit=2
""".split()
run_command(self._launch_args + initial_run_args)
pipe = DiffusionPipeline.from_pretrained(pipeline_path)
pipe.load_lora_weights(tmpdir)
pipe(prompt, num_inference_steps=2)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
# checkpoint-2 should have been deleted
{"checkpoint-4", "checkpoint-6"},
)
def test_text_to_image_lora_checkpointing_checkpoints_total_limit(self):
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
prompt = "a prompt"
with tempfile.TemporaryDirectory() as tmpdir:
# Run training script with checkpointing
# max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
# Should create checkpoints at steps 2, 4, 6
# with checkpoint at step 2 deleted
initial_run_args = f"""
examples/text_to_image/train_text_to_image_lora.py
--pretrained_model_name_or_path {pretrained_model_name_or_path}
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--center_crop
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--checkpoints_total_limit=2
--seed=0
--num_validation_images=0
""".split()
run_command(self._launch_args + initial_run_args)
pipe = DiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
)
pipe.load_lora_weights(tmpdir)
pipe(prompt, num_inference_steps=2)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
# checkpoint-2 should have been deleted
{"checkpoint-4", "checkpoint-6"},
)
def test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
prompt = "a prompt"
with tempfile.TemporaryDirectory() as tmpdir:
# Run training script with checkpointing
# max_train_steps == 9, checkpointing_steps == 2
# Should create checkpoints at steps 2, 4, 6, 8
initial_run_args = f"""
examples/text_to_image/train_text_to_image_lora.py
--pretrained_model_name_or_path {pretrained_model_name_or_path}
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--center_crop
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 9
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--seed=0
--num_validation_images=0
""".split()
run_command(self._launch_args + initial_run_args)
pipe = DiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
)
pipe.load_lora_weights(tmpdir)
pipe(prompt, num_inference_steps=2)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
)
# resume and we should try to checkpoint at 10, where we'll have to remove
# checkpoint-2 and checkpoint-4 instead of just a single previous checkpoint
resume_run_args = f"""
examples/text_to_image/train_text_to_image_lora.py
--pretrained_model_name_or_path {pretrained_model_name_or_path}
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--center_crop
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 11
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-8
--checkpoints_total_limit=3
--seed=0
--num_validation_images=0
""".split()
run_command(self._launch_args + resume_run_args)
pipe = DiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None
)
pipe.load_lora_weights(tmpdir)
pipe(prompt, num_inference_steps=2)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
)
class TextToImageLoRASDXL(ExamplesTestsAccelerate):
def test_text_to_image_lora_sdxl(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/text_to_image/train_text_to_image_lora_sdxl.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
def test_text_to_image_lora_sdxl_with_text_encoder(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/text_to_image/train_text_to_image_lora_sdxl.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--train_text_encoder
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"unet"` or `"text_encoder"` or `"text_encoder_2"` in their names.
keys = lora_state_dict.keys()
starts_with_unet = all(
k.startswith("unet") or k.startswith("text_encoder") or k.startswith("text_encoder_2") for k in keys
)
self.assertTrue(starts_with_unet)
def test_text_to_image_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self):
prompt = "a prompt"
pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
with tempfile.TemporaryDirectory() as tmpdir:
# Run training script with checkpointing
# max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
# Should create checkpoints at steps 2, 4, 6
# with checkpoint at step 2 deleted
initial_run_args = f"""
examples/text_to_image/train_text_to_image_lora_sdxl.py
--pretrained_model_name_or_path {pipeline_path}
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--train_text_encoder
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--checkpoints_total_limit=2
""".split()
run_command(self._launch_args + initial_run_args)
pipe = DiffusionPipeline.from_pretrained(pipeline_path)
pipe.load_lora_weights(tmpdir)
pipe(prompt, num_inference_steps=2)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
# checkpoint-2 should have been deleted
{"checkpoint-4", "checkpoint-6"},
)
+12 -4
View File
@@ -53,7 +53,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.24.0.dev0")
check_min_version("0.25.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -148,6 +148,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight
unet=accelerator.unwrap_model(unet),
safety_checker=None,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
@@ -209,6 +210,12 @@ def parse_args():
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--dataset_name",
type=str,
@@ -567,10 +574,10 @@ def main():
# across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.
with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
)
unet = UNet2DConditionModel.from_pretrained(
@@ -585,7 +592,7 @@ def main():
# Create EMA for the unet.
if args.use_ema:
ema_unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
@@ -1026,6 +1033,7 @@ def main():
vae=vae,
unet=unet,
revision=args.revision,
variant=args.variant,
)
pipeline.save_pretrained(args.output_dir)
@@ -33,7 +33,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.24.0.dev0")
check_min_version("0.25.0.dev0")
logger = logging.getLogger(__name__)
@@ -54,6 +54,12 @@ def parse_args():
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--dataset_name",
type=str,
@@ -40,8 +40,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available
@@ -49,11 +48,44 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.24.0.dev0")
check_min_version("0.25.0.dev0")
logger = get_logger(__name__, log_level="INFO")
# TODO: This function should be removed once training scripts are rewritten in PEFT
def text_encoder_lora_state_dict(text_encoder):
state_dict = {}
def text_encoder_attn_modules(text_encoder):
from transformers import CLIPTextModel, CLIPTextModelWithProjection
attn_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn
attn_modules.append((name, mod))
return attn_modules
for name, module in text_encoder_attn_modules(text_encoder):
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
return state_dict
def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):
img_str = ""
for i, image in enumerate(images):
@@ -98,6 +130,12 @@ def parse_args():
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--dataset_name",
type=str,
@@ -422,9 +460,11 @@ def main():
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
# freeze parameters of models to save more memory
unet.requires_grad_(False)
@@ -458,25 +498,43 @@ def main():
# => 32 layers
# Set correct lora layers
lora_attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
unet_lora_parameters = []
for attn_processor_name, attn_processor in unet.attn_processors.items():
# Parse the attention module.
attn_module = unet
for n in attn_processor_name.split(".")[:-1]:
attn_module = getattr(attn_module, n)
lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
rank=args.rank,
# Set the `lora_layer` attribute of the attention-related matrices.
attn_module.to_q.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
)
)
attn_module.to_k.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
)
)
unet.set_attn_processor(lora_attn_procs)
attn_module.to_v.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
)
)
attn_module.to_out[0].set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_out[0].in_features,
out_features=attn_module.to_out[0].out_features,
rank=args.rank,
)
)
# Accumulate the LoRA params to optimize.
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
@@ -491,8 +549,6 @@ def main():
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
lora_layers = AttnProcsLayers(unet.attn_processors)
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
@@ -517,7 +573,7 @@ def main():
optimizer_cls = torch.optim.AdamW
optimizer = optimizer_cls(
lora_layers.parameters(),
unet_lora_parameters,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
@@ -644,8 +700,8 @@ def main():
)
# Prepare everything with our `accelerator`.
lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
lora_layers, optimizer, train_dataloader, lr_scheduler
unet_lora_parameters, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet_lora_parameters, optimizer, train_dataloader, lr_scheduler
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
@@ -777,7 +833,7 @@ def main():
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = lora_layers.parameters()
params_to_clip = unet_lora_parameters
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
@@ -833,6 +889,7 @@ def main():
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
@@ -889,7 +946,7 @@ def main():
# Final inference
# Load previous pipeline
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype
)
pipeline = pipeline.to(accelerator.device)
@@ -50,7 +50,7 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.loaders import LoraLoaderMixin
from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available
@@ -58,11 +58,44 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.24.0.dev0")
check_min_version("0.25.0.dev0")
logger = get_logger(__name__)
# TODO: This function should be removed once training scripts are rewritten in PEFT
def text_encoder_lora_state_dict(text_encoder):
state_dict = {}
def text_encoder_attn_modules(text_encoder):
from transformers import CLIPTextModel, CLIPTextModelWithProjection
attn_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn
attn_modules.append((name, mod))
return attn_modules
for name, module in text_encoder_attn_modules(text_encoder):
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
return state_dict
def save_model_card(
repo_id: str,
images=None,
@@ -147,6 +180,12 @@ def parse_args(input_args=None):
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--dataset_name",
type=str,
@@ -537,10 +576,16 @@ def main(args):
# Load the tokenizers
tokenizer_one = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
use_fast=False,
)
tokenizer_two = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
args.pretrained_model_name_or_path,
subfolder="tokenizer_2",
revision=args.revision,
use_fast=False,
)
# import correct text encoder classes
@@ -554,10 +599,10 @@ def main(args):
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
text_encoder_two = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
)
vae_path = (
args.pretrained_model_name_or_path
@@ -565,10 +610,13 @@ def main(args):
else args.pretrained_vae_model_name_or_path
)
vae = AutoencoderKL.from_pretrained(
vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
# We only train the additional adapter LoRA layers
@@ -1143,6 +1191,7 @@ def main(args):
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
unet=accelerator.unwrap_model(unet),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
@@ -1208,7 +1257,11 @@ def main(args):
# Final inference
# Load previous pipeline
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path, vae=vae, revision=args.revision, torch_dtype=weight_dtype
args.pretrained_model_name_or_path,
vae=vae,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
@@ -57,7 +57,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.24.0.dev0")
check_min_version("0.25.0.dev0")
logger = get_logger(__name__)
@@ -148,6 +148,12 @@ def parse_args(input_args=None):
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--dataset_name",
type=str,
@@ -618,10 +624,16 @@ def main(args):
# Load the tokenizers
tokenizer_one = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
use_fast=False,
)
tokenizer_two = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
args.pretrained_model_name_or_path,
subfolder="tokenizer_2",
revision=args.revision,
use_fast=False,
)
# import correct text encoder classes
@@ -636,10 +648,10 @@ def main(args):
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
# Check for terminal SNR in combination with SNR Gamma
text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
text_encoder_two = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
)
vae_path = (
args.pretrained_model_name_or_path
@@ -647,10 +659,13 @@ def main(args):
else args.pretrained_vae_model_name_or_path
)
vae = AutoencoderKL.from_pretrained(
vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
# Freeze vae and text encoders.
@@ -677,7 +692,7 @@ def main(args):
# Create EMA for the unet.
if args.use_ema:
ema_unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
@@ -1145,12 +1160,14 @@ def main(args):
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
unet=accelerator.unwrap_model(unet),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
if args.prediction_type is not None:
@@ -1198,10 +1215,16 @@ def main(args):
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path, unet=unet, vae=vae, revision=args.revision, torch_dtype=weight_dtype
args.pretrained_model_name_or_path,
unet=unet,
vae=vae,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
if args.prediction_type is not None:
scheduler_args = {"prediction_type": args.prediction_type}
@@ -0,0 +1,160 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
import tempfile
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class TextualInversion(ExamplesTestsAccelerate):
def test_textual_inversion(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/textual_inversion/textual_inversion.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
--train_data_dir docs/source/en/imgs
--learnable_property object
--placeholder_token <cat-toy>
--initializer_token a
--validation_prompt <cat-toy>
--validation_steps 1
--save_steps 1
--num_vectors 2
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "learned_embeds.safetensors")))
def test_textual_inversion_checkpointing(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/textual_inversion/textual_inversion.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
--train_data_dir docs/source/en/imgs
--learnable_property object
--placeholder_token <cat-toy>
--initializer_token a
--validation_prompt <cat-toy>
--validation_steps 1
--save_steps 1
--num_vectors 2
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 3
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=1
--checkpoints_total_limit=2
""".split()
run_command(self._launch_args + test_args)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-3"},
)
def test_textual_inversion_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/textual_inversion/textual_inversion.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
--train_data_dir docs/source/en/imgs
--learnable_property object
--placeholder_token <cat-toy>
--initializer_token a
--validation_prompt <cat-toy>
--validation_steps 1
--save_steps 1
--num_vectors 2
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 3
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=1
""".split()
run_command(self._launch_args + test_args)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-1", "checkpoint-2", "checkpoint-3"},
)
resume_run_args = f"""
examples/textual_inversion/textual_inversion.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
--train_data_dir docs/source/en/imgs
--learnable_property object
--placeholder_token <cat-toy>
--initializer_token a
--validation_prompt <cat-toy>
--validation_steps 1
--save_steps 1
--num_vectors 2
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 4
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=1
--resume_from_checkpoint=checkpoint-3
--checkpoints_total_limit=2
""".split()
run_command(self._launch_args + resume_run_args)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-3", "checkpoint-4"},
)
@@ -79,7 +79,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.24.0.dev0")
check_min_version("0.25.0.dev0")
logger = get_logger(__name__)
@@ -126,6 +126,7 @@ def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight
vae=vae,
safety_checker=None,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
@@ -206,6 +207,12 @@ def parse_args():
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--tokenizer_name",
type=str,
@@ -624,9 +631,11 @@ def main():
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
# Add the placeholder token in tokenizer
@@ -752,6 +761,7 @@ def main():
num_cycles=args.lr_num_cycles,
)
text_encoder.train()
# Prepare everything with our `accelerator`.
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder, optimizer, train_dataloader, lr_scheduler
@@ -56,7 +56,7 @@ else:
# ------------------------------------------------------------------------------
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.24.0.dev0")
check_min_version("0.25.0.dev0")
logger = logging.getLogger(__name__)
@@ -0,0 +1,130 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
import tempfile
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class Unconditional(ExamplesTestsAccelerate):
def test_train_unconditional(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/unconditional_image_generation/train_unconditional.py
--dataset_name hf-internal-testing/dummy_image_class_data
--model_config_name_or_path diffusers/ddpm_dummy
--resolution 64
--output_dir {tmpdir}
--train_batch_size 2
--num_epochs 1
--gradient_accumulation_steps 1
--ddpm_num_inference_steps 2
--learning_rate 1e-3
--lr_warmup_steps 5
""".split()
run_command(self._launch_args + test_args, return_stdout=True)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
def test_unconditional_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
initial_run_args = f"""
examples/unconditional_image_generation/train_unconditional.py
--dataset_name hf-internal-testing/dummy_image_class_data
--model_config_name_or_path diffusers/ddpm_dummy
--resolution 64
--output_dir {tmpdir}
--train_batch_size 1
--num_epochs 1
--gradient_accumulation_steps 1
--ddpm_num_inference_steps 2
--learning_rate 1e-3
--lr_warmup_steps 5
--checkpointing_steps=2
--checkpoints_total_limit=2
""".split()
run_command(self._launch_args + initial_run_args)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
# checkpoint-2 should have been deleted
{"checkpoint-4", "checkpoint-6"},
)
def test_unconditional_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
initial_run_args = f"""
examples/unconditional_image_generation/train_unconditional.py
--dataset_name hf-internal-testing/dummy_image_class_data
--model_config_name_or_path diffusers/ddpm_dummy
--resolution 64
--output_dir {tmpdir}
--train_batch_size 1
--num_epochs 1
--gradient_accumulation_steps 1
--ddpm_num_inference_steps 2
--learning_rate 1e-3
--lr_warmup_steps 5
--checkpointing_steps=1
""".split()
run_command(self._launch_args + initial_run_args)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-1", "checkpoint-2", "checkpoint-3", "checkpoint-4", "checkpoint-5", "checkpoint-6"},
)
resume_run_args = f"""
examples/unconditional_image_generation/train_unconditional.py
--dataset_name hf-internal-testing/dummy_image_class_data
--model_config_name_or_path diffusers/ddpm_dummy
--resolution 64
--output_dir {tmpdir}
--train_batch_size 1
--num_epochs 2
--gradient_accumulation_steps 1
--ddpm_num_inference_steps 2
--learning_rate 1e-3
--lr_warmup_steps 5
--resume_from_checkpoint=checkpoint-6
--checkpointing_steps=2
--checkpoints_total_limit=3
""".split()
run_command(self._launch_args + resume_run_args)
# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-8", "checkpoint-10", "checkpoint-12"},
)
@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.24.0.dev0")
check_min_version("0.25.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -5,3 +5,4 @@ wandb
huggingface-cli
bitsandbytes
deepspeed
peft>=0.6.0
@@ -31,14 +31,14 @@ from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from huggingface_hub import create_repo, hf_hub_download, upload_folder
from modeling_efficient_net_encoder import EfficientNetEncoder
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
from torchvision import transforms
from tqdm import tqdm
from transformers import CLIPTextModel, PreTrainedTokenizerFast
from transformers.utils import ContextManagers
from diffusers import AutoPipelineForText2Image, DDPMWuerstchenScheduler, WuerstchenPriorPipeline
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.optimization import get_scheduler
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS, WuerstchenPrior
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
@@ -50,7 +50,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.24.0.dev0")
check_min_version("0.25.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -139,17 +139,17 @@ More information on all the CLI arguments and the environment are available on y
f.write(yaml + model_card)
def log_validation(text_encoder, tokenizer, attn_processors, args, accelerator, weight_dtype, epoch):
def log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dtype, epoch):
logger.info("Running validation... ")
pipeline = AutoPipelineForText2Image.from_pretrained(
args.pretrained_decoder_model_name_or_path,
prior=accelerator.unwrap_model(prior),
prior_text_encoder=accelerator.unwrap_model(text_encoder),
prior_tokenizer=tokenizer,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
pipeline.prior_prior.set_attn_processor(attn_processors)
pipeline.set_progress_bar_config(disable=True)
if args.seed is None:
@@ -159,7 +159,7 @@ def log_validation(text_encoder, tokenizer, attn_processors, args, accelerator,
images = []
for i in range(len(args.validation_prompts)):
with torch.autocast("cuda"):
with torch.cuda.amp.autocast():
image = pipeline(
args.validation_prompts[i],
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
@@ -167,7 +167,6 @@ def log_validation(text_encoder, tokenizer, attn_processors, args, accelerator,
height=args.resolution,
width=args.resolution,
).images[0]
images.append(image)
for tracker in accelerator.trackers:
@@ -527,11 +526,50 @@ def main():
prior.to(accelerator.device, dtype=weight_dtype)
# lora attn processor
lora_attn_procs = {}
for name in prior.attn_processors.keys():
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=prior.config["c"], rank=args.rank)
prior.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(prior.attn_processors)
prior_lora_config = LoraConfig(
r=args.rank, target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"]
)
prior.add_adapter(prior_lora_config)
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
prior_lora_layers_to_save = None
for model in models:
if isinstance(model, type(accelerator.unwrap_model(prior))):
prior_lora_layers_to_save = get_peft_model_state_dict(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
WuerstchenPriorPipeline.save_lora_weights(
output_dir,
unet_lora_layers=prior_lora_layers_to_save,
)
def load_model_hook(models, input_dir):
prior_ = None
while len(models) > 0:
model = models.pop()
if isinstance(model, type(accelerator.unwrap_model(prior))):
prior_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, network_alphas = WuerstchenPriorPipeline.lora_state_dict(input_dir)
WuerstchenPriorPipeline.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=prior_)
WuerstchenPriorPipeline.load_lora_into_text_encoder(
lora_state_dict,
network_alphas=network_alphas,
)
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
@@ -547,8 +585,9 @@ def main():
optimizer_cls = bnb.optim.AdamW8bit
else:
optimizer_cls = torch.optim.AdamW
params_to_optimize = list(filter(lambda p: p.requires_grad, prior.parameters()))
optimizer = optimizer_cls(
lora_layers.parameters(),
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
@@ -674,8 +713,8 @@ def main():
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
)
lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
lora_layers, optimizer, train_dataloader, lr_scheduler
prior, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
prior, optimizer, train_dataloader, lr_scheduler
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
@@ -782,7 +821,7 @@ def main():
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(lora_layers.parameters(), args.max_grad_norm)
accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
@@ -828,17 +867,19 @@ def main():
if accelerator.is_main_process:
if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
log_validation(
text_encoder, tokenizer, prior.attn_processors, args, accelerator, weight_dtype, global_step
)
log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dtype, global_step)
# Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
prior = accelerator.unwrap_model(prior)
prior = prior.to(torch.float32)
prior_lora_state_dict = get_peft_model_state_dict(prior)
WuerstchenPriorPipeline.save_lora_weights(
os.path.join(args.output_dir, "prior_lora"),
unet_lora_layers=lora_layers,
save_directory=args.output_dir,
unet_lora_layers=prior_lora_state_dict,
)
# Run a final round of inference.
@@ -849,11 +890,12 @@ def main():
args.pretrained_decoder_model_name_or_path,
prior_text_encoder=accelerator.unwrap_model(text_encoder),
prior_tokenizer=tokenizer,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device, torch_dtype=weight_dtype)
# load lora weights
pipeline.prior_pipe.load_lora_weights(os.path.join(args.output_dir, "prior_lora"))
pipeline = pipeline.to(accelerator.device)
# load lora weights
pipeline.prior_pipe.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors")
pipeline.set_progress_bar_config(disable=True)
if args.seed is None:
@@ -862,7 +904,7 @@ def main():
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
for i in range(len(args.validation_prompts)):
with torch.autocast("cuda"):
with torch.cuda.amp.autocast():
image = pipeline(
args.validation_prompts[i],
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
@@ -51,7 +51,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.24.0.dev0")
check_min_version("0.25.0.dev0")
logger = get_logger(__name__, log_level="INFO")
+98
View File
@@ -0,0 +1,98 @@
#!/usr/bin/env python3
import argparse
import fnmatch
from safetensors.torch import load_file
from diffusers import Kandinsky3UNet
MAPPING = {
"to_time_embed.1": "time_embedding.linear_1",
"to_time_embed.3": "time_embedding.linear_2",
"in_layer": "conv_in",
"out_layer.0": "conv_norm_out",
"out_layer.2": "conv_out",
"down_samples": "down_blocks",
"up_samples": "up_blocks",
"projection_lin": "encoder_hid_proj.projection_linear",
"projection_ln": "encoder_hid_proj.projection_norm",
"feature_pooling": "add_time_condition",
"to_query": "to_q",
"to_key": "to_k",
"to_value": "to_v",
"output_layer": "to_out.0",
"self_attention_block": "attentions.0",
}
DYNAMIC_MAP = {
"resnet_attn_blocks.*.0": "resnets_in.*",
"resnet_attn_blocks.*.1": ("attentions.*", 1),
"resnet_attn_blocks.*.2": "resnets_out.*",
}
# MAPPING = {}
def convert_state_dict(unet_state_dict):
"""
Convert the state dict of a U-Net model to match the key format expected by Kandinsky3UNet model.
Args:
unet_model (torch.nn.Module): The original U-Net model.
unet_kandi3_model (torch.nn.Module): The Kandinsky3UNet model to match keys with.
Returns:
OrderedDict: The converted state dictionary.
"""
# Example of renaming logic (this will vary based on your model's architecture)
converted_state_dict = {}
for key in unet_state_dict:
new_key = key
for pattern, new_pattern in MAPPING.items():
new_key = new_key.replace(pattern, new_pattern)
for dyn_pattern, dyn_new_pattern in DYNAMIC_MAP.items():
has_matched = False
if fnmatch.fnmatch(new_key, f"*.{dyn_pattern}.*") and not has_matched:
star = int(new_key.split(dyn_pattern.split(".")[0])[-1].split(".")[1])
if isinstance(dyn_new_pattern, tuple):
new_star = star + dyn_new_pattern[-1]
dyn_new_pattern = dyn_new_pattern[0]
else:
new_star = star
pattern = dyn_pattern.replace("*", str(star))
new_pattern = dyn_new_pattern.replace("*", str(new_star))
new_key = new_key.replace(pattern, new_pattern)
has_matched = True
converted_state_dict[new_key] = unet_state_dict[key]
return converted_state_dict
def main(model_path, output_path):
# Load your original U-Net model
unet_state_dict = load_file(model_path)
# Initialize your Kandinsky3UNet model
config = {}
# Convert the state dict
converted_state_dict = convert_state_dict(unet_state_dict)
unet = Kandinsky3UNet(config)
unet.load_state_dict(converted_state_dict)
unet.save_pretrained(output_path)
print(f"Converted model saved to {output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert U-Net PyTorch model to Kandinsky3UNet format")
parser.add_argument("--model_path", type=str, required=True, help="Path to the original U-Net PyTorch model")
parser.add_argument("--output_path", type=str, required=True, help="Path to save the converted model")
args = parser.parse_args()
main(args.model_path, args.output_path)
+730
View File
@@ -0,0 +1,730 @@
from diffusers.utils import is_accelerate_available, logging
if is_accelerate_available():
pass
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
"""
Creates a config for the diffusers based on the config of the LDM model.
"""
if controlnet:
unet_params = original_config.model.params.control_stage_config.params
else:
if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None:
unet_params = original_config.model.params.unet_config.params
else:
unet_params = original_config.model.params.network_config.params
vae_params = original_config.model.params.first_stage_config.params.encoder_config.params
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
down_block_types = []
resolution = 1
for i in range(len(block_out_channels)):
block_type = (
"CrossAttnDownBlockSpatioTemporal"
if resolution in unet_params.attention_resolutions
else "DownBlockSpatioTemporal"
)
down_block_types.append(block_type)
if i != len(block_out_channels) - 1:
resolution *= 2
up_block_types = []
for i in range(len(block_out_channels)):
block_type = (
"CrossAttnUpBlockSpatioTemporal"
if resolution in unet_params.attention_resolutions
else "UpBlockSpatioTemporal"
)
up_block_types.append(block_type)
resolution //= 2
if unet_params.transformer_depth is not None:
transformer_layers_per_block = (
unet_params.transformer_depth
if isinstance(unet_params.transformer_depth, int)
else list(unet_params.transformer_depth)
)
else:
transformer_layers_per_block = 1
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
head_dim = unet_params.num_heads if "num_heads" in unet_params else None
use_linear_projection = (
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
)
if use_linear_projection:
# stable diffusion 2-base-512 and 2-768
if head_dim is None:
head_dim_mult = unet_params.model_channels // unet_params.num_head_channels
head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)]
class_embed_type = None
addition_embed_type = None
addition_time_embed_dim = None
projection_class_embeddings_input_dim = None
context_dim = None
if unet_params.context_dim is not None:
context_dim = (
unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0]
)
if "num_classes" in unet_params:
if unet_params.num_classes == "sequential":
addition_time_embed_dim = 256
assert "adm_in_channels" in unet_params
projection_class_embeddings_input_dim = unet_params.adm_in_channels
config = {
"sample_size": image_size // vae_scale_factor,
"in_channels": unet_params.in_channels,
"down_block_types": tuple(down_block_types),
"block_out_channels": tuple(block_out_channels),
"layers_per_block": unet_params.num_res_blocks,
"cross_attention_dim": context_dim,
"attention_head_dim": head_dim,
"use_linear_projection": use_linear_projection,
"class_embed_type": class_embed_type,
"addition_embed_type": addition_embed_type,
"addition_time_embed_dim": addition_time_embed_dim,
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
"transformer_layers_per_block": transformer_layers_per_block,
}
if "disable_self_attentions" in unet_params:
config["only_cross_attention"] = unet_params.disable_self_attentions
if "num_classes" in unet_params and isinstance(unet_params.num_classes, int):
config["num_class_embeds"] = unet_params.num_classes
if controlnet:
config["conditioning_channels"] = unet_params.hint_channels
else:
config["out_channels"] = unet_params.out_channels
config["up_block_types"] = tuple(up_block_types)
return config
def assign_to_checkpoint(
paths,
checkpoint,
old_checkpoint,
attention_paths_to_split=None,
additional_replacements=None,
config=None,
mid_block_suffix="",
):
"""
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
attention layers, and takes into account additional replacements that may arise.
Assigns the weights to the new checkpoint.
"""
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
# Splits the attention layers into three variables.
if attention_paths_to_split is not None:
for path, path_map in attention_paths_to_split.items():
old_tensor = old_checkpoint[path]
channels = old_tensor.shape[0] // 3
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
query, key, value = old_tensor.split(channels // num_heads, dim=1)
checkpoint[path_map["query"]] = query.reshape(target_shape)
checkpoint[path_map["key"]] = key.reshape(target_shape)
checkpoint[path_map["value"]] = value.reshape(target_shape)
if mid_block_suffix is not None:
mid_block_suffix = f".{mid_block_suffix}"
else:
mid_block_suffix = ""
for path in paths:
new_path = path["new"]
# These have already been assigned
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
continue
# Global renaming happens here
new_path = new_path.replace("middle_block.0", f"mid_block.resnets.0{mid_block_suffix}")
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
new_path = new_path.replace("middle_block.2", f"mid_block.resnets.1{mid_block_suffix}")
if additional_replacements is not None:
for replacement in additional_replacements:
new_path = new_path.replace(replacement["old"], replacement["new"])
if new_path == "mid_block.resnets.0.spatial_res_block.norm1.weight":
print("yeyy")
# proj_attn.weight has to be converted from conv 1D to linear
is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path)
shape = old_checkpoint[path["old"]].shape
if is_attn_weight and len(shape) == 3:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
elif is_attn_weight and len(shape) == 4:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
else:
checkpoint[new_path] = old_checkpoint[path["old"]]
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside attentions to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
new_item = new_item.replace("time_stack", "temporal_transformer_blocks")
new_item = new_item.replace("time_pos_embed.0.bias", "time_pos_embed.linear_1.bias")
new_item = new_item.replace("time_pos_embed.0.weight", "time_pos_embed.linear_1.weight")
new_item = new_item.replace("time_pos_embed.2.bias", "time_pos_embed.linear_2.bias")
new_item = new_item.replace("time_pos_embed.2.weight", "time_pos_embed.linear_2.weight")
mapping.append({"old": old_item, "new": new_item})
return mapping
def shave_segments(path, n_shave_prefix_segments=1):
"""
Removes segments. Positive values shave the first segments, negative shave the last segments.
"""
if n_shave_prefix_segments >= 0:
return ".".join(path.split(".")[n_shave_prefix_segments:])
else:
return ".".join(path.split(".")[:n_shave_prefix_segments])
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside resnets to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item.replace("in_layers.0", "norm1")
new_item = new_item.replace("in_layers.2", "conv1")
new_item = new_item.replace("out_layers.0", "norm2")
new_item = new_item.replace("out_layers.3", "conv2")
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
new_item = new_item.replace("skip_connection", "conv_shortcut")
new_item = new_item.replace("time_stack.", "")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def convert_ldm_unet_checkpoint(
checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False
):
"""
Takes a state dict and a config, and returns a converted checkpoint.
"""
if skip_extract_state_dict:
unet_state_dict = checkpoint
else:
# extract state_dict for UNet
unet_state_dict = {}
keys = list(checkpoint.keys())
unet_key = "model.diffusion_model."
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.")
logger.warning(
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
)
for key in keys:
if key.startswith("model.diffusion_model"):
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
else:
if sum(k.startswith("model_ema") for k in keys) > 100:
logger.warning(
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
)
for key in keys:
if key.startswith(unet_key):
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
new_checkpoint = {}
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
if config["class_embed_type"] is None:
# No parameters to port
...
elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
else:
raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
# if config["addition_embed_type"] == "text_time":
new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
# Retrieves the keys for the input blocks only
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
input_blocks = {
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
for layer_id in range(num_input_blocks)
}
# Retrieves the keys for the middle blocks only
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
middle_blocks = {
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
for layer_id in range(num_middle_blocks)
}
# Retrieves the keys for the output blocks only
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
output_blocks = {
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
for layer_id in range(num_output_blocks)
}
for i in range(1, num_input_blocks):
block_id = (i - 1) // (config["layers_per_block"] + 1)
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
spatial_resnets = [
key
for key in input_blocks[i]
if f"input_blocks.{i}.0" in key
and (
f"input_blocks.{i}.0.op" not in key
and f"input_blocks.{i}.0.time_stack" not in key
and f"input_blocks.{i}.0.time_mixer" not in key
)
]
temporal_resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0.time_stack" in key]
# import ipdb; ipdb.set_trace()
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.weight"
)
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.bias"
)
paths = renew_resnet_paths(spatial_resnets)
meta_path = {
"old": f"input_blocks.{i}.0",
"new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}.spatial_res_block",
}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
paths = renew_resnet_paths(temporal_resnets)
meta_path = {
"old": f"input_blocks.{i}.0",
"new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}.temporal_res_block",
}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
# TODO resnet time_mixer.mix_factor
if f"input_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
new_checkpoint[
f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"
] = unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"]
if len(attentions):
paths = renew_attention_paths(attentions)
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
# import ipdb; ipdb.set_trace()
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
resnet_0 = middle_blocks[0]
attentions = middle_blocks[1]
resnet_1 = middle_blocks[2]
resnet_0_spatial = [key for key in resnet_0 if "time_stack" not in key and "time_mixer" not in key]
resnet_0_paths = renew_resnet_paths(resnet_0_spatial)
# import ipdb; ipdb.set_trace()
assign_to_checkpoint(
resnet_0_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="spatial_res_block"
)
resnet_0_temporal = [key for key in resnet_0 if "time_stack" in key and "time_mixer" not in key]
resnet_0_paths = renew_resnet_paths(resnet_0_temporal)
assign_to_checkpoint(
resnet_0_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="temporal_res_block"
)
resnet_1_spatial = [key for key in resnet_1 if "time_stack" not in key and "time_mixer" not in key]
resnet_1_paths = renew_resnet_paths(resnet_1_spatial)
assign_to_checkpoint(
resnet_1_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="spatial_res_block"
)
resnet_1_temporal = [key for key in resnet_1 if "time_stack" in key and "time_mixer" not in key]
resnet_1_paths = renew_resnet_paths(resnet_1_temporal)
assign_to_checkpoint(
resnet_1_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="temporal_res_block"
)
new_checkpoint["mid_block.resnets.0.time_mixer.mix_factor"] = unet_state_dict[
"middle_block.0.time_mixer.mix_factor"
]
new_checkpoint["mid_block.resnets.1.time_mixer.mix_factor"] = unet_state_dict[
"middle_block.2.time_mixer.mix_factor"
]
attentions_paths = renew_attention_paths(attentions)
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
for i in range(num_output_blocks):
block_id = i // (config["layers_per_block"] + 1)
layer_in_block_id = i % (config["layers_per_block"] + 1)
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
output_block_list = {}
for layer in output_block_layers:
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
if layer_id in output_block_list:
output_block_list[layer_id].append(layer_name)
else:
output_block_list[layer_id] = [layer_name]
if len(output_block_list) > 1:
spatial_resnets = [
key
for key in output_blocks[i]
if f"output_blocks.{i}.0" in key
and (f"output_blocks.{i}.0.time_stack" not in key and "time_mixer" not in key)
]
# import ipdb; ipdb.set_trace()
temporal_resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0.time_stack" in key]
paths = renew_resnet_paths(spatial_resnets)
meta_path = {
"old": f"output_blocks.{i}.0",
"new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}.spatial_res_block",
}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
paths = renew_resnet_paths(temporal_resnets)
meta_path = {
"old": f"output_blocks.{i}.0",
"new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}.temporal_res_block",
}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
if f"output_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
new_checkpoint[
f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"
] = unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"]
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
if ["conv.bias", "conv.weight"] in output_block_list.values():
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.weight"
]
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.bias"
]
# Clear attentions as they have been attributed above.
if len(attentions) == 2:
attentions = []
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key and "conv" not in key]
if len(attentions):
paths = renew_attention_paths(attentions)
# import ipdb; ipdb.set_trace()
meta_path = {
"old": f"output_blocks.{i}.1",
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
else:
spatial_layers = [
layer for layer in output_block_layers if "time_stack" not in layer and "time_mixer" not in layer
]
resnet_0_paths = renew_resnet_paths(spatial_layers, n_shave_prefix_segments=1)
# import ipdb; ipdb.set_trace()
for path in resnet_0_paths:
old_path = ".".join(["output_blocks", str(i), path["old"]])
new_path = ".".join(
["up_blocks", str(block_id), "resnets", str(layer_in_block_id), "spatial_res_block", path["new"]]
)
new_checkpoint[new_path] = unet_state_dict[old_path]
temporal_layers = [
layer for layer in output_block_layers if "time_stack" in layer and "time_mixer" not in key
]
resnet_0_paths = renew_resnet_paths(temporal_layers, n_shave_prefix_segments=1)
# import ipdb; ipdb.set_trace()
for path in resnet_0_paths:
old_path = ".".join(["output_blocks", str(i), path["old"]])
new_path = ".".join(
["up_blocks", str(block_id), "resnets", str(layer_in_block_id), "temporal_res_block", path["new"]]
)
new_checkpoint[new_path] = unet_state_dict[old_path]
new_checkpoint["up_blocks.0.resnets.0.time_mixer.mix_factor"] = unet_state_dict[
f"output_blocks.{str(i)}.0.time_mixer.mix_factor"
]
return new_checkpoint
def conv_attn_to_linear(checkpoint):
keys = list(checkpoint.keys())
attn_keys = ["to_q.weight", "to_k.weight", "to_v.weight"]
for key in keys:
if ".".join(key.split(".")[-2:]) in attn_keys:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0, 0]
elif "proj_attn.weight" in key:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0]
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0, is_temporal=False):
"""
Updates paths inside resnets to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
# Temporal resnet
new_item = old_item.replace("in_layers.0", "norm1")
new_item = new_item.replace("in_layers.2", "conv1")
new_item = new_item.replace("out_layers.0", "norm2")
new_item = new_item.replace("out_layers.3", "conv2")
new_item = new_item.replace("skip_connection", "conv_shortcut")
new_item = new_item.replace("time_stack.", "temporal_res_block.")
# Spatial resnet
new_item = new_item.replace("conv1", "spatial_res_block.conv1")
new_item = new_item.replace("norm1", "spatial_res_block.norm1")
new_item = new_item.replace("conv2", "spatial_res_block.conv2")
new_item = new_item.replace("norm2", "spatial_res_block.norm2")
new_item = new_item.replace("nin_shortcut", "spatial_res_block.conv_shortcut")
new_item = new_item.replace("mix_factor", "spatial_res_block.time_mixer.mix_factor")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside attentions to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
new_item = new_item.replace("norm.weight", "group_norm.weight")
new_item = new_item.replace("norm.bias", "group_norm.bias")
new_item = new_item.replace("q.weight", "to_q.weight")
new_item = new_item.replace("q.bias", "to_q.bias")
new_item = new_item.replace("k.weight", "to_k.weight")
new_item = new_item.replace("k.bias", "to_k.bias")
new_item = new_item.replace("v.weight", "to_v.weight")
new_item = new_item.replace("v.bias", "to_v.bias")
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def convert_ldm_vae_checkpoint(checkpoint, config):
# extract state dict for VAE
vae_state_dict = {}
keys = list(checkpoint.keys())
vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else ""
for key in keys:
if key.startswith(vae_key):
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
new_checkpoint = {}
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
new_checkpoint["decoder.time_conv_out.weight"] = vae_state_dict["decoder.time_mix_conv.weight"]
new_checkpoint["decoder.time_conv_out.bias"] = vae_state_dict["decoder.time_mix_conv.bias"]
# new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
# new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
# new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
# new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
# Retrieves the keys for the encoder down blocks only
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
down_blocks = {
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
}
# Retrieves the keys for the decoder up blocks only
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
up_blocks = {
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
}
for i in range(num_down_blocks):
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
f"encoder.down.{i}.downsample.conv.weight"
)
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
f"encoder.down.{i}.downsample.conv.bias"
)
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
paths = renew_vae_attention_paths(mid_attentions)
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
conv_attn_to_linear(new_checkpoint)
for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i
resnets = [
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
]
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.weight"
]
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.bias"
]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
paths = renew_vae_attention_paths(mid_attentions)
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
conv_attn_to_linear(new_checkpoint)
return new_checkpoint
+6 -25
View File
@@ -118,9 +118,10 @@ _deps = [
"pytest-timeout",
"pytest-xdist",
"python>=3.8.0",
"ruff>=0.1.5,<=0.2",
"ruff==0.1.5",
"safetensors>=0.3.1",
"sentencepiece>=0.1.91,!=0.1.92",
"GitPython<3.1.19",
"scipy",
"onnx",
"regex!=2019.12.17",
@@ -206,6 +207,7 @@ extras["docs"] = deps_list("hf-doc-builder")
extras["training"] = deps_list("accelerate", "datasets", "protobuf", "tensorboard", "Jinja2")
extras["test"] = deps_list(
"compel",
"GitPython",
"datasets",
"Jinja2",
"invisible-watermark",
@@ -249,13 +251,13 @@ version_range_max = max(sys.version_info[1], 10) + 1
setup(
name="diffusers",
version="0.24.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="0.25.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="State-of-the-art diffusion in PyTorch and JAX.",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
keywords="deep learning diffusion jax pytorch stable diffusion audioldm",
license="Apache",
author="The HuggingFace team",
license="Apache 2.0 License",
author="The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/diffusers/graphs/contributors)",
author_email="patrick@huggingface.co",
url="https://github.com/huggingface/diffusers",
package_dir={"": "src"},
@@ -279,24 +281,3 @@ setup(
+ [f"Programming Language :: Python :: 3.{i}" for i in range(8, version_range_max)],
cmdclass={"deps_table_update": DepsTableUpdateCommand},
)
# Release checklist
# 1. Change the version in __init__.py and setup.py.
# 2. Commit these changes with the message: "Release: Release"
# 3. Add a tag in git to mark the release: "git tag RELEASE -m 'Adds tag RELEASE for PyPI'"
# Push the tag to git: git push --tags origin main
# 4. Run the following commands in the top-level directory:
# python setup.py bdist_wheel
# python setup.py sdist
# 5. Upload the package to the PyPI test server first:
# twine upload dist/* -r pypitest
# twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/
# 6. Check that you can install it in a virtualenv by running:
# pip install -i https://testpypi.python.org/pypi diffusers
# diffusers env
# diffusers test
# 7. Upload the final version to the actual PyPI:
# twine upload dist/* -r pypi
# 8. Add release notes to the tag in GitHub once everything is looking hunky-dory.
# 9. Update the version in __init__.py, setup.py to the new version "-dev" and push to main.
+15 -1
View File
@@ -1,4 +1,4 @@
__version__ = "0.24.0.dev0"
__version__ = "0.25.0.dev0"
from typing import TYPE_CHECKING
@@ -76,9 +76,11 @@ else:
[
"AsymmetricAutoencoderKL",
"AutoencoderKL",
"AutoencoderKLTemporalDecoder",
"AutoencoderTiny",
"ConsistencyDecoderVAE",
"ControlNetModel",
"Kandinsky3UNet",
"ModelMixin",
"MotionAdapter",
"MultiAdapter",
@@ -91,6 +93,7 @@ else:
"UNet2DModel",
"UNet3DConditionModel",
"UNetMotionModel",
"UNetSpatioTemporalConditionModel",
"VQModel",
]
)
@@ -214,6 +217,8 @@ else:
"IFPipeline",
"IFSuperResolutionPipeline",
"ImageTextPipelineOutput",
"Kandinsky3Img2ImgPipeline",
"Kandinsky3Pipeline",
"KandinskyCombinedPipeline",
"KandinskyImg2ImgCombinedPipeline",
"KandinskyImg2ImgPipeline",
@@ -274,8 +279,10 @@ else:
"StableDiffusionXLPipeline",
"StableUnCLIPImg2ImgPipeline",
"StableUnCLIPPipeline",
"StableVideoDiffusionPipeline",
"TextToVideoSDPipeline",
"TextToVideoZeroPipeline",
"TextToVideoZeroSDXLPipeline",
"UnCLIPImageVariationPipeline",
"UnCLIPPipeline",
"UniDiffuserModel",
@@ -443,9 +450,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .models import (
AsymmetricAutoencoderKL,
AutoencoderKL,
AutoencoderKLTemporalDecoder,
AutoencoderTiny,
ConsistencyDecoderVAE,
ControlNetModel,
Kandinsky3UNet,
ModelMixin,
MotionAdapter,
MultiAdapter,
@@ -458,6 +467,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
UNet2DModel,
UNet3DConditionModel,
UNetMotionModel,
UNetSpatioTemporalConditionModel,
VQModel,
)
from .optimization import (
@@ -560,6 +570,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
IFPipeline,
IFSuperResolutionPipeline,
ImageTextPipelineOutput,
Kandinsky3Img2ImgPipeline,
Kandinsky3Pipeline,
KandinskyCombinedPipeline,
KandinskyImg2ImgCombinedPipeline,
KandinskyImg2ImgPipeline,
@@ -620,8 +632,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionXLPipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
StableVideoDiffusionPipeline,
TextToVideoSDPipeline,
TextToVideoZeroPipeline,
TextToVideoZeroSDXLPipeline,
UnCLIPImageVariationPipeline,
UnCLIPPipeline,
UniDiffuserModel,
+2 -1
View File
@@ -30,9 +30,10 @@ deps = {
"pytest-timeout": "pytest-timeout",
"pytest-xdist": "pytest-xdist",
"python": "python>=3.8.0",
"ruff": "ruff>=0.1.5,<=0.2",
"ruff": "ruff==0.1.5",
"safetensors": "safetensors>=0.3.1",
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
"GitPython": "GitPython<3.1.19",
"scipy": "scipy",
"onnx": "onnx",
"regex": "regex!=2019.12.17",
@@ -113,7 +113,7 @@ class ValueGuidedRLPipeline(DiffusionPipeline):
prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
# TODO: verify deprecation of this kwarg
x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
x = self.scheduler.step(prev_x, i, x)["prev_sample"]
# apply conditions to the trajectory (set the initial state)
x = self.reset_x0(x, conditions, self.action_dim)
+121 -1
View File
@@ -33,6 +33,15 @@ PipelineImageInput = Union[
List[torch.FloatTensor],
]
PipelineDepthInput = Union[
PIL.Image.Image,
np.ndarray,
torch.FloatTensor,
List[PIL.Image.Image],
List[np.ndarray],
List[torch.FloatTensor],
]
class VaeImageProcessor(ConfigMixin):
"""
@@ -326,7 +335,7 @@ class VaeImageProcessor(ConfigMixin):
# expected range [0,1], normalize to [-1,1]
do_normalize = self.config.do_normalize
if image.min() < 0 and do_normalize:
if do_normalize and image.min() < 0:
warnings.warn(
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
@@ -441,6 +450,18 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
return pil_images
@staticmethod
def depth_pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
"""
Convert a PIL image or a list of PIL images to NumPy arrays.
"""
if not isinstance(images, list):
images = [images]
images = [np.array(image).astype(np.float32) / (2**16 - 1) for image in images]
images = np.stack(images, axis=0)
return images
@staticmethod
def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
"""
@@ -526,3 +547,102 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
return self.numpy_to_pil(image), self.numpy_to_depth(image)
else:
raise Exception(f"This type {output_type} is not supported")
def preprocess(
self,
rgb: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
depth: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
height: Optional[int] = None,
width: Optional[int] = None,
target_res: Optional[int] = None,
) -> torch.Tensor:
"""
Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
"""
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
# Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
if self.config.do_convert_grayscale and isinstance(rgb, (torch.Tensor, np.ndarray)) and rgb.ndim == 3:
raise Exception("This is not yet supported")
if isinstance(rgb, supported_formats):
rgb = [rgb]
depth = [depth]
elif not (isinstance(rgb, list) and all(isinstance(i, supported_formats) for i in rgb)):
raise ValueError(
f"Input is in incorrect format: {[type(i) for i in rgb]}. Currently, we only support {', '.join(supported_formats)}"
)
if isinstance(rgb[0], PIL.Image.Image):
if self.config.do_convert_rgb:
raise Exception("This is not yet supported")
# rgb = [self.convert_to_rgb(i) for i in rgb]
# depth = [self.convert_to_depth(i) for i in depth] #TODO define convert_to_depth
if self.config.do_resize or target_res:
height, width = self.get_default_height_width(rgb[0], height, width) if not target_res else target_res
rgb = [self.resize(i, height, width) for i in rgb]
depth = [self.resize(i, height, width) for i in depth]
rgb = self.pil_to_numpy(rgb) # to np
rgb = self.numpy_to_pt(rgb) # to pt
depth = self.depth_pil_to_numpy(depth) # to np
depth = self.numpy_to_pt(depth) # to pt
elif isinstance(rgb[0], np.ndarray):
rgb = np.concatenate(rgb, axis=0) if rgb[0].ndim == 4 else np.stack(rgb, axis=0)
rgb = self.numpy_to_pt(rgb)
height, width = self.get_default_height_width(rgb, height, width)
if self.config.do_resize:
rgb = self.resize(rgb, height, width)
depth = np.concatenate(depth, axis=0) if rgb[0].ndim == 4 else np.stack(depth, axis=0)
depth = self.numpy_to_pt(depth)
height, width = self.get_default_height_width(depth, height, width)
if self.config.do_resize:
depth = self.resize(depth, height, width)
elif isinstance(rgb[0], torch.Tensor):
raise Exception("This is not yet supported")
# rgb = torch.cat(rgb, axis=0) if rgb[0].ndim == 4 else torch.stack(rgb, axis=0)
# if self.config.do_convert_grayscale and rgb.ndim == 3:
# rgb = rgb.unsqueeze(1)
# channel = rgb.shape[1]
# height, width = self.get_default_height_width(rgb, height, width)
# if self.config.do_resize:
# rgb = self.resize(rgb, height, width)
# depth = torch.cat(depth, axis=0) if depth[0].ndim == 4 else torch.stack(depth, axis=0)
# if self.config.do_convert_grayscale and depth.ndim == 3:
# depth = depth.unsqueeze(1)
# channel = depth.shape[1]
# # don't need any preprocess if the image is latents
# if depth == 4:
# return rgb, depth
# height, width = self.get_default_height_width(depth, height, width)
# if self.config.do_resize:
# depth = self.resize(depth, height, width)
# expected range [0,1], normalize to [-1,1]
do_normalize = self.config.do_normalize
if rgb.min() < 0 and do_normalize:
warnings.warn(
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{rgb.min()},{rgb.max()}]",
FutureWarning,
)
do_normalize = False
if do_normalize:
rgb = self.normalize(rgb)
depth = self.normalize(depth)
if self.config.do_binarize:
rgb = self.binarize(rgb)
depth = self.binarize(depth)
return rgb, depth
+4 -3
View File
@@ -8,7 +8,7 @@ def text_encoder_lora_state_dict(text_encoder):
deprecate(
"text_encoder_load_state_dict in `models`",
"0.27.0",
"`text_encoder_lora_state_dict` has been moved to `diffusers.models.lora`. Please make sure to import it via `from diffusers.models.lora import text_encoder_lora_state_dict`.",
"`text_encoder_lora_state_dict` is deprecated and will be removed in 0.27.0. Make sure to retrieve the weights using `get_peft_model`. See https://huggingface.co/docs/peft/v0.6.2/en/quicktour#peftmodel for more information.",
)
state_dict = {}
@@ -34,7 +34,7 @@ if is_transformers_available():
deprecate(
"text_encoder_attn_modules in `models`",
"0.27.0",
"`text_encoder_lora_state_dict` has been moved to `diffusers.models.lora`. Please make sure to import it via `from diffusers.models.lora import text_encoder_lora_state_dict`.",
"`text_encoder_lora_state_dict` is deprecated and will be removed in 0.27.0. Make sure to retrieve the weights using `get_peft_model`. See https://huggingface.co/docs/peft/v0.6.2/en/quicktour#peftmodel for more information.",
)
from transformers import CLIPTextModel, CLIPTextModelWithProjection
@@ -62,16 +62,17 @@ if is_torch_available():
_import_structure["single_file"].extend(["FromSingleFileMixin"])
_import_structure["lora"] = ["LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin"]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from ..models.lora import text_encoder_lora_state_dict
from .single_file import FromOriginalControlnetMixin, FromOriginalVAEMixin
from .unet import UNet2DConditionLoadersMixin
from .utils import AttnProcsLayers
if is_transformers_available():
from .ip_adapter import IPAdapterMixin
from .lora import LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin
from .single_file import FromSingleFileMixin
from .textual_inversion import TextualInversionLoaderMixin
+157
View File
@@ -0,0 +1,157 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Dict, Union
import torch
from safetensors import safe_open
from ..utils import (
DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
_get_model_file,
is_transformers_available,
logging,
)
if is_transformers_available():
from transformers import (
CLIPImageProcessor,
CLIPVisionModelWithProjection,
)
from ..models.attention_processor import (
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
)
logger = logging.get_logger(__name__)
class IPAdapterMixin:
"""Mixin for handling IP Adapters."""
def load_ip_adapter(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
subfolder: str,
weight_name: str,
**kwargs,
):
"""
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
incompletely downloaded files are deleted.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
"""
# Load the main state dict first.
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
if weight_name.endswith(".safetensors"):
state_dict = {"image_proj": {}, "ip_adapter": {}}
with safe_open(model_file, framework="pt", device="cpu") as f:
for key in f.keys():
if key.startswith("image_proj."):
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
else:
state_dict = torch.load(model_file, map_location="cpu")
else:
state_dict = pretrained_model_name_or_path_or_dict
keys = list(state_dict.keys())
if keys != ["image_proj", "ip_adapter"]:
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
# load CLIP image encoer here if it has not been registered to the pipeline yet
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
pretrained_model_name_or_path_or_dict,
subfolder=os.path.join(subfolder, "image_encoder"),
).to(self.device, dtype=self.dtype)
self.image_encoder = image_encoder
else:
raise ValueError("`image_encoder` cannot be None when using IP Adapters.")
# create feature extractor if it has not been registered to the pipeline yet
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
self.feature_extractor = CLIPImageProcessor()
# load ip-adapter into unet
self.unet._load_ip_adapter_weights(state_dict)
def set_ip_adapter_scale(self, scale):
for attn_processor in self.unet.attn_processors.values():
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
attn_processor.scale = scale
+96 -528
View File
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
from contextlib import nullcontext
from typing import Callable, Dict, List, Optional, Union
@@ -44,6 +43,7 @@ from ..utils import (
set_adapter_layers,
set_weights_and_activate_adapters,
)
from .lora_conversion_utils import _convert_kohya_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers
if is_transformers_available():
@@ -68,7 +68,8 @@ LORA_DEPRECATION_MESSAGE = "You are using an old version of LoRA backend. This w
class LoraLoaderMixin:
r"""
Load LoRA layers into [`UNet2DConditionModel`] and [`~transformers.CLIPTextModel`].
Load LoRA layers into [`UNet2DConditionModel`] and
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
"""
text_encoder_name = TEXT_ENCODER_NAME
@@ -94,28 +95,12 @@ class LoraLoaderMixin:
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
A string (model id of a pretrained model hosted on the Hub), a path to a directory containing the model
weights, or a [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
kwargs (`dict`, *optional*):
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
adapter_name (`str`, *optional*):
Name for referencing the loaded adapter model. If not specified, it will use `default_{i}` where `i` is
the total number of adapters being loaded. Must have PEFT installed to use.
Example:
```py
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to(
"cuda"
)
pipeline.load_lora_weights(
"Yntec/pineappleAnimeMix", weight_name="pineappleAnimeMix_pineapple10.1.safetensors", adapter_name="anime"
)
```
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
"""
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
@@ -153,7 +138,15 @@ class LoraLoaderMixin:
**kwargs,
):
r"""
Return state dict and network alphas of the LoRA weights.
Return state dict for lora weights and the network alphas.
<Tip warning={true}>
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
This function is experimental and might change in the future.
</Tip>
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
@@ -161,7 +154,8 @@ class LoraLoaderMixin:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
@@ -197,6 +191,7 @@ class LoraLoaderMixin:
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
information.
"""
# Load the main state dict first which has the LoRA layers for either of
# UNet and text encoder or both.
@@ -293,8 +288,8 @@ class LoraLoaderMixin:
# Map SDXL blocks correctly.
if unet_config is not None:
# use unet config to remap block numbers
state_dict = cls._maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
state_dict, network_alphas = cls._convert_kohya_lora_to_diffusers(state_dict)
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
state_dict, network_alphas = _convert_kohya_lora_to_diffusers(state_dict)
return state_dict, network_alphas
@@ -334,109 +329,6 @@ class LoraLoaderMixin:
weight_name = targeted_files[0]
return weight_name
@classmethod
def _maybe_map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", block_slice_pos=5):
# 1. get all state_dict_keys
all_keys = list(state_dict.keys())
sgm_patterns = ["input_blocks", "middle_block", "output_blocks"]
# 2. check if needs remapping, if not return original dict
is_in_sgm_format = False
for key in all_keys:
if any(p in key for p in sgm_patterns):
is_in_sgm_format = True
break
if not is_in_sgm_format:
return state_dict
# 3. Else remap from SGM patterns
new_state_dict = {}
inner_block_map = ["resnets", "attentions", "upsamplers"]
# Retrieves # of down, mid and up blocks
input_block_ids, middle_block_ids, output_block_ids = set(), set(), set()
for layer in all_keys:
if "text" in layer:
new_state_dict[layer] = state_dict.pop(layer)
else:
layer_id = int(layer.split(delimiter)[:block_slice_pos][-1])
if sgm_patterns[0] in layer:
input_block_ids.add(layer_id)
elif sgm_patterns[1] in layer:
middle_block_ids.add(layer_id)
elif sgm_patterns[2] in layer:
output_block_ids.add(layer_id)
else:
raise ValueError(f"Checkpoint not supported because layer {layer} not supported.")
input_blocks = {
layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key]
for layer_id in input_block_ids
}
middle_blocks = {
layer_id: [key for key in state_dict if f"middle_block{delimiter}{layer_id}" in key]
for layer_id in middle_block_ids
}
output_blocks = {
layer_id: [key for key in state_dict if f"output_blocks{delimiter}{layer_id}" in key]
for layer_id in output_block_ids
}
# Rename keys accordingly
for i in input_block_ids:
block_id = (i - 1) // (unet_config.layers_per_block + 1)
layer_in_block_id = (i - 1) % (unet_config.layers_per_block + 1)
for key in input_blocks[i]:
inner_block_id = int(key.split(delimiter)[block_slice_pos])
inner_block_key = inner_block_map[inner_block_id] if "op" not in key else "downsamplers"
inner_layers_in_block = str(layer_in_block_id) if "op" not in key else "0"
new_key = delimiter.join(
key.split(delimiter)[: block_slice_pos - 1]
+ [str(block_id), inner_block_key, inner_layers_in_block]
+ key.split(delimiter)[block_slice_pos + 1 :]
)
new_state_dict[new_key] = state_dict.pop(key)
for i in middle_block_ids:
key_part = None
if i == 0:
key_part = [inner_block_map[0], "0"]
elif i == 1:
key_part = [inner_block_map[1], "0"]
elif i == 2:
key_part = [inner_block_map[0], "1"]
else:
raise ValueError(f"Invalid middle block id {i}.")
for key in middle_blocks[i]:
new_key = delimiter.join(
key.split(delimiter)[: block_slice_pos - 1] + key_part + key.split(delimiter)[block_slice_pos:]
)
new_state_dict[new_key] = state_dict.pop(key)
for i in output_block_ids:
block_id = i // (unet_config.layers_per_block + 1)
layer_in_block_id = i % (unet_config.layers_per_block + 1)
for key in output_blocks[i]:
inner_block_id = int(key.split(delimiter)[block_slice_pos])
inner_block_key = inner_block_map[inner_block_id]
inner_layers_in_block = str(layer_in_block_id) if inner_block_id < 2 else "0"
new_key = delimiter.join(
key.split(delimiter)[: block_slice_pos - 1]
+ [str(block_id), inner_block_key, inner_layers_in_block]
+ key.split(delimiter)[block_slice_pos + 1 :]
)
new_state_dict[new_key] = state_dict.pop(key)
if len(state_dict) > 0:
raise ValueError("At this point all state dict entries have to be converted.")
return new_state_dict
@classmethod
def _optionally_disable_offloading(cls, _pipeline):
"""
@@ -473,27 +365,25 @@ class LoraLoaderMixin:
cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
):
"""
Load LoRA layers specified in `state_dict` into `unet`.
This will load the LoRA layers specified in `state_dict` into `unet`.
Parameters:
state_dict (`dict`):
A standard state dict containing the LoRA layer parameters. The keys can either be indexed directly
into the `unet` or prefixed with an additional `unet`, which can be used to distinguish between text
encoder LoRA layers.
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
encoder lora layers.
network_alphas (`Dict[str, float]`):
See
[`LoRALinearLayer`](https://github.com/huggingface/diffusers/blob/c697f524761abd2314c030221a3ad2f7791eab4e/src/diffusers/models/lora.py#L182)
for more details.
See `LoRALinearLayer` for more details.
unet (`UNet2DConditionModel`):
The UNet model to load the LoRA layers into.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Only load and not initialize the pretrained weights. This can speedup model loading and also tries to
not use more than 1x model size in CPU memory (including peak memory) while loading the model. Only
supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this argument to
`True` will raise an error.
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
adapter_name (`str`, *optional*):
Name for referencing the loaded adapter model. If not specified, it will use `default_{i}` where `i` is
the total number of adapters being loaded.
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
"""
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
@@ -501,6 +391,10 @@ class LoraLoaderMixin:
# their prefixes.
keys = list(state_dict.keys())
if all(key.startswith("unet.unet") for key in keys):
deprecation_message = "Keys starting with 'unet.unet' are deprecated."
deprecate("unet.unet keys", "0.27", deprecation_message)
if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys):
# Load the layers corresponding to UNet.
logger.info(f"Loading {cls.unet_name}.")
@@ -517,8 +411,9 @@ class LoraLoaderMixin:
else:
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
# contain the module names of the `unet` as its keys WITHOUT any prefix.
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
logger.warn(warn_message)
if not USE_PEFT_BACKEND:
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
logger.warn(warn_message)
if USE_PEFT_BACKEND and len(state_dict.keys()) > 0:
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
@@ -587,27 +482,26 @@ class LoraLoaderMixin:
_pipeline=None,
):
"""
Load LoRA layers specified in `state_dict` into `text_encoder`.
This will load the LoRA layers specified in `state_dict` into `text_encoder`
Parameters:
state_dict (`dict`):
A standard state dict containing the LoRA layer parameters. The key should be prefixed with an
additional `text_encoder` to distinguish between UNet LoRA layers.
A standard state dict containing the lora layer parameters. The key should be prefixed with an
additional `text_encoder` to distinguish between unet lora layers.
network_alphas (`Dict[str, float]`):
See
[`LoRALinearLayer`](https://github.com/huggingface/diffusers/blob/c697f524761abd2314c030221a3ad2f7791eab4e/src/diffusers/models/lora.py#L182)
for more details.
See `LoRALinearLayer` for more details.
text_encoder (`CLIPTextModel`):
The text encoder model to load the LoRA layers into.
prefix (`str`):
Expected prefix of the `text_encoder` in the `state_dict`.
lora_scale (`float`):
Scale of `LoRALinearLayer`'s output before it is added with the output of the regular LoRA layer.
How much to scale the output of the lora linear layer before it is added with the output of the regular
lora layer.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Only load and not initialize the pretrained weights. This can speedup model loading and also tries to
not use more than 1x model size in CPU memory (including peak memory) while loading the model. Only
supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this argument to
`True` will raise an error.
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
@@ -786,8 +680,7 @@ class LoraLoaderMixin:
@classmethod
def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
if version.parse(__version__) > version.parse("0.23"):
deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.25", LORA_DEPRECATION_MESSAGE)
deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.27", LORA_DEPRECATION_MESSAGE)
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
@@ -815,8 +708,7 @@ class LoraLoaderMixin:
r"""
Monkey-patches the forward passes of attention modules of the text encoder.
"""
if version.parse(__version__) > version.parse("0.23"):
deprecate("_modify_text_encoder", "0.25", LORA_DEPRECATION_MESSAGE)
deprecate("_modify_text_encoder", "0.27", LORA_DEPRECATION_MESSAGE)
def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters):
linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model
@@ -892,11 +784,11 @@ class LoraLoaderMixin:
safe_serialization: bool = True,
):
r"""
Save the UNet and text encoder LoRA parameters.
Save the LoRA parameters corresponding to the UNet and text encoder.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to save LoRA parameters to (will be created if it doesn't exist).
Directory to save LoRA parameters to. Will be created if it doesn't exist.
unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `unet`.
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
@@ -907,54 +799,27 @@ class LoraLoaderMixin:
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
process to avoid race conditions.
save_function (`Callable`):
The function to use to save the state dict. Useful during distributed training when you need to replace
`torch.save` with another method. Can be configured with the environment variable
The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or with `pickle`.
Example:
```py
from diffusers import StableDiffusionXLPipeline
from peft.utils import get_peft_model_state_dict
import torch
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipeline.fuse_lora()
# get and save unet state dict
unet_state_dict = get_peft_model_state_dict(pipeline.unet, adapter_name="pixel")
pipeline.save_lora_weights("fused-model", unet_lora_layers=unet_state_dict)
pipeline.load_lora_weights("fused-model", weight_name="pytorch_lora_weights.safetensors")
```
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
"""
# Create a flat dictionary.
state_dict = {}
# Populate the dictionary.
if unet_lora_layers is not None:
weights = (
unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers
)
def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict
unet_lora_state_dict = {f"{cls.unet_name}.{module_name}": param for module_name, param in weights.items()}
state_dict.update(unet_lora_state_dict)
if not (unet_lora_layers or text_encoder_lora_layers):
raise ValueError("You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`.")
if text_encoder_lora_layers is not None:
weights = (
text_encoder_lora_layers.state_dict()
if isinstance(text_encoder_lora_layers, torch.nn.Module)
else text_encoder_lora_layers
)
if unet_lora_layers:
state_dict.update(pack_weights(unet_lora_layers, "unet"))
text_encoder_lora_state_dict = {
f"{cls.text_encoder_name}.{module_name}": param for module_name, param in weights.items()
}
state_dict.update(text_encoder_lora_state_dict)
if text_encoder_lora_layers:
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
# Save the model
cls.write_lora_layers(
@@ -999,186 +864,16 @@ class LoraLoaderMixin:
save_function(state_dict, os.path.join(save_directory, weight_name))
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
@classmethod
def _convert_kohya_lora_to_diffusers(cls, state_dict):
unet_state_dict = {}
te_state_dict = {}
te2_state_dict = {}
network_alphas = {}
# every down weight has a corresponding up weight and potentially an alpha weight
lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")]
for key in lora_keys:
lora_name = key.split(".")[0]
lora_name_up = lora_name + ".lora_up.weight"
lora_name_alpha = lora_name + ".alpha"
if lora_name.startswith("lora_unet_"):
diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
if "input.blocks" in diffusers_name:
diffusers_name = diffusers_name.replace("input.blocks", "down_blocks")
else:
diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
if "middle.block" in diffusers_name:
diffusers_name = diffusers_name.replace("middle.block", "mid_block")
else:
diffusers_name = diffusers_name.replace("mid.block", "mid_block")
if "output.blocks" in diffusers_name:
diffusers_name = diffusers_name.replace("output.blocks", "up_blocks")
else:
diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
diffusers_name = diffusers_name.replace("proj.in", "proj_in")
diffusers_name = diffusers_name.replace("proj.out", "proj_out")
diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
# SDXL specificity.
if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name:
pattern = r"\.\d+(?=\D*$)"
diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
if ".in." in diffusers_name:
diffusers_name = diffusers_name.replace("in.layers.2", "conv1")
if ".out." in diffusers_name:
diffusers_name = diffusers_name.replace("out.layers.3", "conv2")
if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name:
diffusers_name = diffusers_name.replace("op", "conv")
if "skip" in diffusers_name:
diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
# LyCORIS specificity.
if "time.emb.proj" in diffusers_name:
diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
if "conv.shortcut" in diffusers_name:
diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")
# General coverage.
if "transformer_blocks" in diffusers_name:
if "attn1" in diffusers_name or "attn2" in diffusers_name:
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
unet_state_dict[diffusers_name] = state_dict.pop(key)
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif "ff" in diffusers_name:
unet_state_dict[diffusers_name] = state_dict.pop(key)
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
unet_state_dict[diffusers_name] = state_dict.pop(key)
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
else:
unet_state_dict[diffusers_name] = state_dict.pop(key)
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif lora_name.startswith("lora_te_"):
diffusers_name = key.replace("lora_te_", "").replace("_", ".")
diffusers_name = diffusers_name.replace("text.model", "text_model")
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
if "self_attn" in diffusers_name:
te_state_dict[diffusers_name] = state_dict.pop(key)
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif "mlp" in diffusers_name:
# Be aware that this is the new diffusers convention and the rest of the code might
# not utilize it yet.
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
te_state_dict[diffusers_name] = state_dict.pop(key)
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
# (sayakpaul): Duplicate code. Needs to be cleaned.
elif lora_name.startswith("lora_te1_"):
diffusers_name = key.replace("lora_te1_", "").replace("_", ".")
diffusers_name = diffusers_name.replace("text.model", "text_model")
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
if "self_attn" in diffusers_name:
te_state_dict[diffusers_name] = state_dict.pop(key)
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif "mlp" in diffusers_name:
# Be aware that this is the new diffusers convention and the rest of the code might
# not utilize it yet.
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
te_state_dict[diffusers_name] = state_dict.pop(key)
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
# (sayakpaul): Duplicate code. Needs to be cleaned.
elif lora_name.startswith("lora_te2_"):
diffusers_name = key.replace("lora_te2_", "").replace("_", ".")
diffusers_name = diffusers_name.replace("text.model", "text_model")
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
if "self_attn" in diffusers_name:
te2_state_dict[diffusers_name] = state_dict.pop(key)
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif "mlp" in diffusers_name:
# Be aware that this is the new diffusers convention and the rest of the code might
# not utilize it yet.
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
te2_state_dict[diffusers_name] = state_dict.pop(key)
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
# Rename the alphas so that they can be mapped appropriately.
if lora_name_alpha in state_dict:
alpha = state_dict.pop(lora_name_alpha).item()
if lora_name_alpha.startswith("lora_unet_"):
prefix = "unet."
elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")):
prefix = "text_encoder."
else:
prefix = "text_encoder_2."
new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
network_alphas.update({new_name: alpha})
if len(state_dict) > 0:
raise ValueError(
f"The following keys have not been correctly be renamed: \n\n {', '.join(state_dict.keys())}"
)
logger.info("Kohya-style checkpoint detected.")
unet_state_dict = {f"{cls.unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
te_state_dict = {
f"{cls.text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()
}
te2_state_dict = (
{f"text_encoder_2.{module_name}": params for module_name, params in te2_state_dict.items()}
if len(te2_state_dict) > 0
else None
)
if te2_state_dict is not None:
te_state_dict.update(te2_state_dict)
new_state_dict = {**unet_state_dict, **te_state_dict}
return new_state_dict, network_alphas
def unload_lora_weights(self):
"""
Unload the LoRA parameters from a pipeline.
Unloads the LoRA parameters.
Examples:
```py
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipeline.unload_lora_weights()
```python
>>> # Assuming `pipeline` is already loaded with the LoRA parameters.
>>> pipeline.unload_lora_weights()
>>> ...
```
"""
if not USE_PEFT_BACKEND:
@@ -1207,7 +902,7 @@ class LoraLoaderMixin:
safe_fusing: bool = False,
):
r"""
Fuse the LoRA parameters with the original parameters in their corresponding blocks.
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
<Tip warning={true}>
@@ -1221,23 +916,9 @@ class LoraLoaderMixin:
Whether to fuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
lora_scale (`float`, defaults to 1.0):
Controls LoRA influence on the outputs.
Controls how much to influence the outputs with the LoRA parameters.
safe_fusing (`bool`, defaults to `False`):
Whether to check fused weights for `NaN` values before fusing and if values are `NaN`, then don't fuse
them.
Example:
```py
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipeline.fuse_lora(lora_scale=0.7)
```
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
"""
if fuse_unet or fuse_text_encoder:
self.num_fused_loras += 1
@@ -1262,8 +943,7 @@ class LoraLoaderMixin:
module.merge()
else:
if version.parse(__version__) > version.parse("0.23"):
deprecate("fuse_text_encoder_lora", "0.25", LORA_DEPRECATION_MESSAGE)
deprecate("fuse_text_encoder_lora", "0.27", LORA_DEPRECATION_MESSAGE)
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False):
for _, attn_module in text_encoder_attn_modules(text_encoder):
@@ -1286,7 +966,8 @@ class LoraLoaderMixin:
def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True):
r"""
Unfuse the LoRA parameters from the original parameters in their corresponding blocks.
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.fuse_lora).
<Tip warning={true}>
@@ -1299,20 +980,6 @@ class LoraLoaderMixin:
unfuse_text_encoder (`bool`, defaults to `True`):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
Example:
```py
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipeline.fuse_lora(lora_scale=0.7)
pipeline.unfuse_lora()
```
"""
if unfuse_unet:
if not USE_PEFT_BACKEND:
@@ -1333,8 +1000,7 @@ class LoraLoaderMixin:
module.unmerge()
else:
if version.parse(__version__) > version.parse("0.23"):
deprecate("unfuse_text_encoder_lora", "0.25", LORA_DEPRECATION_MESSAGE)
deprecate("unfuse_text_encoder_lora", "0.27", LORA_DEPRECATION_MESSAGE)
def unfuse_text_encoder_lora(text_encoder):
for _, attn_module in text_encoder_attn_modules(text_encoder):
@@ -1364,32 +1030,16 @@ class LoraLoaderMixin:
text_encoder_weights: List[float] = None,
):
"""
Set the currently active adapter for use in the text encoder.
Sets the adapter layers for the text encoder.
Args:
adapter_names (`List[str]` or `str`):
The adapter to activate.
The names of the adapters to use.
text_encoder (`torch.nn.Module`, *optional*):
The text encoder module to activate the adapter layers for. If `None`, it will try to get the
`text_encoder` attribute.
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
attribute.
text_encoder_weights (`List[float]`, *optional*):
The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters.
Example:
```py
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipeline.load_lora_weights(
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
)
pipeline.set_adapters_for_text_encoder("pixel")
```
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -1417,25 +1067,12 @@ class LoraLoaderMixin:
def disable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None):
"""
Disable the text encoder's LoRA layers.
Disables the LoRA layers for the text encoder.
Args:
text_encoder (`torch.nn.Module`, *optional*):
The text encoder module to disable the LoRA layers for. If `None`, it will try to get the
`text_encoder` attribute.
Example:
```py
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipeline.disable_lora_for_text_encoder()
```
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -1447,25 +1084,12 @@ class LoraLoaderMixin:
def enable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None):
"""
Enables the text encoder's LoRA layers.
Enables the LoRA layers for the text encoder.
Args:
text_encoder (`torch.nn.Module`, *optional*):
The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder`
attribute.
Example:
```py
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipeline.enable_lora_for_text_encoder()
```
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -1516,24 +1140,10 @@ class LoraLoaderMixin:
def delete_adapters(self, adapter_names: Union[List[str], str]):
"""
Delete an adapter's LoRA layers from the UNet and text encoder(s).
Args:
Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s).
adapter_names (`Union[List[str], str]`):
The names (single string or list of strings) of the adapter to delete.
Example:
```py
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipeline.delete_adapters("pixel")
```
The names of the adapter to delete. Can be a single string or a list of strings
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -1553,7 +1163,7 @@ class LoraLoaderMixin:
def get_active_adapters(self) -> List[str]:
"""
Get a list of currently active adapters.
Gets the list of the current active adapters.
Example:
@@ -1585,22 +1195,7 @@ class LoraLoaderMixin:
def get_list_adapters(self) -> Dict[str, List[str]]:
"""
Get a list of all currently available adapters for each component in the pipeline.
Example:
```py
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
).to("cuda")
pipeline.load_lora_weights(
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
)
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipeline.get_list_adapters()
```
Gets the current list of all available adapters in the pipeline.
"""
if not USE_PEFT_BACKEND:
raise ValueError(
@@ -1622,27 +1217,14 @@ class LoraLoaderMixin:
def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None:
"""
Move a LoRA to a target device. Useful for offloading a LoRA to the CPU in case you want to load multiple
adapters and free some GPU memory.
Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
you want to load multiple adapters and free some GPU memory.
Args:
adapter_names (`List[str]`):
List of adapters to send to device.
List of adapters to send device to.
device (`Union[torch.device, str, int]`):
Device (can be a `torch.device`, `str` or `int`) to place adapters on.
Example:
```py
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
).to("cuda")
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipeline.set_lora_device(["pixel"], device="cuda")
```
Device to send the adapters to. Can be either a torch device, a str or an integer.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -1674,7 +1256,7 @@ class LoraLoaderMixin:
class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
"""This class overrides [`LoraLoaderMixin`] with LoRA loading/saving code that's specific to SDXL."""
"""This class overrides `LoraLoaderMixin` with LoRA loading/saving code that's specific to SDXL"""
# Overrride to properly handle the loading and unloading of the additional text encoder.
def load_lora_weights(
@@ -1699,26 +1281,12 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
A string (model id of a pretrained model hosted on the Hub), a path to a directory containing the model
weights, or a [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
kwargs (`dict`, *optional*):
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
adapter_name (`str`, *optional*):
Name for referencing the loaded adapter model. If not specified, it will use `default_{i}` where `i` is
the total number of adapters being loaded. Must have PEFT installed to use.
Example:
```py
from diffusers import StableDiffusionXLPipeline
import torch
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
```
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
kwargs (`dict`, *optional*):
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
"""
# We could have accessed the unet config from `lora_state_dict()` too. We pass
# it here explicitly to be able to tell that it's coming from an SDXL
@@ -0,0 +1,284 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from ..utils import logging
logger = logging.get_logger(__name__)
def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", block_slice_pos=5):
# 1. get all state_dict_keys
all_keys = list(state_dict.keys())
sgm_patterns = ["input_blocks", "middle_block", "output_blocks"]
# 2. check if needs remapping, if not return original dict
is_in_sgm_format = False
for key in all_keys:
if any(p in key for p in sgm_patterns):
is_in_sgm_format = True
break
if not is_in_sgm_format:
return state_dict
# 3. Else remap from SGM patterns
new_state_dict = {}
inner_block_map = ["resnets", "attentions", "upsamplers"]
# Retrieves # of down, mid and up blocks
input_block_ids, middle_block_ids, output_block_ids = set(), set(), set()
for layer in all_keys:
if "text" in layer:
new_state_dict[layer] = state_dict.pop(layer)
else:
layer_id = int(layer.split(delimiter)[:block_slice_pos][-1])
if sgm_patterns[0] in layer:
input_block_ids.add(layer_id)
elif sgm_patterns[1] in layer:
middle_block_ids.add(layer_id)
elif sgm_patterns[2] in layer:
output_block_ids.add(layer_id)
else:
raise ValueError(f"Checkpoint not supported because layer {layer} not supported.")
input_blocks = {
layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key]
for layer_id in input_block_ids
}
middle_blocks = {
layer_id: [key for key in state_dict if f"middle_block{delimiter}{layer_id}" in key]
for layer_id in middle_block_ids
}
output_blocks = {
layer_id: [key for key in state_dict if f"output_blocks{delimiter}{layer_id}" in key]
for layer_id in output_block_ids
}
# Rename keys accordingly
for i in input_block_ids:
block_id = (i - 1) // (unet_config.layers_per_block + 1)
layer_in_block_id = (i - 1) % (unet_config.layers_per_block + 1)
for key in input_blocks[i]:
inner_block_id = int(key.split(delimiter)[block_slice_pos])
inner_block_key = inner_block_map[inner_block_id] if "op" not in key else "downsamplers"
inner_layers_in_block = str(layer_in_block_id) if "op" not in key else "0"
new_key = delimiter.join(
key.split(delimiter)[: block_slice_pos - 1]
+ [str(block_id), inner_block_key, inner_layers_in_block]
+ key.split(delimiter)[block_slice_pos + 1 :]
)
new_state_dict[new_key] = state_dict.pop(key)
for i in middle_block_ids:
key_part = None
if i == 0:
key_part = [inner_block_map[0], "0"]
elif i == 1:
key_part = [inner_block_map[1], "0"]
elif i == 2:
key_part = [inner_block_map[0], "1"]
else:
raise ValueError(f"Invalid middle block id {i}.")
for key in middle_blocks[i]:
new_key = delimiter.join(
key.split(delimiter)[: block_slice_pos - 1] + key_part + key.split(delimiter)[block_slice_pos:]
)
new_state_dict[new_key] = state_dict.pop(key)
for i in output_block_ids:
block_id = i // (unet_config.layers_per_block + 1)
layer_in_block_id = i % (unet_config.layers_per_block + 1)
for key in output_blocks[i]:
inner_block_id = int(key.split(delimiter)[block_slice_pos])
inner_block_key = inner_block_map[inner_block_id]
inner_layers_in_block = str(layer_in_block_id) if inner_block_id < 2 else "0"
new_key = delimiter.join(
key.split(delimiter)[: block_slice_pos - 1]
+ [str(block_id), inner_block_key, inner_layers_in_block]
+ key.split(delimiter)[block_slice_pos + 1 :]
)
new_state_dict[new_key] = state_dict.pop(key)
if len(state_dict) > 0:
raise ValueError("At this point all state dict entries have to be converted.")
return new_state_dict
def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"):
unet_state_dict = {}
te_state_dict = {}
te2_state_dict = {}
network_alphas = {}
# every down weight has a corresponding up weight and potentially an alpha weight
lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")]
for key in lora_keys:
lora_name = key.split(".")[0]
lora_name_up = lora_name + ".lora_up.weight"
lora_name_alpha = lora_name + ".alpha"
if lora_name.startswith("lora_unet_"):
diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
if "input.blocks" in diffusers_name:
diffusers_name = diffusers_name.replace("input.blocks", "down_blocks")
else:
diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
if "middle.block" in diffusers_name:
diffusers_name = diffusers_name.replace("middle.block", "mid_block")
else:
diffusers_name = diffusers_name.replace("mid.block", "mid_block")
if "output.blocks" in diffusers_name:
diffusers_name = diffusers_name.replace("output.blocks", "up_blocks")
else:
diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
diffusers_name = diffusers_name.replace("proj.in", "proj_in")
diffusers_name = diffusers_name.replace("proj.out", "proj_out")
diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
# SDXL specificity.
if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name:
pattern = r"\.\d+(?=\D*$)"
diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
if ".in." in diffusers_name:
diffusers_name = diffusers_name.replace("in.layers.2", "conv1")
if ".out." in diffusers_name:
diffusers_name = diffusers_name.replace("out.layers.3", "conv2")
if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name:
diffusers_name = diffusers_name.replace("op", "conv")
if "skip" in diffusers_name:
diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
# LyCORIS specificity.
if "time.emb.proj" in diffusers_name:
diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
if "conv.shortcut" in diffusers_name:
diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")
# General coverage.
if "transformer_blocks" in diffusers_name:
if "attn1" in diffusers_name or "attn2" in diffusers_name:
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
unet_state_dict[diffusers_name] = state_dict.pop(key)
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif "ff" in diffusers_name:
unet_state_dict[diffusers_name] = state_dict.pop(key)
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
unet_state_dict[diffusers_name] = state_dict.pop(key)
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
else:
unet_state_dict[diffusers_name] = state_dict.pop(key)
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif lora_name.startswith("lora_te_"):
diffusers_name = key.replace("lora_te_", "").replace("_", ".")
diffusers_name = diffusers_name.replace("text.model", "text_model")
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
if "self_attn" in diffusers_name:
te_state_dict[diffusers_name] = state_dict.pop(key)
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif "mlp" in diffusers_name:
# Be aware that this is the new diffusers convention and the rest of the code might
# not utilize it yet.
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
te_state_dict[diffusers_name] = state_dict.pop(key)
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
# (sayakpaul): Duplicate code. Needs to be cleaned.
elif lora_name.startswith("lora_te1_"):
diffusers_name = key.replace("lora_te1_", "").replace("_", ".")
diffusers_name = diffusers_name.replace("text.model", "text_model")
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
if "self_attn" in diffusers_name:
te_state_dict[diffusers_name] = state_dict.pop(key)
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif "mlp" in diffusers_name:
# Be aware that this is the new diffusers convention and the rest of the code might
# not utilize it yet.
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
te_state_dict[diffusers_name] = state_dict.pop(key)
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
# (sayakpaul): Duplicate code. Needs to be cleaned.
elif lora_name.startswith("lora_te2_"):
diffusers_name = key.replace("lora_te2_", "").replace("_", ".")
diffusers_name = diffusers_name.replace("text.model", "text_model")
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
if "self_attn" in diffusers_name:
te2_state_dict[diffusers_name] = state_dict.pop(key)
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif "mlp" in diffusers_name:
# Be aware that this is the new diffusers convention and the rest of the code might
# not utilize it yet.
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
te2_state_dict[diffusers_name] = state_dict.pop(key)
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
# Rename the alphas so that they can be mapped appropriately.
if lora_name_alpha in state_dict:
alpha = state_dict.pop(lora_name_alpha).item()
if lora_name_alpha.startswith("lora_unet_"):
prefix = "unet."
elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")):
prefix = "text_encoder."
else:
prefix = "text_encoder_2."
new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
network_alphas.update({new_name: alpha})
if len(state_dict) > 0:
raise ValueError(f"The following keys have not been correctly be renamed: \n\n {', '.join(state_dict.keys())}")
logger.info("Kohya-style checkpoint detected.")
unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
te_state_dict = {f"{text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()}
te2_state_dict = (
{f"text_encoder_2.{module_name}": params for module_name, params in te2_state_dict.items()}
if len(te2_state_dict) > 0
else None
)
if te2_state_dict is not None:
te_state_dict.update(te2_state_dict)
new_state_dict = {**unet_state_dict, **te_state_dict}
return new_state_dict, network_alphas
+14 -2
View File
@@ -189,7 +189,7 @@ class TextualInversionLoaderMixin:
f" `{self.load_textual_inversion.__name__}`"
)
if len(pretrained_model_name_or_paths) != len(tokens):
if len(pretrained_model_name_or_paths) > 1 and len(pretrained_model_name_or_paths) != len(tokens):
raise ValueError(
f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)} "
f"Make sure both lists have the same length."
@@ -382,7 +382,9 @@ class TextualInversionLoaderMixin:
if not isinstance(pretrained_model_name_or_path, list)
else pretrained_model_name_or_path
)
tokens = len(pretrained_model_name_or_paths) * [token] if (isinstance(token, str) or token is None) else token
tokens = [token] if not isinstance(token, list) else token
if tokens[0] is None:
tokens = tokens * len(pretrained_model_name_or_paths)
# 3. Check inputs
self._check_text_inv_inputs(tokenizer, text_encoder, pretrained_model_name_or_paths, tokens)
@@ -390,6 +392,16 @@ class TextualInversionLoaderMixin:
# 4. Load state dicts of textual embeddings
state_dicts = load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
# 4.1 Handle the special case when state_dict is a tensor that contains n embeddings for n tokens
if len(tokens) > 1 and len(state_dicts) == 1:
if isinstance(state_dicts[0], torch.Tensor):
state_dicts = list(state_dicts[0])
if len(tokens) != len(state_dicts):
raise ValueError(
f"You have passed a state_dict contains {len(state_dicts)} embeddings, and list of tokens of length {len(tokens)} "
f"Make sure both have the same length."
)
# 4. Retrieve tokens and embeddings
tokens, embeddings = self._retrieve_tokens_and_embeddings(tokens, state_dicts, tokenizer)
+135 -1
View File
@@ -12,14 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections import defaultdict
from collections import OrderedDict, defaultdict
from contextlib import nullcontext
from typing import Callable, Dict, List, Optional, Union
import safetensors
import torch
import torch.nn.functional as F
from torch import nn
from ..models.embeddings import ImageProjection, Resampler
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import (
DIFFUSERS_CACHE,
@@ -662,4 +664,136 @@ class UNet2DConditionLoadersMixin:
if hasattr(self, "peft_config"):
self.peft_config.pop(adapter_name, None)
def _load_ip_adapter_weights(self, state_dict):
from ..models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
)
if "proj.weight" in state_dict["image_proj"]:
# IP-Adapter
num_image_text_embeds = 4
else:
# IP-Adapter Plus
num_image_text_embeds = state_dict["image_proj"]["latents"].shape[1]
# Set encoder_hid_proj after loading ip_adapter weights,
# because `Resampler` also has `attn_processors`.
self.encoder_hid_proj = None
# set ip-adapter cross-attention processors & load state_dict
attn_procs = {}
key_id = 1
for name in self.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else self.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = self.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(self.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = self.config.block_out_channels[block_id]
if cross_attention_dim is None or "motion_modules" in name:
attn_processor_class = (
AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
)
attn_procs[name] = attn_processor_class()
else:
attn_processor_class = (
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
)
attn_procs[name] = attn_processor_class(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
num_tokens=num_image_text_embeds,
).to(dtype=self.dtype, device=self.device)
value_dict = {}
for k, w in attn_procs[name].state_dict().items():
value_dict.update({f"{k}": state_dict["ip_adapter"][f"{key_id}.{k}"]})
attn_procs[name].load_state_dict(value_dict)
key_id += 2
self.set_attn_processor(attn_procs)
# create image projection layers.
if "proj.weight" in state_dict["image_proj"]:
# IP-Adapter
clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1]
cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4
image_projection = ImageProjection(
cross_attention_dim=cross_attention_dim,
image_embed_dim=clip_embeddings_dim,
num_image_text_embeds=num_image_text_embeds,
)
image_projection.to(dtype=self.dtype, device=self.device)
# load image projection layer weights
image_proj_state_dict = {}
image_proj_state_dict.update(
{
"image_embeds.weight": state_dict["image_proj"]["proj.weight"],
"image_embeds.bias": state_dict["image_proj"]["proj.bias"],
"norm.weight": state_dict["image_proj"]["norm.weight"],
"norm.bias": state_dict["image_proj"]["norm.bias"],
}
)
image_projection.load_state_dict(image_proj_state_dict)
else:
# IP-Adapter Plus
embed_dims = state_dict["image_proj"]["proj_in.weight"].shape[1]
output_dims = state_dict["image_proj"]["proj_out.weight"].shape[0]
hidden_dims = state_dict["image_proj"]["latents"].shape[2]
heads = state_dict["image_proj"]["layers.0.0.to_q.weight"].shape[0] // 64
image_projection = Resampler(
embed_dims=embed_dims,
output_dims=output_dims,
hidden_dims=hidden_dims,
heads=heads,
num_queries=num_image_text_embeds,
)
image_proj_state_dict = state_dict["image_proj"]
new_sd = OrderedDict()
for k, v in image_proj_state_dict.items():
if "0.to" in k:
k = k.replace("0.to", "2.to")
elif "1.0.weight" in k:
k = k.replace("1.0.weight", "3.0.weight")
elif "1.0.bias" in k:
k = k.replace("1.0.bias", "3.0.bias")
elif "1.1.weight" in k:
k = k.replace("1.1.weight", "3.1.net.0.proj.weight")
elif "1.3.weight" in k:
k = k.replace("1.3.weight", "3.1.net.2.weight")
if "norm1" in k:
new_sd[k.replace("0.norm1", "0")] = v
elif "norm2" in k:
new_sd[k.replace("0.norm2", "1")] = v
elif "to_kv" in k:
v_chunk = v.chunk(2, dim=0)
new_sd[k.replace("to_kv", "to_k")] = v_chunk[0]
new_sd[k.replace("to_kv", "to_v")] = v_chunk[1]
elif "to_out" in k:
new_sd[k.replace("to_out", "to_out.0")] = v
else:
new_sd[k] = v
image_projection.load_state_dict(new_sd)
del image_proj_state_dict
self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype)
self.config.encoder_hid_dim_type = "ip_image_proj"
delete_adapter_layers
+14 -1
View File
@@ -14,7 +14,12 @@
from typing import TYPE_CHECKING
from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available
from ..utils import (
DIFFUSERS_SLOW_IMPORT,
_LazyModule,
is_flax_available,
is_torch_available,
)
_import_structure = {}
@@ -23,11 +28,13 @@ if is_torch_available():
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoder_kl"] = ["AutoencoderKL"]
_import_structure["autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
_import_structure["autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["controlnet"] = ["ControlNetModel"]
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
_import_structure["modeling_utils"] = ["ModelMixin"]
_import_structure["embeddings"] = ["ImageProjection"]
_import_structure["prior_transformer"] = ["PriorTransformer"]
_import_structure["t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformer_2d"] = ["Transformer2DModel"]
@@ -36,7 +43,9 @@ if is_torch_available():
_import_structure["unet_2d"] = ["UNet2DModel"]
_import_structure["unet_2d_condition"] = ["UNet2DConditionModel"]
_import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
_import_structure["unet_kandinsky3"] = ["Kandinsky3UNet"]
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
_import_structure["unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
_import_structure["vq_model"] = ["VQModel"]
if is_flax_available():
@@ -50,10 +59,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .adapter import MultiAdapter, T2IAdapter
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
from .autoencoder_kl import AutoencoderKL
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
from .autoencoder_tiny import AutoencoderTiny
from .consistency_decoder_vae import ConsistencyDecoderVAE
from .controlnet import ControlNetModel
from .dual_transformer_2d import DualTransformer2DModel
from .embeddings import ImageProjection
from .modeling_utils import ModelMixin
from .prior_transformer import PriorTransformer
from .t5_film_transformer import T5FilmDecoder
@@ -63,7 +74,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .unet_3d_condition import UNet3DConditionModel
from .unet_kandinsky3 import Kandinsky3UNet
from .unet_motion_model import MotionAdapter, UNetMotionModel
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
from .vq_model import VQModel
if is_flax_available():
+9 -6
View File
@@ -55,11 +55,12 @@ class GELU(nn.Module):
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out)
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
self.approximate = approximate
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
@@ -81,13 +82,14 @@ class GEGLU(nn.Module):
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(self, dim_in: int, dim_out: int):
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
super().__init__()
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
self.proj = linear_cls(dim_in, dim_out * 2)
self.proj = linear_cls(dim_in, dim_out * 2, bias=bias)
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type != "mps":
@@ -109,11 +111,12 @@ class ApproximateGELU(nn.Module):
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(self, dim_in: int, dim_out: int):
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out)
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
+172 -19
View File
@@ -25,6 +25,31 @@ from .lora import LoRACompatibleLinear
from .normalization import AdaLayerNorm, AdaLayerNormZero
def _chunked_feed_forward(
ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None
):
# "feed_forward_chunk_size" can be used to save memory
if hidden_states.shape[chunk_dim] % chunk_size != 0:
raise ValueError(
f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
)
num_chunks = hidden_states.shape[chunk_dim] // chunk_size
if lora_scale is None:
ff_output = torch.cat(
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
dim=chunk_dim,
)
else:
# TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete
ff_output = torch.cat(
[ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
dim=chunk_dim,
)
return ff_output
@maybe_allow_in_graph
class GatedSelfAttentionDense(nn.Module):
r"""
@@ -194,7 +219,12 @@ class BasicTransformerBlock(nn.Module):
if not self.use_ada_layer_norm_single:
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
)
# 4. Fuser
if attention_type == "gated" or attention_type == "gated-text-image":
@@ -208,7 +238,7 @@ class BasicTransformerBlock(nn.Module):
self._chunk_size = None
self._chunk_dim = 0
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
@@ -311,18 +341,8 @@ class BasicTransformerBlock(nn.Module):
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
raise ValueError(
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
)
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
ff_output = torch.cat(
[
self.ff(hid_slice, scale=lora_scale)
for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
],
dim=self._chunk_dim,
ff_output = _chunked_feed_forward(
self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale
)
else:
ff_output = self.ff(norm_hidden_states, scale=lora_scale)
@@ -339,6 +359,137 @@ class BasicTransformerBlock(nn.Module):
return hidden_states
@maybe_allow_in_graph
class TemporalBasicTransformerBlock(nn.Module):
r"""
A basic Transformer block for video like data.
Parameters:
dim (`int`): The number of channels in the input and output.
time_mix_inner_dim (`int`): The number of channels for temporal attention.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
"""
def __init__(
self,
dim: int,
time_mix_inner_dim: int,
num_attention_heads: int,
attention_head_dim: int,
cross_attention_dim: Optional[int] = None,
):
super().__init__()
self.is_res = dim == time_mix_inner_dim
self.norm_in = nn.LayerNorm(dim)
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
self.norm_in = nn.LayerNorm(dim)
self.ff_in = FeedForward(
dim,
dim_out=time_mix_inner_dim,
activation_fn="geglu",
)
self.norm1 = nn.LayerNorm(time_mix_inner_dim)
self.attn1 = Attention(
query_dim=time_mix_inner_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
cross_attention_dim=None,
)
# 2. Cross-Attn
if cross_attention_dim is not None:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
self.norm2 = nn.LayerNorm(time_mix_inner_dim)
self.attn2 = Attention(
query_dim=time_mix_inner_dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
) # is self-attn if encoder_hidden_states is none
else:
self.norm2 = None
self.attn2 = None
# 3. Feed-forward
self.norm3 = nn.LayerNorm(time_mix_inner_dim)
self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = None
def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
# Sets chunk feed-forward
self._chunk_size = chunk_size
# chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
self._chunk_dim = 1
def forward(
self,
hidden_states: torch.FloatTensor,
num_frames: int,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
batch_size = hidden_states.shape[0]
batch_frames, seq_length, channels = hidden_states.shape
batch_size = batch_frames // num_frames
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
hidden_states = hidden_states.permute(0, 2, 1, 3)
hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
residual = hidden_states
hidden_states = self.norm_in(hidden_states)
if self._chunk_size is not None:
hidden_states = _chunked_feed_forward(self.ff, hidden_states, self._chunk_dim, self._chunk_size)
else:
hidden_states = self.ff_in(hidden_states)
if self.is_res:
hidden_states = hidden_states + residual
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
hidden_states = attn_output + hidden_states
# 3. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = self.norm2(hidden_states)
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
hidden_states = attn_output + hidden_states
# 4. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self._chunk_size is not None:
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
else:
ff_output = self.ff(norm_hidden_states)
if self.is_res:
hidden_states = ff_output + hidden_states
else:
hidden_states = ff_output
hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
hidden_states = hidden_states.permute(0, 2, 1, 3)
hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
return hidden_states
class FeedForward(nn.Module):
r"""
A feed-forward layer.
@@ -350,6 +501,7 @@ class FeedForward(nn.Module):
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(
@@ -360,6 +512,7 @@ class FeedForward(nn.Module):
dropout: float = 0.0,
activation_fn: str = "geglu",
final_dropout: bool = False,
bias: bool = True,
):
super().__init__()
inner_dim = int(dim * mult)
@@ -367,13 +520,13 @@ class FeedForward(nn.Module):
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim)
act_fn = GELU(dim, inner_dim, bias=bias)
if activation_fn == "gelu-approximate":
act_fn = GELU(dim, inner_dim, approximate="tanh")
act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
elif activation_fn == "geglu":
act_fn = GEGLU(dim, inner_dim)
act_fn = GEGLU(dim, inner_dim, bias=bias)
elif activation_fn == "geglu-approximate":
act_fn = ApproximateGELU(dim, inner_dim)
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
self.net = nn.ModuleList([])
# project in
@@ -381,7 +534,7 @@ class FeedForward(nn.Module):
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(linear_cls(inner_dim, dim_out))
self.net.append(linear_cls(inner_dim, dim_out, bias=bias))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout:
self.net.append(nn.Dropout(dropout))
+251 -3
View File
@@ -109,15 +109,17 @@ class Attention(nn.Module):
residual_connection: bool = False,
_from_deprecated_attn_block: bool = False,
processor: Optional["AttnProcessor"] = None,
out_dim: int = None,
):
super().__init__()
self.inner_dim = dim_head * heads
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.rescale_output_factor = rescale_output_factor
self.residual_connection = residual_connection
self.dropout = dropout
self.out_dim = out_dim if out_dim is not None else query_dim
# we make use of this private variable to know whether this class is loaded
# with an deprecated state dict so that we can convert it on the fly
@@ -126,7 +128,7 @@ class Attention(nn.Module):
self.scale_qk = scale_qk
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
self.heads = heads
self.heads = out_dim // dim_head if out_dim is not None else heads
# for slice_size > 0 the attention score computation
# is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice`
@@ -193,7 +195,7 @@ class Attention(nn.Module):
self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
self.to_out = nn.ModuleList([])
self.to_out.append(linear_cls(self.inner_dim, query_dim, bias=out_bias))
self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout))
# set attention processor
@@ -1975,6 +1977,250 @@ class LoRAAttnAddedKVProcessor(nn.Module):
return attn.processor(attn, hidden_states, *args, **kwargs)
class IPAdapterAttnProcessor(nn.Module):
r"""
Attention processor for IP-Adapater.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
num_tokens (`int`, defaults to 4):
The context length of the image features.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
"""
def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=4, scale=1.0):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.num_tokens = num_tokens
self.scale = scale
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
scale=1.0,
):
if scale != 1.0:
logger.warning("`scale` of IPAttnProcessor should be set with `set_ip_adapter_scale`.")
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
# split hidden states
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
encoder_hidden_states[:, end_pos:, :],
)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class IPAdapterAttnProcessor2_0(torch.nn.Module):
r"""
Attention processor for IP-Adapater for PyTorch 2.0.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
num_tokens (`int`, defaults to 4):
The context length of the image features.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
"""
def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=4, scale=1.0):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.num_tokens = num_tokens
self.scale = scale
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
scale=1.0,
):
if scale != 1.0:
logger.warning("`scale` of IPAttnProcessor should be set by `set_ip_adapter_scale`.")
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
# split hidden states
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
encoder_hidden_states[:, end_pos:, :],
)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
LORA_ATTENTION_PROCESSORS = (
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
@@ -1998,6 +2244,8 @@ CROSS_ATTENTION_PROCESSORS = (
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
)
AttentionProcessor = Union[
+4 -1
View File
@@ -18,7 +18,7 @@ import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils.accelerate_utils import apply_forward_hook
from .autoencoder_kl import AutoencoderKLOutput
from .modeling_outputs import AutoencoderKLOutput
from .modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
@@ -108,6 +108,9 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
self.use_slicing = False
self.use_tiling = False
self.register_to_config(block_out_channels=up_block_out_channels)
self.register_to_config(force_upcast=False)
@apply_forward_hook
def encode(
self, x: torch.FloatTensor, return_dict: bool = True

Some files were not shown because too many files have changed in this diff Show More