Compare commits

..

142 Commits

Author SHA1 Message Date
Dhruv Nair 56e8fca572 Merge branch 'main' into test-v 2023-11-27 13:36:38 +00:00
Dhruv Nair c5941a26a4 Merge branch 'test-v' of https://github.com/huggingface/diffusers into test-v 2023-11-27 13:35:36 +00:00
Patrick von Platen b135b6e905 [From_pretrained] Fix warning (#5948) 2023-11-27 14:35:19 +01:00
T. Xu 14a0d21d2e [Community Pipeline] Diffusion Posterior Sampling for General Noisy Inverse Problems (#5939)
* [community pipeline] dps impl

* add type checking

* pass ruff check

* ruff formatter
2023-11-27 14:29:42 +01:00
Dhruv Nair 8bc42512fe remove post quant conv 2023-11-27 13:27:46 +00:00
Patrick von Platen ebf581e85f [Tests] Make sure that we don't run tests multiple times (#5949)
* [Tests] Make sure that we don't run tests mulitple times

* [Tests] Make sure that we don't run tests mulitple times

* [Tests] Make sure that we don't run tests mulitple times
2023-11-27 14:18:56 +01:00
Patrick von Platen e550163b9f [Vae] Make sure all vae's work with latent diffusion models (#5880)
* add comments to explain the code better

* add comments to explain the code better

* add comments to explain the code better

* add comments to explain the code better

* add comments to explain the code better

* fix more

* fix more

* fix more

* fix more

* fix more

* fix more
2023-11-27 14:17:47 +01:00
patil-suraj 55b4d09080 fix upcasting 2023-11-27 14:11:26 +01:00
patil-suraj c452d9c042 up 2023-11-27 13:59:30 +01:00
patil-suraj ee9f7d2493 make added_time_ids is tensor 2023-11-27 13:55:02 +01:00
Dhruv Nair 8620851aa0 update forward pass for gradient checkpointing 2023-11-27 12:50:58 +00:00
patil-suraj 90d8e832f8 upcast vae 2023-11-27 13:50:10 +01:00
Viktor Grygorchuk 20f0cbc88f fix: error on device for lpw_stable_diffusion_xl pipeline if pipe.enable_sequential_cpu_offload() enabled (#5885)
fix: set device for pipe.enable_sequential_cpu_offload()
2023-11-27 13:47:47 +01:00
patil-suraj 18930e0b85 doc 2023-11-27 13:40:30 +01:00
Chi d72a24b790 Replace multiple variables with one variable. (#5715)
* I added a new doc string to the class. This is more flexible to understanding other developers what are doing and where it's using.

* Update src/diffusers/models/unet_2d_blocks.py

This changes suggest by maintener.

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/models/unet_2d_blocks.py

Add suggested text

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update unet_2d_blocks.py

I changed the Parameter to Args text.

* Update unet_2d_blocks.py

proper indentation set in this file.

* Update unet_2d_blocks.py

a little bit of change in the act_fun argument line.

* I run the black command to reformat style in the code

* Update unet_2d_blocks.py

similar doc-string add to have in the original diffusion repository.

* I enhanced the code by replacing multiple redundant variables with a single variable, as they all served the same purpose. Additionally, I utilized the get_activation function for improved flexibility in choosing activation functions.

* Using as black package to reformated my file

* reverte some changes

* Remove conv_out_padding variables and using as conv_in_padding

* conv_out_padding create and add them into the code.

* run black command to solving styling problem

* add little bit space between comment and import statement

* I am utilizing the ruff library to address the style issues in my Makefile.

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2023-11-27 13:34:52 +01:00
ginjia d3cda804e7 add LoRA weights load and fuse support for IPEX pipeline (#5920)
add IPEX pipeline LoRA weights loading support
2023-11-27 13:32:43 +01:00
patil-suraj 847bd0a479 fix copies 2023-11-27 13:23:31 +01:00
dg845 07eac4d65a Fix LCM Stable Diffusion distillation bug related to parsing unet_time_cond_proj_dim (#5893)
* Fix bug related to parsing unet_time_cond_proj_dim.

* Fix analogous bug in the SD-XL LCM distillation script.
2023-11-27 13:00:40 +01:00
Iván de Prado c079cae3d4 Avoid computing min() that is expensive when do_normalize is False in the image processor (#5896)
Avoid computing min() that is expensive when do_normalize is False

Avoid extra computing when do_normalize is False
2023-11-27 12:46:26 +01:00
Wang, Yi c7bfb8b22a set the model to train state before accelerator prepare (#5099)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
2023-11-27 12:43:49 +01:00
dg845 67d070749a Add Custom Timesteps Support to LCMScheduler and Supported Pipelines (#5874)
* Add custom timesteps support to LCMScheduler.

* Add custom timesteps support to StableDiffusionPipeline.

* Add custom timesteps support to StableDiffusionXLPipeline.

* Add custom timesteps support to remaining Stable Diffusion pipelines which support LCMScheduler (img2img, inpaint).

* Add custom timesteps support to remaining Stable Diffusion XL pipelines which support LCMScheduler (img2img, inpaint).

* Add custom timesteps support to StableDiffusionControlNetPipeline.

* Add custom timesteps support to T21 Stable Diffusion (XL) Adapters.

* Clean up Stable Diffusion inpaint tests.

* Manually add support for custom timesteps to AltDiffusion pipelines since make fix-copies doesn't appear to work correctly (it deletes the whole pipeline).

* make style

* Refactor pipeline timestep handling into the retrieve_timesteps function.
2023-11-27 12:39:14 +01:00
Dhruv Nair 3178b16b17 update 2023-11-27 11:37:52 +00:00
Aryan V S 9c357bda3f Deprecate KarrasVeScheduler and ScoreSdeVpScheduler (#5269)
* deprecated: KarrasVeScheduler, ScoreSdeVpScheduler

* delete tests relevant to deprecated schedulers

* chore: run make style

* fix: import error caused due to incorrect _import_structure after deprecation

* fix: ScoreSdeVpScheduler was not importable from diffusers

* remove import added by assumption

* Update src/diffusers/schedulers/__init__.py as suggested by @patrickvonplaten

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* make it a part deprecated

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Fix

* fix

* fix doc

* fix doc....again.......

* remove karras_ve test folder

Co-Authored-By: YiYi Xu <yixu310@gmail.com>

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: yiyixuxu <yixu310@gmail,com>
2023-11-27 12:33:02 +01:00
patil-suraj a08ef009d1 use math for log 2023-11-27 12:16:02 +01:00
patil-suraj 804bdebe51 Merge branch 'test-v' of https://github.com/huggingface/diffusers into test-v 2023-11-27 12:01:11 +01:00
patil-suraj a193e49dff use c_noise values for timesteps 2023-11-27 12:01:08 +01:00
Dhruv Nair c9d1727613 clean up 2023-11-27 11:00:02 +00:00
Sayak Paul 3f7c3511dc [Core] add support for gradient checkpointing in transformer_2d (#5943)
add support for gradient checkpointing in transformer_2d
2023-11-27 16:21:12 +05:30
Dhruv Nair 82cf60828f Merge branch 'test-v' of https://github.com/huggingface/diffusers into test-v 2023-11-27 10:50:12 +00:00
Dhruv Nair 26ed460265 clean up 2023-11-27 10:49:58 +00:00
Dhruv Nair 403a81c30d clean up temp decoder 2023-11-27 10:21:22 +00:00
patil-suraj 1b3cf2db5e Merge branch 'test-v' of https://github.com/huggingface/diffusers into test-v 2023-11-27 11:13:20 +01:00
patil-suraj b8d84c4320 fix norm eps in TransformerSpatioTemporalModel 2023-11-27 11:13:18 +01:00
Dhruv Nair 3fbe123d84 make temb optional in Decoder mid block 2023-11-27 10:09:41 +00:00
Dhruv Nair f7cf8c338c clean up 2023-11-27 09:53:56 +00:00
Dhruv Nair ab8076f234 Merge branch 'test-v' of https://github.com/huggingface/diffusers into test-v 2023-11-27 09:50:00 +00:00
Dhruv Nair 7b6a0d48c6 add slow svd test 2023-11-27 09:45:00 +00:00
patil-suraj 6adae54046 clean TransformerSpatioTemporalModel 2023-11-27 10:34:44 +01:00
patil-suraj af85fb1bc1 clean up unet 2023-11-27 10:03:40 +01:00
Dhruv Nair 760333d524 add unet tests 2023-11-27 08:12:02 +00:00
Junsong Chen 7d6f30e89b [Fix: pixart-alpha] random 512px resolution bug (#5842)
* [Fix: pixart-alpha]
add ASPECT_RATIO_512_BIN in use_resolution_binning for random 512px image generation.

* add slow test file for 512px generation without resolution binning

* fix: slow tests for resolution binning.

---------

Co-authored-by: jschen <chenjunsong4@h-partners.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2023-11-27 13:35:35 +05:30
Patrick von Platen 6d2e19f746 [Examples] Allow downloading variant model files (#5531)
* add variant

* add variant

* Apply suggestions from code review

* reformat

* fix: textual_inversion.py

* fix: variant in model_info

---------

Co-authored-by: sayakpaul <spsayakpaul@gmail.com>
2023-11-27 10:43:20 +05:30
patil-suraj f651c12ef8 don't scale image latents 2023-11-26 17:13:04 +01:00
patil-suraj d614a33a09 use AutoencoderKLTemporalDecoder 2023-11-26 17:00:22 +01:00
patil-suraj 13b646edd3 remove hack 2023-11-26 16:59:21 +01:00
patil-suraj cb49cbdd29 add pipeline and vae in init 2023-11-26 16:58:59 +01:00
patil-suraj 1ce8ff51e6 accept fps as arg 2023-11-26 16:20:22 +01:00
patil-suraj fdd182f335 allow passing PIL to export_video 2023-11-26 16:19:25 +01:00
patil-suraj 2a46326c25 up 2023-11-26 16:07:24 +01:00
patil-suraj e34e9d9a33 take guidance scale as input 2023-11-26 16:06:44 +01:00
patil-suraj 96af28f92b style 2023-11-26 16:01:32 +01:00
patil-suraj 6827a1dc6a add vae conversion 2023-11-26 15:42:27 +01:00
patil-suraj c3bdeb8a4c skip_post_quant_conv 2023-11-26 13:07:50 +01:00
patil-suraj cf70b9a0b4 fix missing activation in TemporalDecoder 2023-11-26 13:06:44 +01:00
patil-suraj 712b9950c5 fix guidance_scales dtype 2023-11-26 12:47:51 +01:00
patil-suraj 21148de853 fix typo 2023-11-26 12:45:01 +01:00
patil-suraj d930977656 fix attention in MidBlockTemporalDecoder 2023-11-26 12:01:14 +01:00
patil-suraj 268ffea0e7 cast alpha to sample dtype 2023-11-26 11:15:28 +01:00
patil-suraj 8bcf43d52a fix num frames during split decoding 2023-11-26 11:10:42 +01:00
patil-suraj b071aaa719 switch spatial to temporal for mixing in VAE 2023-11-26 10:51:53 +01:00
patil-suraj 5316fb5107 pass num frames in decode 2023-11-25 19:15:19 +01:00
patil-suraj 9af07d1d5c fix default values in vae 2023-11-25 19:09:47 +01:00
patil-suraj d0017d9b70 allow using differnt eps in temporal block for video decoder 2023-11-25 19:02:57 +01:00
patil-suraj 0cf6c6b291 type image_latents same as image_embeddings 2023-11-25 16:20:01 +01:00
patil-suraj df986274d6 fix dtype in TransformerSpatioTemporalModel 2023-11-25 16:17:45 +01:00
patil-suraj 7ddd14bd94 vae encode/decode in fp32 2023-11-25 16:16:01 +01:00
patil-suraj 4346ddd402 fix decode_latents 2023-11-25 14:33:25 +01:00
patil-suraj 9da55b381c pass decoding_t to decode_latents 2023-11-25 14:30:27 +01:00
patil-suraj 4d4469ee87 decode n frames at a time 2023-11-25 14:30:09 +01:00
patil-suraj f9954a0e7b decode in float32 2023-11-25 14:02:23 +01:00
patil-suraj e7798333c4 fix frame decodig 2023-11-25 14:01:01 +01:00
patil-suraj efb1e5e1d8 make pipeline run 2023-11-24 21:30:31 +01:00
Patrick von Platen 2a7f43a73b correct num inference steps 2023-11-24 17:09:26 +00:00
Patrick von Platen b978334d71 [@cene555][Kandinsky 3.0] Add Kandinsky 3.0 (#5913)
* finalize

* finalize

* finalize

* add slow test

* add slow test

* add slow test

* Fix more

* add slow test

* fix more

* fix more

* fix more

* fix more

* fix more

* fix more

* fix more

* fix more

* fix more

* Better

* Fix more

* Fix more

* add slow test

* Add auto pipelines

* add slow test

* Add all

* add slow test

* add slow test

* add slow test

* add slow test

* add slow test

* Apply suggestions from code review

* add slow test

* add slow test
2023-11-24 17:46:00 +01:00
Dhruv Nair beaaf18b2c Merge branch 'test-v' of https://github.com/huggingface/diffusers into test-v 2023-11-24 16:36:06 +00:00
Dhruv Nair 132fe97bf4 add temporal autoencoder 2023-11-24 16:35:41 +00:00
Sayak Paul e5f232f76b [Docs] add: 8bit inference with pixart alpha (#5814)
* add: 8bit inference with pixart alpha

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* add: note on 4bit.

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* address comment

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2023-11-24 20:36:33 +05:30
patil-suraj 2f35e8c94c fix norm eps in temporal transformers 2023-11-24 15:40:41 +01:00
patil-suraj b336529573 add guidance scalings 2023-11-24 14:16:50 +01:00
patil-suraj 3e47d3c8ed adapt scheduler 2023-11-24 14:06:07 +01:00
patil-suraj 122a6bd390 begin pipeline 2023-11-24 13:36:57 +01:00
Dhruv Nair 37c428a79c Merge branch 'test-v' of https://github.com/huggingface/diffusers into test-v 2023-11-24 12:24:57 +00:00
Dhruv Nair eefed8ab6b update up/mid blocks for decoder 2023-11-24 12:23:14 +00:00
Dhruv Nair 05eaec2d39 Merge branch 'test-v-old' into test-v 2023-11-24 12:19:29 +00:00
Dhruv Nair e68424378f update vae 2023-11-24 12:19:11 +00:00
patil-suraj 24b5c4360c check for None 2023-11-24 11:53:50 +01:00
patil-suraj 0c4192b537 up 2023-11-24 11:51:40 +01:00
patil-suraj dff26ce8af up 2023-11-24 11:50:02 +01:00
patil-suraj 9f22651c1f remove more unsed args 2023-11-24 11:48:58 +01:00
patil-suraj d8c9e67aac remove unused arg 2023-11-24 11:38:34 +01:00
patil-suraj 6c28367b1a remove unused arg 2023-11-24 11:36:01 +01:00
patil-suraj f9def2aeed add in init 2023-11-24 11:31:30 +01:00
patil-suraj 576fa1c7dc remove UNetMidBlockSpatioTemporal 2023-11-24 11:30:35 +01:00
patil-suraj f1457b7e1d update conversion script 2023-11-24 11:24:42 +01:00
patil-suraj 1f34311eec rename model 2023-11-24 11:24:34 +01:00
patil-suraj f976f5a31e Merge branch 'test-v' of https://github.com/huggingface/diffusers into test-v 2023-11-24 11:17:55 +01:00
patil-suraj 8e1851a16a Merge branch 'test-v' of https://github.com/huggingface/diffusers into test-v 2023-11-24 11:17:51 +01:00
patil-suraj 6c69c7a0d2 add blocks 2023-11-24 11:11:15 +01:00
Dhruv Nair 6481e9495f make temb optional 2023-11-24 10:10:09 +00:00
Dhruv Nair 8c3fd58c85 Merge branch 'test-v' of https://github.com/huggingface/diffusers into test-v 2023-11-24 09:51:43 +00:00
Dhruv Nair 9117547ee0 clean up 2023-11-24 09:51:29 +00:00
patil-suraj af1e86af8d fix time_context dim 2023-11-24 10:47:44 +01:00
patil-suraj 29551f8e30 fix TransformerSpatioTemporalModel 2023-11-24 10:19:44 +01:00
patil-suraj 661033171b use TransformerSpatioTemporalModel 2023-11-24 10:16:22 +01:00
patil-suraj 20efe541c5 fix TemporalBasicTransformerBlock 2023-11-24 10:11:40 +01:00
patil-suraj 5a523e21c6 reuse TemporalBasicTransformerBlock 2023-11-24 10:04:22 +01:00
patil-suraj b0fc4fd4cb fix SpatioTemporalResBlock 2023-11-24 10:01:09 +01:00
patil-suraj 678d19fa18 fix temb shape 2023-11-24 09:41:15 +01:00
patil-suraj c8ec445964 style 2023-11-24 09:34:53 +01:00
patil-suraj ffd9e26a65 use new blocks 2023-11-24 09:26:42 +01:00
patil-suraj 6f87490408 fix shapes in Alphablender and add time activation in res blcok 2023-11-24 08:57:28 +01:00
Dhruv Nair 9c9d46763b update 2023-11-24 07:12:50 +00:00
Dhruv Nair 47684dab43 update 2023-11-24 04:14:58 +00:00
Linoy Tsaban 3003ff4947 [bug fix] fix small bug in readme template of sdxl lora training script (#5914)
readme improvement and metadata fix
2023-11-23 19:08:49 +01:00
Dhruv Nair 5218f46173 fix blocks 2023-11-23 14:32:18 +00:00
Dhruv Nair 8ee280773f add vae blocks 2023-11-23 14:28:07 +00:00
Dhruv Nair 85846f7450 add spatio temporal transformers 2023-11-23 13:02:34 +00:00
patil-suraj 28dee6e735 fix temb shape in TemporalResnetBlock 2023-11-23 13:52:48 +01:00
patil-suraj 165ed7c5d5 return sample in original shape 2023-11-23 13:52:40 +01:00
patil-suraj d4cdfa33f5 make forward work 2023-11-23 13:35:52 +01:00
Linoy Tsaban 5ffa603244 [bug fix] fix small bug in readme template of sdxl lora training script (#5906)
* readme bug fix

* style fix

---------

Co-authored-by: Linoy Tsaban <linoy@huggingface.co>
2023-11-23 12:11:50 +01:00
Dhruv Nair 1bd09b1489 Merge branch 'test-v' of https://github.com/huggingface/diffusers into test-v 2023-11-23 10:54:08 +00:00
Dhruv Nair edf7121ec7 add new resnet blocks 2023-11-23 10:53:25 +00:00
patil-suraj 7b64d3a17b up 2023-11-23 10:48:59 +01:00
patil-suraj c93606c93c fix model 2023-11-23 10:47:57 +01:00
patil-suraj 5df09ef355 add conversion script 2023-11-22 19:15:18 +01:00
patil-suraj ac9473153c fix add_embedding 2023-11-22 19:04:10 +01:00
patil-suraj ee9d7b8ecd fix time_pos_embed 2023-11-22 18:59:44 +01:00
patil-suraj 669824e5bb fix temporal res block 2023-11-22 17:44:56 +01:00
Linoy Tsaban 0eeee618cf Adds an advanced version of the SD-XL DreamBooth LoRA training script supporting pivotal tuning (#5883)
* sdxl dreambooth lora training script with pivotal tuning

* bug fix - args missing from parse_args

* code quality fixes

* comment unnecessary code from TokenEmbedding handler class

* fixup

---------

Co-authored-by: Linoy Tsaban <linoy@huggingface.co>
2023-11-22 16:27:56 +01:00
patil-suraj 45c9b56bf7 use TimestepEmbedding 2023-11-22 15:56:09 +01:00
patil-suraj cad51d45d1 addition_time_embed_dim 2023-11-22 14:26:43 +01:00
patil-suraj 7de5d7c6fd add_embedding 2023-11-22 14:06:50 +01:00
patil-suraj 58883ee085 finish blocks 2023-11-22 13:42:10 +01:00
Andrés Romero 93f1a14cab ControlNet+Adapter pipeline, and ControlNet+Adapter+Inpaint pipeline (#5869)
* ControlNet+Adapter pipeline, and +Inpaint pipeline


---------

Co-authored-by: andres <andres@hax.ai>
2023-11-21 08:59:29 -10:00
Patrick von Platen 13d73d9303 [Lora] Seperate logic (#5809)
* [Lora] Seperate logic

* [Lora] Seperate logic

* [Lora] Seperate logic

* add comments to explain the code better

* add comments to explain the code better
2023-11-21 18:58:37 +01:00
YiYi Xu ba352aea29 [feat] IP Adapters (author @okotaku ) (#5713)
* add ip-adapter


---------

Co-authored-by: okotaku <to78314910@gmail.com>
Co-authored-by: sayakpaul <spsayakpaul@gmail.com>
Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2023-11-21 07:34:30 -10:00
Linoy Tsaban 6fac1369d0 Add features to the Dreambooth LoRA SDXL training script (#5508)
* Additions:
- support for different lr for text encoder
- support for Prodigy optimizer
- support for min snr gamma
- support for custom captions and dataset loading from the hub

* adjusted --caption_column behaviour (to -not- use the second column of the dataset by default if --caption_column is not provided)

* fixed --output_dir / --model_dir_name confusion

* added --repeats, --adam_weight_decay_text_encoder
+ some fixes

* Update examples/dreambooth/train_dreambooth_lora_sdxl.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update examples/dreambooth/train_dreambooth_lora_sdxl.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Update examples/dreambooth/train_dreambooth_lora_sdxl.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* - import compute_snr from diffusers/training_utils.py
- cluster adamw together
- when using 'prodigy', if --train_text_encoder == True and --text_encoder_lr != --learning rate, changes the lr of the text encoders optimization params to be --learning_rate (otherwise errors)

* shape fixes when custom captions are used

* formatting and a little cleanup

* code styling

* --repeats default value fixed, changed to 1

* bug fix - removed redundant lines of embedding concatenation when using prior_preservation (that duplicated class_prompt embeddings)

* changed dataset loading logic according to the following usecases (to avoid unnecessary dependency on datasets)-
1. user provides --dataset_name
2. user provides local dir --instance_data_dir that contains a metadata .jsonl file
3. user provides local dir --instance_data_dir that contains only images
in cases [1,2] we import datasets and use load_dataset method, in case [3] we process the data same as in the original script setting

* styling fix

* arg name fix

* adjusted the --repeats logic

* -removed redundant arg and 'if' when loading local folder with prompts
-updated readme template
-some default val fixes
-custom caption tests

* image path fix for readme

* code style

* bug fix

* --caption_column arg

* readme fix

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Linoy Tsaban <linoy@huggingface.co>
2023-11-21 17:38:43 +01:00
patil-suraj 2f5648177e begin model 2023-11-21 16:39:15 +01:00
Steven Liu 1093f9d615 [docs] MusicLDM (#5854)
* fix

* feedback
2023-11-21 15:27:41 +01:00
Aryan V S 81780882b8 Addition of new callbacks to controlnets (#5812)
* add new callbacks to src/diffusers/pipelines/controlnet/pipeline_controlnet.py

* update callbacks

* fix repeated kwarg

* update

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2023-11-21 15:22:20 +01:00
Dhruv Nair ebc7bedeb7 Add tests fetcher (#5848)
* add tests fetcher to utils

* add test fetcher

* update

* update

* remove unused dependency version check script

* update

* fix mistake

* update

* update

* update

* update

* update

* update

* update

* remove concurrency params

* update

* update

* update

* update

* update

* update

* move test fetcher to dedicated workflow
2023-11-21 18:01:44 +05:30
178 changed files with 20481 additions and 1042 deletions
+175
View File
@@ -0,0 +1,175 @@
name: Fast tests for PRs - Test Fetcher
on:
pull_request:
branches:
- main
push:
branches:
- ci-*
env:
DIFFUSERS_IS_CI: yes
OMP_NUM_THREADS: 4
MKL_NUM_THREADS: 4
PYTEST_TIMEOUT: 60
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
setup_pr_tests:
name: Setup PR Tests
runs-on: docker-cpu
container:
image: diffusers/diffusers-pytorch-cpu
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
defaults:
run:
shell: bash
outputs:
matrix: ${{ steps.set_matrix.outputs.matrix }}
test_map: ${{ steps.set_matrix.outputs.test_map }}
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
run: |
apt-get update && apt-get install libsndfile1-dev libgl1 -y
python -m pip install -e .
- name: Environment
run: |
python utils/print_env.py
- name: Fetch Tests
run: |
python utils/tests_fetcher.py | tee test_preparation.txt
- name: Report fetched tests
uses: actions/upload-artifact@v3
with:
name: test_fetched
path: test_preparation.txt
- id: set_matrix
name: Create Test Matrix
# The `keys` is used as GitHub actions matrix for jobs, i.e. `models`, `pipelines`, etc.
# The `test_map` is used to get the actual identified test files under each key.
# If no test to run (so no `test_map.json` file), create a dummy map (empty matrix will fail)
run: |
if [ -f test_map.json ]; then
keys=$(python3 -c 'import json; fp = open("test_map.json"); test_map = json.load(fp); fp.close(); d = list(test_map.keys()); print(json.dumps(d))')
test_map=$(python3 -c 'import json; fp = open("test_map.json"); test_map = json.load(fp); fp.close(); print(json.dumps(test_map))')
else
keys=$(python3 -c 'keys = ["dummy"]; print(keys)')
test_map=$(python3 -c 'test_map = {"dummy": []}; print(test_map)')
fi
echo $keys
echo $test_map
echo "matrix=$keys" >> $GITHUB_OUTPUT
echo "test_map=$test_map" >> $GITHUB_OUTPUT
run_pr_tests:
name: Run PR Tests
needs: setup_pr_tests
if: contains(fromJson(needs.setup_pr_tests.outputs.matrix), 'dummy') != true
strategy:
fail-fast: false
max-parallel: 2
matrix:
modules: ${{ fromJson(needs.setup_pr_tests.outputs.matrix) }}
runs-on: docker-cpu
container:
image: diffusers/diffusers-pytorch-cpu
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
defaults:
run:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
run: |
apt-get update && apt-get install libsndfile1-dev libgl1 -y
python -m pip install -e .[quality,test]
python -m pip install accelerate
- name: Environment
run: |
python utils/print_env.py
- name: Run all selected tests on CPU
run: |
python -m pytest -n 2 --dist=loadfile -v --make-reports=${{ matrix.modules }}_tests_cpu ${{ fromJson(needs.setup_pr_tests.outputs.test_map)[matrix.modules] }}
- name: Failure short reports
if: ${{ failure() }}
continue-on-error: true
run: |
cat reports/${{ matrix.modules }}_tests_cpu_stats.txt
cat reports/${{ matrix.modules }}_tests_cpu/failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v3
with:
name: ${{ matrix.modules }}_test_reports
path: reports
run_staging_tests:
strategy:
fail-fast: false
matrix:
config:
- name: Hub tests for models, schedulers, and pipelines
framework: hub_tests_pytorch
runner: docker-cpu
image: diffusers/diffusers-pytorch-cpu
report: torch_hub
name: ${{ matrix.config.name }}
runs-on: ${{ matrix.config.runner }}
container:
image: ${{ matrix.config.image }}
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
defaults:
run:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
run: |
apt-get update && apt-get install libsndfile1-dev libgl1 -y
python -m pip install -e .[quality,test]
- name: Environment
run: |
python utils/print_env.py
- name: Run Hub tests for models, schedulers, and pipelines on a staging env
if: ${{ matrix.config.framework == 'hub_tests_pytorch' }}
run: |
HUGGINGFACE_CO_STAGING=true python -m pytest \
-m "is_staging_test" \
--make-reports=tests_${{ matrix.config.report }} \
tests
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v2
with:
name: pr_${{ matrix.config.report }}_test_reports
path: reports
+4
View File
@@ -5,6 +5,10 @@ on:
branches:
- main
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
env:
DIFFUSERS_IS_CI: yes
HF_HOME: /mnt/cache
+4
View File
@@ -13,6 +13,10 @@ env:
PYTEST_TIMEOUT: 600
RUN_SLOW: no
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
run_fast_tests_apple_m1:
name: Fast PyTorch MPS tests on MacOS
+2
View File
@@ -278,6 +278,8 @@
title: Kandinsky 2.1
- local: api/pipelines/kandinsky_v22
title: Kandinsky 2.2
- local: api/pipelines/kandinsky3
title: Kandinsky 3
- local: api/pipelines/latent_consistency_models
title: Latent Consistency Models
- local: api/pipelines/latent_diffusion
@@ -0,0 +1,24 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Kandinsky 3
TODO
## Kandinsky3Pipeline
[[autodoc]] Kandinsky3Pipeline
- all
- __call__
## Kandinsky3Img2ImgPipeline
[[autodoc]] Kandinsky3Img2ImgPipeline
- all
- __call__
+106
View File
@@ -35,6 +35,112 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
</Tip>
## Inference with under 8GB GPU VRAM
Run the [`PixArtAlphaPipeline`] with under 8GB GPU VRAM by loading the text encoder in 8-bit precision. Let's walk through a full-fledged example.
First, install the [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) library:
```bash
pip install -U bitsandbytes
```
Then load the text encoder in 8-bit:
```python
from transformers import T5EncoderModel
from diffusers import PixArtAlphaPipeline
import torch
text_encoder = T5EncoderModel.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
subfolder="text_encoder",
load_in_8bit=True,
device_map="auto",
)
pipe = PixArtAlphaPipeline.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
text_encoder=text_encoder,
transformer=None,
device_map="auto"
)
```
Now, use the `pipe` to encode a prompt:
```python
with torch.no_grad():
prompt = "cute cat"
prompt_embeds, prompt_attention_mask, negative_embeds, negative_prompt_attention_mask = pipe.encode_prompt(prompt)
```
Since text embeddings have been computed, remove the `text_encoder` and `pipe` from the memory, and free up som GPU VRAM:
```python
import gc
def flush():
gc.collect()
torch.cuda.empty_cache()
del text_encoder
del pipe
flush()
```
Then compute the latents with the prompt embeddings as inputs:
```python
pipe = PixArtAlphaPipeline.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
text_encoder=None,
torch_dtype=torch.float16,
).to("cuda")
latents = pipe(
negative_prompt=None,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
num_images_per_prompt=1,
output_type="latent",
).images
del pipe.transformer
flush()
```
<Tip>
Notice that while initializing `pipe`, you're setting `text_encoder` to `None` so that it's not loaded.
</Tip>
Once the latents are computed, pass it off to the VAE to decode into a real image:
```python
with torch.no_grad():
image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
image = pipe.image_processor.postprocess(image, output_type="pil")[0]
image.save("cat.png")
```
By deleting components you aren't using and flushing the GPU VRAM, you should be able to run [`PixArtAlphaPipeline`] with under 8GB GPU VRAM.
![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/pixart/8bits_cat.png)
If you want a report of your memory-usage, run this [script](https://gist.github.com/sayakpaul/3ae0f847001d342af27018a96f467e4e).
<Tip warning={true}>
Text embeddings computed in 8-bit can impact the quality of the generated images because of the information loss in the representation space caused by the reduced precision. It's recommended to compare the outputs with and without 8-bit.
</Tip>
While loading the `text_encoder`, you set `load_in_8bit` to `True`. You could also specify `load_in_4bit` to bring your memory requirements down even further to under 7GB.
## PixArtAlphaPipeline
[[autodoc]] PixArtAlphaPipeline
@@ -25,4 +25,4 @@ The abstract from the paper is:
</Tip>
## ScoreSdeVpScheduler
[[autodoc]] schedulers.scheduling_sde_vp.ScoreSdeVpScheduler
[[autodoc]] schedulers.deprecated.scheduling_sde_vp.ScoreSdeVpScheduler
@@ -18,4 +18,4 @@ specific language governing permissions and limitations under the License.
[[autodoc]] KarrasVeScheduler
## KarrasVeOutput
[[autodoc]] schedulers.scheduling_karras_ve.KarrasVeOutput
[[autodoc]] schedulers.deprecated.scheduling_karras_ve.KarrasVeOutput
@@ -307,3 +307,331 @@ prompt = "a house by william eggleston, sunrays, beautiful, sunlight, sunrays, b
image = pipeline(prompt=prompt).images[0]
image
```
## IP-Adapter
[IP-Adapter](https://ip-adapter.github.io/) is an effective and lightweight adapter that adds image prompting capabilities to a diffusion model. This adapter works by decoupling the cross-attention layers of the image and text features. All the other model components are frozen and only the embedded image features in the UNet are trained. As a result, IP-Adapter files are typically only ~100MBs.
IP-Adapter works with most of our pipelines, including Stable Diffusion, Stable Diffusion XL (SDXL), ControlNet, T2I-Adapter, AnimateDiff. And you can use any custom models finetuned from the same base models. It also works with LCM-Lora out of box.
<Tip>
You can find official IP-Adapter checkpoints in [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter).
IP-Adapter was contributed by [okotaku](https://github.com/okotaku).
</Tip>
Let's first create a Stable Diffusion Pipeline.
```py
from diffusers import AutoPipelineForText2Image
import torch
from diffusers.utils import load_image
pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
```
Now load the [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter) weights with the [`~loaders.IPAdapterMixin.load_ip_adapter`] method.
```py
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
```
<Tip>
IP-Adapter relies on an image encoder to generate the image features, if your IP-Adapter weights folder contains a "image_encoder" subfolder, the image encoder will be automatically loaded and registered to the pipeline. Otherwise you can so load a [`~transformers.CLIPVisionModelWithProjection`] model and pass it to a Stable Diffusion pipeline when you create it.
```py
from diffusers import AutoPipelineForText2Image, CLIPVisionModelWithProjection
import torch
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"h94/IP-Adapter",
subfolder="models/image_encoder",
torch_dtype=torch.float16,
).to("cuda")
pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5", image_encoder=image_encoder, torch_dtype=torch.float16).to("cuda")
```
</Tip>
IP-Adapter allows you to use both image and text to condition the image generation process. For example, let's use the bear image from the [Textual Inversion](#textual-inversion) section as the image prompt (`ip_adapter_image`) along with a text prompt to add "sunglasses". 😎
```py
pipeline.set_ip_adapter_scale(0.6)
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png")
generator = torch.Generator(device="cpu").manual_seed(33)
images = pipeline(
    prompt='best quality, high quality, wearing sunglasses',
    ip_adapter_image=image,
    negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
    num_inference_steps=50,
    generator=generator,
).images
images[0]
```
<div class="flex justify-center">
    <img src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ip-bear.png" />
</div>
<Tip>
You can use the [`~loaders.IPAdapterMixin.set_ip_adapter_scale`] method to adjust the text prompt and image prompt condition ratio.  If you're only using the image prompt, you should set the scale to `1.0`. You can lower the scale to get more generation diversity, but it'll be less aligned with the prompt.
`scale=0.5` can achieve good results in most cases when you use both text and image prompts.
</Tip>
IP-Adapter also works great with Image-to-Image and Inpainting pipelines. See below examples of how you can use it with Image-to-Image and Inpaint.
<hfoptions id="tasks">
<hfoption id="image-to-image">
```py
from diffusers import AutoPipelineForImage2Image
import torch
from diffusers.utils import load_image
pipeline = AutoPipelineForImage2Image.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/vermeer.jpg")
ip_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/river.png")
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
generator = torch.Generator(device="cpu").manual_seed(33)
images = pipeline(
    prompt='best quality, high quality',
    image = image,
    ip_adapter_image=ip_image,
    num_inference_steps=50,
    generator=generator,
    strength=0.6,
).images
images[0]
```
</hfoption>
<hfoption id="inpaint">
```py
from diffusers import AutoPipelineForInpaint
import torch
from diffusers.utils import load_image
pipeline = AutoPipelineForInpaint.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float).to("cuda")
image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/inpaint_image.png")
mask = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/mask.png")
ip_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/girl.png")
image = image.resize((512, 768))
mask = mask.resize((512, 768))
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
generator = torch.Generator(device="cpu").manual_seed(33)
images = pipeline(
prompt='best quality, high quality',
image = image,
mask_image = mask,
ip_adapter_image=ip_image,
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
num_inference_steps=50,
generator=generator,
strength=0.5,
).images
images[0]
```
</hfoption>
</hfoptions>
IP-Adapters can also be used with [SDXL](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)
```python
from diffusers import AutoPipelineForText2Image
from diffusers.utils import load_image
import torch
pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16
).to("cuda")
image = load_image("https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/watercolor_painting.jpeg")
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
generator = torch.Generator(device="cpu").manual_seed(33)
image = pipeline(
prompt="best quality, high quality",
ip_adapter_image=image,
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
num_inference_steps=25,
generator=generator,
).images[0]
image.save("sdxl_t2i.png")
```
<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/watercolor_painting.jpeg"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">input image</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/sayakpaul/sample-datasets/resolve/main/sdxl_t2i.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">adapted image</figcaption>
</div>
</div>
### LCM-Lora
You can use IP-Adapter with LCM-Lora to achieve "instant fine-tune" with custom images. Note that you need to load IP-Adapter weights before loading the LCM-Lora weights.
```py
from diffusers import DiffusionPipeline, LCMScheduler
import torch
from diffusers.utils import load_image
model_id = "sd-dreambooth-library/herge-style"
lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5"
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
pipe.load_lora_weights(lcm_lora_id)
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
prompt = "best quality, high quality"
image = load_image("https://user-images.githubusercontent.com/24734142/266492875-2d50d223-8475-44f0-a7c6-08b51cb53572.png")
images = pipe(
prompt=prompt,
ip_adapter_image=image,
num_inference_steps=4,
guidance_scale=1,
).images[0]
```
### Other pipelines
IP-Adapter is compatible with any pipeline that (1) uses a text prompt and (2) uses Stable Diffusion or Stable Diffusion XL checkpoint. To use IP-Adapter with a different pipeline, all you need to do is to run `load_ip_adapter()` method after you create the pipeline, and then pass your image to the pipeline as `ip_adapter_image`
<Tip>
🤗 Diffusers currently only supports using IP-Adapter with some of the most popular pipelines, feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you have a cool use-case and require integrating IP-adapters with a pipeline that does not support it yet!
</Tip>
You can find below examples on how to use IP-Adapter with ControlNet and AnimateDiff.
<hfoptions id="model">
<hfoption id="ControlNet">
```
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
import torch
from diffusers.utils import load_image
controlnet_model_path = "lllyasviel/control_v11f1p_sd15_depth"
controlnet = ControlNetModel.from_pretrained(controlnet_model_path, torch_dtype=torch.float16)
pipeline = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16)
pipeline.to("cuda")
image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/statue.png")
depth_map = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/depth.png")
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
generator = torch.Generator(device="cpu").manual_seed(33)
images = pipeline(
prompt='best quality, high quality',
image=depth_map,
ip_adapter_image=image,
negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
num_inference_steps=50,
generator=generator,
).images
images[0]
```
<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/statue.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">input image</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ipa-controlnet-out.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">adapted image</figcaption>
</div>
</div>
</hfoption>
<hfoption id="AnimateDiff">
```py
# animate diff + ip adapter
import torch
from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler
from diffusers.utils import export_to_gif, load_image
# Load the motion adapter
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16)
# load SD 1.5 based finetuned model
model_id = "Lykon/DreamShaper"
pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter, torch_dtype=torch.float16)
# scheduler
scheduler = DDIMScheduler(
clip_sample=False,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="linear",
timestep_spacing="trailing",
steps_offset=1
)
pipe.scheduler = scheduler
# enable memory savings
pipe.enable_vae_slicing()
pipe.enable_model_cpu_offload()
# load ip_adapter
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
# load motion adapters
pipe.load_lora_weights("guoyww/animatediff-motion-lora-zoom-out", adapter_name="zoom-out")
pipe.load_lora_weights("guoyww/animatediff-motion-lora-tilt-up", adapter_name="tilt-up")
pipe.load_lora_weights("guoyww/animatediff-motion-lora-pan-left", adapter_name="pan-left")
seed = 42
image = load_image("https://user-images.githubusercontent.com/24734142/266492875-2d50d223-8475-44f0-a7c6-08b51cb53572.png")
images = [image] * 3
prompts = ["best quality, high quality"] * 3
negative_prompt = "bad quality, worst quality"
adapter_weights = [[0.75, 0.0, 0.0], [0.0, 0.0, 0.75], [0.0, 0.75, 0.75]]
# generate
output_frames = []
for prompt, image, adapter_weight in zip(prompts, images, adapter_weights):
pipe.set_adapters(["zoom-out", "tilt-up", "pan-left"], adapter_weights=adapter_weight)
output = pipe(
prompt= prompt,
num_frames=16,
guidance_scale=7.5,
num_inference_steps=30,
ip_adapter_image = image,
generator=torch.Generator("cpu").manual_seed(seed),
)
frames = output.frames[0]
output_frames.extend(frames)
export_to_gif(output_frames, "test_out_animation.gif")
```
</hfoption>
</hfoptions>
File diff suppressed because it is too large Load Diff
+280
View File
@@ -2343,3 +2343,283 @@ images = pipe(
assert len(images) == (len(prompts) - 1) * num_interpolation_steps
```
### ControlNet + T2I Adapter Pipeline
This pipelines combines both ControlNet and T2IAdapter into a single pipeline, where the forward pass is executed once.
It receives `control_image` and `adapter_image`, as well as `controlnet_conditioning_scale` and `adapter_conditioning_scale`, for the ControlNet and Adapter modules, respectively. Whenever `adapter_conditioning_scale = 0` or `controlnet_conditioning_scale = 0`, it will act as a full ControlNet module or as a full T2IAdapter module, respectively.
```py
import cv2
import numpy as np
import torch
from controlnet_aux.midas import MidasDetector
from PIL import Image
from diffusers import AutoencoderKL, ControlNetModel, MultiAdapter, T2IAdapter
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
from diffusers.utils import load_image
from examples.community.pipeline_stable_diffusion_xl_controlnet_adapter import (
StableDiffusionXLControlNetAdapterPipeline,
)
controlnet_depth = ControlNetModel.from_pretrained(
"diffusers/controlnet-depth-sdxl-1.0",
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True
)
adapter_depth = T2IAdapter.from_pretrained(
"TencentARC/t2i-adapter-depth-midas-sdxl-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
)
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True)
pipe = StableDiffusionXLControlNetAdapterPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
controlnet=controlnet_depth,
adapter=adapter_depth,
vae=vae,
variant="fp16",
use_safetensors=True,
torch_dtype=torch.float16,
)
pipe = pipe.to("cuda")
pipe.enable_xformers_memory_efficient_attention()
# pipe.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)
midas_depth = MidasDetector.from_pretrained(
"valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
).to("cuda")
prompt = "a tiger sitting on a park bench"
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
image = load_image(img_url).resize((1024, 1024))
depth_image = midas_depth(
image, detect_resolution=512, image_resolution=1024
)
strength = 0.5
images = pipe(
prompt,
control_image=depth_image,
adapter_image=depth_image,
num_inference_steps=30,
controlnet_conditioning_scale=strength,
adapter_conditioning_scale=strength,
).images
images[0].save("controlnet_and_adapter.png")
```
### ControlNet + T2I Adapter + Inpainting Pipeline
```py
import cv2
import numpy as np
import torch
from controlnet_aux.midas import MidasDetector
from PIL import Image
from diffusers import AutoencoderKL, ControlNetModel, MultiAdapter, T2IAdapter
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
from diffusers.utils import load_image
from examples.community.pipeline_stable_diffusion_xl_controlnet_adapter_inpaint import (
StableDiffusionXLControlNetAdapterInpaintPipeline,
)
controlnet_depth = ControlNetModel.from_pretrained(
"diffusers/controlnet-depth-sdxl-1.0",
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True
)
adapter_depth = T2IAdapter.from_pretrained(
"TencentARC/t2i-adapter-depth-midas-sdxl-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
)
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True)
pipe = StableDiffusionXLControlNetAdapterInpaintPipeline.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
controlnet=controlnet_depth,
adapter=adapter_depth,
vae=vae,
variant="fp16",
use_safetensors=True,
torch_dtype=torch.float16,
)
pipe = pipe.to("cuda")
pipe.enable_xformers_memory_efficient_attention()
# pipe.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)
midas_depth = MidasDetector.from_pretrained(
"valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
).to("cuda")
prompt = "a tiger sitting on a park bench"
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
image = load_image(img_url).resize((1024, 1024))
mask_image = load_image(mask_url).resize((1024, 1024))
depth_image = midas_depth(
image, detect_resolution=512, image_resolution=1024
)
strength = 0.4
images = pipe(
prompt,
image=image,
mask_image=mask_image,
control_image=depth_image,
adapter_image=depth_image,
num_inference_steps=30,
controlnet_conditioning_scale=strength,
adapter_conditioning_scale=strength,
strength=0.7,
).images
images[0].save("controlnet_and_adapter_inpaint.png")
```
## Diffusion Posterior Sampling Pipeline
* Reference paper
```
@article{chung2022diffusion,
title={Diffusion posterior sampling for general noisy inverse problems},
author={Chung, Hyungjin and Kim, Jeongsol and Mccann, Michael T and Klasky, Marc L and Ye, Jong Chul},
journal={arXiv preprint arXiv:2209.14687},
year={2022}
}
```
* This pipeline allows zero-shot conditional sampling from the posterior distribution $p(x|y)$, given observation on $y$, unconditional generative model $p(x)$ and differentiable operator $y=f(x)$.
* For example, $f(.)$ can be downsample operator, then $y$ is a downsampled image, and the pipeline becomes a super-resolution pipeline.
* To use this pipeline, you need to know your operator $f(.)$ and corrupted image $y$, and pass them during the call. For example, as in the main function of dps_pipeline.py, you need to first define the Gaussian blurring operator $f(.)$. The operator should be a callable nn.Module, with all the parameter gradient disabled:
```python
import torch.nn.functional as F
import scipy
from torch import nn
# define the Gaussian blurring operator first
class GaussialBlurOperator(nn.Module):
def __init__(self, kernel_size, intensity):
super().__init__()
class Blurkernel(nn.Module):
def __init__(self, blur_type='gaussian', kernel_size=31, std=3.0):
super().__init__()
self.blur_type = blur_type
self.kernel_size = kernel_size
self.std = std
self.seq = nn.Sequential(
nn.ReflectionPad2d(self.kernel_size//2),
nn.Conv2d(3, 3, self.kernel_size, stride=1, padding=0, bias=False, groups=3)
)
self.weights_init()
def forward(self, x):
return self.seq(x)
def weights_init(self):
if self.blur_type == "gaussian":
n = np.zeros((self.kernel_size, self.kernel_size))
n[self.kernel_size // 2,self.kernel_size // 2] = 1
k = scipy.ndimage.gaussian_filter(n, sigma=self.std)
k = torch.from_numpy(k)
self.k = k
for name, f in self.named_parameters():
f.data.copy_(k)
elif self.blur_type == "motion":
k = Kernel(size=(self.kernel_size, self.kernel_size), intensity=self.std).kernelMatrix
k = torch.from_numpy(k)
self.k = k
for name, f in self.named_parameters():
f.data.copy_(k)
def update_weights(self, k):
if not torch.is_tensor(k):
k = torch.from_numpy(k)
for name, f in self.named_parameters():
f.data.copy_(k)
def get_kernel(self):
return self.k
self.kernel_size = kernel_size
self.conv = Blurkernel(blur_type='gaussian',
kernel_size=kernel_size,
std=intensity)
self.kernel = self.conv.get_kernel()
self.conv.update_weights(self.kernel.type(torch.float32))
for param in self.parameters():
param.requires_grad=False
def forward(self, data, **kwargs):
return self.conv(data)
def transpose(self, data, **kwargs):
return data
def get_kernel(self):
return self.kernel.view(1, 1, self.kernel_size, self.kernel_size)
```
* Next, you should obtain the corrupted image $y$ by the operator. In this example, we generate $y$ from the source image $x$. However in practice, having the operator $f(.)$ and corrupted image $y$ is enough:
```python
# set up source image
src = Image.open('sample.png')
# read image into [1,3,H,W]
src = torch.from_numpy(np.array(src, dtype=np.float32)).permute(2,0,1)[None]
# normalize image to [-1,1]
src = (src / 127.5) - 1.0
src = src.to("cuda")
# set up operator and measurement
operator = GaussialBlurOperator(kernel_size=61, intensity=3.0).to("cuda")
measurement = operator(src)
# save the source and corrupted images
save_image((src+1.0)/2.0, "dps_src.png")
save_image((measurement+1.0)/2.0, "dps_mea.png")
```
* We provide an example pair of saved source and corrupted images, using the Gaussian blur operator above
* Source image:
* ![sample](https://github.com/tongdaxu/Images/assets/22267548/4d2a1216-08d1-4aeb-9ce3-7a2d87561d65)
* Gaussian blurred image:
* ![ddpm_generated_image](https://github.com/tongdaxu/Images/assets/22267548/65076258-344b-4ed8-b704-a04edaade8ae)
* You can download those image to run the example on your own.
* Next, we need to define a loss function used for diffusion posterior sample. For most of the cases, the RMSE is fine:
```python
def RMSELoss(yhat, y):
return torch.sqrt(torch.sum((yhat-y)**2))
```
* And next, as any other diffusion models, we need the score estimator and scheduler. As we are working with $256x256$ face images, we use ddmp-celebahq-256:
```python
# set up scheduler
scheduler = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256")
scheduler.set_timesteps(1000)
# set up model
model = UNet2DModel.from_pretrained("google/ddpm-celebahq-256").to("cuda")
```
* And finally, run the pipeline:
```python
# finally, the pipeline
dpspipe = DPSPipeline(model, scheduler)
image = dpspipe(
measurement = measurement,
operator = operator,
loss_fn = RMSELoss,
zeta = 1.0,
).images[0]
image.save("dps_generated_image.png")
```
* The zeta is a hyperparameter that is in range of $[0,1]$. It need to be tuned for best effect. By setting zeta=1, you should be able to have the reconstructed result:
* Reconstructed image:
* ![sample](https://github.com/tongdaxu/Images/assets/22267548/0ceb5575-d42e-4f0b-99c0-50e69c982209)
* The reconstruction is perceptually similar to the source image, but different in details.
* In dps_pipeline.py, we also provide a super-resolution example, which should produce:
* Downsampled image:
* ![dps_mea](https://github.com/tongdaxu/Images/assets/22267548/ff6a33d6-26f0-42aa-88ce-f8a76ba45a13)
* Reconstructed image:
* ![dps_generated_image](https://github.com/tongdaxu/Images/assets/22267548/b74f084d-93f4-4845-83d8-44c0fa758a5f)
+466
View File
@@ -0,0 +1,466 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from math import pi
from typing import Callable, List, Optional, Tuple, Union
import numpy as np
import torch
from PIL import Image
from diffusers import DDPMScheduler, DiffusionPipeline, ImagePipelineOutput, UNet2DModel
from diffusers.utils.torch_utils import randn_tensor
class DPSPipeline(DiffusionPipeline):
r"""
Pipeline for Diffusion Posterior Sampling.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Parameters:
unet ([`UNet2DModel`]):
A `UNet2DModel` to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
[`DDPMScheduler`], or [`DDIMScheduler`].
"""
model_cpu_offload_seq = "unet"
def __init__(self, unet, scheduler):
super().__init__()
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(
self,
measurement: torch.Tensor,
operator: torch.nn.Module,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
batch_size: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
num_inference_steps: int = 1000,
output_type: Optional[str] = "pil",
return_dict: bool = True,
zeta: float = 0.3,
) -> Union[ImagePipelineOutput, Tuple]:
r"""
The call function to the pipeline for generation.
Args:
measurement (`torch.Tensor`, *required*):
A 'torch.Tensor', the corrupted image
operator (`torch.nn.Module`, *required*):
A 'torch.nn.Module', the operator generating the corrupted image
loss_fn (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *required*):
A 'Callable[[torch.Tensor, torch.Tensor], torch.Tensor]', the loss function used
between the measurements, for most of the cases using RMSE is fine.
batch_size (`int`, *optional*, defaults to 1):
The number of images to generate.
generator (`torch.Generator`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
num_inference_steps (`int`, *optional*, defaults to 1000):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
Example:
```py
>>> from diffusers import DDPMPipeline
>>> # load model and scheduler
>>> pipe = DDPMPipeline.from_pretrained("google/ddpm-cat-256")
>>> # run pipeline in inference (sample random noise and denoise)
>>> image = pipe().images[0]
>>> # save image
>>> image.save("ddpm_generated_image.png")
```
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
returned where the first element is a list with the generated images
"""
# Sample gaussian noise to begin loop
if isinstance(self.unet.config.sample_size, int):
image_shape = (
batch_size,
self.unet.config.in_channels,
self.unet.config.sample_size,
self.unet.config.sample_size,
)
else:
image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)
if self.device.type == "mps":
# randn does not work reproducibly on mps
image = randn_tensor(image_shape, generator=generator)
image = image.to(self.device)
else:
image = randn_tensor(image_shape, generator=generator, device=self.device)
# set step values
self.scheduler.set_timesteps(num_inference_steps)
for t in self.progress_bar(self.scheduler.timesteps):
with torch.enable_grad():
# 1. predict noise model_output
image = image.requires_grad_()
model_output = self.unet(image, t).sample
# 2. compute previous image x'_{t-1} and original prediction x0_{t}
scheduler_out = self.scheduler.step(model_output, t, image, generator=generator)
image_pred, origi_pred = scheduler_out.prev_sample, scheduler_out.pred_original_sample
# 3. compute y'_t = f(x0_{t})
measurement_pred = operator(origi_pred)
# 4. compute loss = d(y, y'_t-1)
loss = loss_fn(measurement, measurement_pred)
loss.backward()
print("distance: {0:.4f}".format(loss.item()))
with torch.no_grad():
image_pred = image_pred - zeta * image.grad
image = image_pred.detach()
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
if __name__ == "__main__":
import scipy
from torch import nn
from torchvision.utils import save_image
# defining the operators f(.) of y = f(x)
# super-resolution operator
class SuperResolutionOperator(nn.Module):
def __init__(self, in_shape, scale_factor):
super().__init__()
# Resizer local class, do not use outiside the SR operator class
class Resizer(nn.Module):
def __init__(self, in_shape, scale_factor=None, output_shape=None, kernel=None, antialiasing=True):
super(Resizer, self).__init__()
# First standardize values and fill missing arguments (if needed) by deriving scale from output shape or vice versa
scale_factor, output_shape = self.fix_scale_and_size(in_shape, output_shape, scale_factor)
# Choose interpolation method, each method has the matching kernel size
def cubic(x):
absx = np.abs(x)
absx2 = absx**2
absx3 = absx**3
return (1.5 * absx3 - 2.5 * absx2 + 1) * (absx <= 1) + (
-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2
) * ((1 < absx) & (absx <= 2))
def lanczos2(x):
return (
(np.sin(pi * x) * np.sin(pi * x / 2) + np.finfo(np.float32).eps)
/ ((pi**2 * x**2 / 2) + np.finfo(np.float32).eps)
) * (abs(x) < 2)
def box(x):
return ((-0.5 <= x) & (x < 0.5)) * 1.0
def lanczos3(x):
return (
(np.sin(pi * x) * np.sin(pi * x / 3) + np.finfo(np.float32).eps)
/ ((pi**2 * x**2 / 3) + np.finfo(np.float32).eps)
) * (abs(x) < 3)
def linear(x):
return (x + 1) * ((-1 <= x) & (x < 0)) + (1 - x) * ((0 <= x) & (x <= 1))
method, kernel_width = {
"cubic": (cubic, 4.0),
"lanczos2": (lanczos2, 4.0),
"lanczos3": (lanczos3, 6.0),
"box": (box, 1.0),
"linear": (linear, 2.0),
None: (cubic, 4.0), # set default interpolation method as cubic
}.get(kernel)
# Antialiasing is only used when downscaling
antialiasing *= np.any(np.array(scale_factor) < 1)
# Sort indices of dimensions according to scale of each dimension. since we are going dim by dim this is efficient
sorted_dims = np.argsort(np.array(scale_factor))
self.sorted_dims = [int(dim) for dim in sorted_dims if scale_factor[dim] != 1]
# Iterate over dimensions to calculate local weights for resizing and resize each time in one direction
field_of_view_list = []
weights_list = []
for dim in self.sorted_dims:
# for each coordinate (along 1 dim), calculate which coordinates in the input image affect its result and the
# weights that multiply the values there to get its result.
weights, field_of_view = self.contributions(
in_shape[dim], output_shape[dim], scale_factor[dim], method, kernel_width, antialiasing
)
# convert to torch tensor
weights = torch.tensor(weights.T, dtype=torch.float32)
# We add singleton dimensions to the weight matrix so we can multiply it with the big tensor we get for
# tmp_im[field_of_view.T], (bsxfun style)
weights_list.append(
nn.Parameter(
torch.reshape(weights, list(weights.shape) + (len(scale_factor) - 1) * [1]),
requires_grad=False,
)
)
field_of_view_list.append(
nn.Parameter(
torch.tensor(field_of_view.T.astype(np.int32), dtype=torch.long), requires_grad=False
)
)
self.field_of_view = nn.ParameterList(field_of_view_list)
self.weights = nn.ParameterList(weights_list)
def forward(self, in_tensor):
x = in_tensor
# Use the affecting position values and the set of weights to calculate the result of resizing along this 1 dim
for dim, fov, w in zip(self.sorted_dims, self.field_of_view, self.weights):
# To be able to act on each dim, we swap so that dim 0 is the wanted dim to resize
x = torch.transpose(x, dim, 0)
# This is a bit of a complicated multiplication: x[field_of_view.T] is a tensor of order image_dims+1.
# for each pixel in the output-image it matches the positions the influence it from the input image (along 1 dim
# only, this is why it only adds 1 dim to 5the shape). We then multiply, for each pixel, its set of positions with
# the matching set of weights. we do this by this big tensor element-wise multiplication (MATLAB bsxfun style:
# matching dims are multiplied element-wise while singletons mean that the matching dim is all multiplied by the
# same number
x = torch.sum(x[fov] * w, dim=0)
# Finally we swap back the axes to the original order
x = torch.transpose(x, dim, 0)
return x
def fix_scale_and_size(self, input_shape, output_shape, scale_factor):
# First fixing the scale-factor (if given) to be standardized the function expects (a list of scale factors in the
# same size as the number of input dimensions)
if scale_factor is not None:
# By default, if scale-factor is a scalar we assume 2d resizing and duplicate it.
if np.isscalar(scale_factor) and len(input_shape) > 1:
scale_factor = [scale_factor, scale_factor]
# We extend the size of scale-factor list to the size of the input by assigning 1 to all the unspecified scales
scale_factor = list(scale_factor)
scale_factor = [1] * (len(input_shape) - len(scale_factor)) + scale_factor
# Fixing output-shape (if given): extending it to the size of the input-shape, by assigning the original input-size
# to all the unspecified dimensions
if output_shape is not None:
output_shape = list(input_shape[len(output_shape) :]) + list(np.uint(np.array(output_shape)))
# Dealing with the case of non-give scale-factor, calculating according to output-shape. note that this is
# sub-optimal, because there can be different scales to the same output-shape.
if scale_factor is None:
scale_factor = 1.0 * np.array(output_shape) / np.array(input_shape)
# Dealing with missing output-shape. calculating according to scale-factor
if output_shape is None:
output_shape = np.uint(np.ceil(np.array(input_shape) * np.array(scale_factor)))
return scale_factor, output_shape
def contributions(self, in_length, out_length, scale, kernel, kernel_width, antialiasing):
# This function calculates a set of 'filters' and a set of field_of_view that will later on be applied
# such that each position from the field_of_view will be multiplied with a matching filter from the
# 'weights' based on the interpolation method and the distance of the sub-pixel location from the pixel centers
# around it. This is only done for one dimension of the image.
# When anti-aliasing is activated (default and only for downscaling) the receptive field is stretched to size of
# 1/sf. this means filtering is more 'low-pass filter'.
fixed_kernel = (lambda arg: scale * kernel(scale * arg)) if antialiasing else kernel
kernel_width *= 1.0 / scale if antialiasing else 1.0
# These are the coordinates of the output image
out_coordinates = np.arange(1, out_length + 1)
# since both scale-factor and output size can be provided simulatneously, perserving the center of the image requires shifting
# the output coordinates. the deviation is because out_length doesn't necesary equal in_length*scale.
# to keep the center we need to subtract half of this deivation so that we get equal margins for boths sides and center is preserved.
shifted_out_coordinates = out_coordinates - (out_length - in_length * scale) / 2
# These are the matching positions of the output-coordinates on the input image coordinates.
# Best explained by example: say we have 4 horizontal pixels for HR and we downscale by SF=2 and get 2 pixels:
# [1,2,3,4] -> [1,2]. Remember each pixel number is the middle of the pixel.
# The scaling is done between the distances and not pixel numbers (the right boundary of pixel 4 is transformed to
# the right boundary of pixel 2. pixel 1 in the small image matches the boundary between pixels 1 and 2 in the big
# one and not to pixel 2. This means the position is not just multiplication of the old pos by scale-factor).
# So if we measure distance from the left border, middle of pixel 1 is at distance d=0.5, border between 1 and 2 is
# at d=1, and so on (d = p - 0.5). we calculate (d_new = d_old / sf) which means:
# (p_new-0.5 = (p_old-0.5) / sf) -> p_new = p_old/sf + 0.5 * (1-1/sf)
match_coordinates = shifted_out_coordinates / scale + 0.5 * (1 - 1 / scale)
# This is the left boundary to start multiplying the filter from, it depends on the size of the filter
left_boundary = np.floor(match_coordinates - kernel_width / 2)
# Kernel width needs to be enlarged because when covering has sub-pixel borders, it must 'see' the pixel centers
# of the pixels it only covered a part from. So we add one pixel at each side to consider (weights can zeroize them)
expanded_kernel_width = np.ceil(kernel_width) + 2
# Determine a set of field_of_view for each each output position, these are the pixels in the input image
# that the pixel in the output image 'sees'. We get a matrix whos horizontal dim is the output pixels (big) and the
# vertical dim is the pixels it 'sees' (kernel_size + 2)
field_of_view = np.squeeze(
np.int16(np.expand_dims(left_boundary, axis=1) + np.arange(expanded_kernel_width) - 1)
)
# Assign weight to each pixel in the field of view. A matrix whos horizontal dim is the output pixels and the
# vertical dim is a list of weights matching to the pixel in the field of view (that are specified in
# 'field_of_view')
weights = fixed_kernel(1.0 * np.expand_dims(match_coordinates, axis=1) - field_of_view - 1)
# Normalize weights to sum up to 1. be careful from dividing by 0
sum_weights = np.sum(weights, axis=1)
sum_weights[sum_weights == 0] = 1.0
weights = 1.0 * weights / np.expand_dims(sum_weights, axis=1)
# We use this mirror structure as a trick for reflection padding at the boundaries
mirror = np.uint(np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1))))
field_of_view = mirror[np.mod(field_of_view, mirror.shape[0])]
# Get rid of weights and pixel positions that are of zero weight
non_zero_out_pixels = np.nonzero(np.any(weights, axis=0))
weights = np.squeeze(weights[:, non_zero_out_pixels])
field_of_view = np.squeeze(field_of_view[:, non_zero_out_pixels])
# Final products are the relative positions and the matching weights, both are output_size X fixed_kernel_size
return weights, field_of_view
self.down_sample = Resizer(in_shape, 1 / scale_factor)
for param in self.parameters():
param.requires_grad = False
def forward(self, data, **kwargs):
return self.down_sample(data)
# Gaussian blurring operator
class GaussialBlurOperator(nn.Module):
def __init__(self, kernel_size, intensity):
super().__init__()
class Blurkernel(nn.Module):
def __init__(self, blur_type="gaussian", kernel_size=31, std=3.0):
super().__init__()
self.blur_type = blur_type
self.kernel_size = kernel_size
self.std = std
self.seq = nn.Sequential(
nn.ReflectionPad2d(self.kernel_size // 2),
nn.Conv2d(3, 3, self.kernel_size, stride=1, padding=0, bias=False, groups=3),
)
self.weights_init()
def forward(self, x):
return self.seq(x)
def weights_init(self):
if self.blur_type == "gaussian":
n = np.zeros((self.kernel_size, self.kernel_size))
n[self.kernel_size // 2, self.kernel_size // 2] = 1
k = scipy.ndimage.gaussian_filter(n, sigma=self.std)
k = torch.from_numpy(k)
self.k = k
for name, f in self.named_parameters():
f.data.copy_(k)
def update_weights(self, k):
if not torch.is_tensor(k):
k = torch.from_numpy(k)
for name, f in self.named_parameters():
f.data.copy_(k)
def get_kernel(self):
return self.k
self.kernel_size = kernel_size
self.conv = Blurkernel(blur_type="gaussian", kernel_size=kernel_size, std=intensity)
self.kernel = self.conv.get_kernel()
self.conv.update_weights(self.kernel.type(torch.float32))
for param in self.parameters():
param.requires_grad = False
def forward(self, data, **kwargs):
return self.conv(data)
def transpose(self, data, **kwargs):
return data
def get_kernel(self):
return self.kernel.view(1, 1, self.kernel_size, self.kernel_size)
# assuming the forward process y = f(x) is polluted by Gaussian noise, use l2 norm
def RMSELoss(yhat, y):
return torch.sqrt(torch.sum((yhat - y) ** 2))
# set up source image
src = Image.open("sample.png")
# read image into [1,3,H,W]
src = torch.from_numpy(np.array(src, dtype=np.float32)).permute(2, 0, 1)[None]
# normalize image to [-1,1]
src = (src / 127.5) - 1.0
src = src.to("cuda")
# set up operator and measurement
# operator = SuperResolutionOperator(in_shape=src.shape, scale_factor=4).to("cuda")
operator = GaussialBlurOperator(kernel_size=61, intensity=3.0).to("cuda")
measurement = operator(src)
# set up scheduler
scheduler = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256")
scheduler.set_timesteps(1000)
# set up model
model = UNet2DModel.from_pretrained("google/ddpm-celebahq-256").to("cuda")
save_image((src + 1.0) / 2.0, "dps_src.png")
save_image((measurement + 1.0) / 2.0, "dps_mea.png")
# finally, the pipeline
dpspipe = DPSPipeline(model, scheduler)
image = dpspipe(
measurement=measurement,
operator=operator,
loss_fn=RMSELoss,
zeta=1.0,
).images[0]
image.save("dps_generated_image.png")
+14 -10
View File
@@ -250,6 +250,7 @@ def get_weighted_text_embeddings_sdxl(
neg_prompt: str = "",
neg_prompt_2: str = None,
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
):
"""
This function can process long prompt with weights, no length limitation
@@ -262,10 +263,13 @@ def get_weighted_text_embeddings_sdxl(
neg_prompt (str)
neg_prompt_2 (str)
num_images_per_prompt (int)
device (torch.device)
Returns:
prompt_embeds (torch.Tensor)
neg_prompt_embeds (torch.Tensor)
"""
device = device or pipe._execution_device
if prompt_2:
prompt = f"{prompt} {prompt_2}"
@@ -330,17 +334,17 @@ def get_weighted_text_embeddings_sdxl(
# get prompt embeddings one by one is not working.
for i in range(len(prompt_token_groups)):
# get positive prompt embeddings with weights
token_tensor = torch.tensor([prompt_token_groups[i]], dtype=torch.long, device=pipe.device)
weight_tensor = torch.tensor(prompt_weight_groups[i], dtype=torch.float16, device=pipe.device)
token_tensor = torch.tensor([prompt_token_groups[i]], dtype=torch.long, device=device)
weight_tensor = torch.tensor(prompt_weight_groups[i], dtype=torch.float16, device=device)
token_tensor_2 = torch.tensor([prompt_token_groups_2[i]], dtype=torch.long, device=pipe.device)
token_tensor_2 = torch.tensor([prompt_token_groups_2[i]], dtype=torch.long, device=device)
# use first text encoder
prompt_embeds_1 = pipe.text_encoder(token_tensor.to(pipe.device), output_hidden_states=True)
prompt_embeds_1 = pipe.text_encoder(token_tensor.to(device), output_hidden_states=True)
prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-2]
# use second text encoder
prompt_embeds_2 = pipe.text_encoder_2(token_tensor_2.to(pipe.device), output_hidden_states=True)
prompt_embeds_2 = pipe.text_encoder_2(token_tensor_2.to(device), output_hidden_states=True)
prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-2]
pooled_prompt_embeds = prompt_embeds_2[0]
@@ -357,16 +361,16 @@ def get_weighted_text_embeddings_sdxl(
embeds.append(token_embedding)
# get negative prompt embeddings with weights
neg_token_tensor = torch.tensor([neg_prompt_token_groups[i]], dtype=torch.long, device=pipe.device)
neg_token_tensor_2 = torch.tensor([neg_prompt_token_groups_2[i]], dtype=torch.long, device=pipe.device)
neg_weight_tensor = torch.tensor(neg_prompt_weight_groups[i], dtype=torch.float16, device=pipe.device)
neg_token_tensor = torch.tensor([neg_prompt_token_groups[i]], dtype=torch.long, device=device)
neg_token_tensor_2 = torch.tensor([neg_prompt_token_groups_2[i]], dtype=torch.long, device=device)
neg_weight_tensor = torch.tensor(neg_prompt_weight_groups[i], dtype=torch.float16, device=device)
# use first text encoder
neg_prompt_embeds_1 = pipe.text_encoder(neg_token_tensor.to(pipe.device), output_hidden_states=True)
neg_prompt_embeds_1 = pipe.text_encoder(neg_token_tensor.to(device), output_hidden_states=True)
neg_prompt_embeds_1_hidden_states = neg_prompt_embeds_1.hidden_states[-2]
# use second text encoder
neg_prompt_embeds_2 = pipe.text_encoder_2(neg_token_tensor_2.to(pipe.device), output_hidden_states=True)
neg_prompt_embeds_2 = pipe.text_encoder_2(neg_token_tensor_2.to(device), output_hidden_states=True)
neg_prompt_embeds_2_hidden_states = neg_prompt_embeds_2.hidden_states[-2]
negative_pooled_prompt_embeds = neg_prompt_embeds_2[0]
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+2 -2
View File
@@ -21,7 +21,7 @@ from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from diffusers.configuration_utils import FrozenDict
from diffusers.loaders import TextualInversionLoaderMixin
from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
@@ -62,7 +62,7 @@ EXAMPLE_DOC_STRING = """
"""
class StableDiffusionIPEXPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
class StableDiffusionIPEXPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
r"""
Pipeline for text-to-image generation using Stable Diffusion on IPEX.
@@ -657,6 +657,15 @@ def parse_args():
default=0.001,
help="The huber loss parameter. Only used if `--loss_type=huber`.",
)
parser.add_argument(
"--unet_time_cond_proj_dim",
type=int,
default=256,
help=(
"The dimension of the guidance scale embedding in the U-Net, which will be used if the teacher U-Net"
" does not have `time_cond_proj_dim` set."
),
)
# ----Exponential Moving Average (EMA)----
parser.add_argument(
"--ema_decay",
@@ -1138,7 +1147,7 @@ def main(args):
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
w_embedding = guidance_scale_embedding(w, embedding_dim=args.unet_time_cond_proj_dim)
w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim)
w = w.reshape(bsz, 1, 1, 1)
# Move to U-Net device and dtype
w = w.to(device=latents.device, dtype=latents.dtype)
@@ -677,6 +677,15 @@ def parse_args():
default=0.001,
help="The huber loss parameter. Only used if `--loss_type=huber`.",
)
parser.add_argument(
"--unet_time_cond_proj_dim",
type=int,
default=256,
help=(
"The dimension of the guidance scale embedding in the U-Net, which will be used if the teacher U-Net"
" does not have `time_cond_proj_dim` set."
),
)
# ----Exponential Moving Average (EMA)----
parser.add_argument(
"--ema_decay",
@@ -1233,6 +1242,7 @@ def main(args):
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim)
w = w.reshape(bsz, 1, 1, 1)
w = w.to(device=latents.device, dtype=latents.dtype)
@@ -1243,7 +1253,7 @@ def main(args):
noise_pred = unet(
noisy_model_input,
start_timesteps,
timestep_cond=None,
timestep_cond=w_embedding,
encoder_hidden_states=prompt_embeds.float(),
added_cond_kwargs=encoded_text,
).sample
@@ -1308,7 +1318,7 @@ def main(args):
target_noise_pred = target_unet(
x_prev.float(),
timesteps,
timestep_cond=None,
timestep_cond=w_embedding,
encoder_hidden_states=prompt_embeds.float(),
added_cond_kwargs=encoded_text,
).sample
+13 -7
View File
@@ -86,6 +86,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler
controlnet=controlnet,
safety_checker=None,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
@@ -249,10 +250,13 @@ def parse_args(input_args=None):
type=str,
default=None,
required=False,
help=(
"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
" float32 precision."
),
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--tokenizer_name",
@@ -767,11 +771,13 @@ def main(args):
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = text_encoder_cls.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
if args.controlnet_model_name_or_path:
+20 -9
View File
@@ -74,6 +74,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
unet=unet,
controlnet=controlnet,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
@@ -243,15 +244,18 @@ def parse_args(input_args=None):
help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
" If not specified controlnet weights are initialized from unet.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help=(
"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
" float32 precision."
),
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--tokenizer_name",
@@ -793,10 +797,16 @@ def main(args):
# Load the tokenizers
tokenizer_one = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
use_fast=False,
)
tokenizer_two = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
args.pretrained_model_name_or_path,
subfolder="tokenizer_2",
revision=args.revision,
use_fast=False,
)
# import correct text encoder classes
@@ -810,10 +820,10 @@ def main(args):
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
text_encoder_two = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
)
vae_path = (
args.pretrained_model_name_or_path
@@ -824,9 +834,10 @@ def main(args):
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
if args.controlnet_model_name_or_path:
@@ -332,6 +332,12 @@ def parse_args(input_args=None):
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--tokenizer_name",
type=str,
@@ -740,6 +746,7 @@ def main(args):
torch_dtype=torch_dtype,
safety_checker=None,
revision=args.revision,
variant=args.variant,
)
pipeline.set_progress_bar_config(disable=True)
@@ -801,11 +808,13 @@ def main(args):
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = text_encoder_cls.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
# Adding a modifier token which is optimized ####
@@ -1229,6 +1238,7 @@ def main(args):
text_encoder=accelerator.unwrap_model(text_encoder),
tokenizer=tokenizer,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
@@ -1278,7 +1288,7 @@ def main(args):
# Final inference
# Load previous pipeline
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline = pipeline.to(accelerator.device)
+13 -7
View File
@@ -139,6 +139,7 @@ def log_validation(
text_encoder=text_encoder,
unet=accelerator.unwrap_model(unet),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
**pipeline_args,
)
@@ -239,10 +240,13 @@ def parse_args(input_args=None):
type=str,
default=None,
required=False,
help=(
"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
" float32 precision."
),
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--tokenizer_name",
@@ -859,6 +863,7 @@ def main(args):
torch_dtype=torch_dtype,
safety_checker=None,
revision=args.revision,
variant=args.variant,
)
pipeline.set_progress_bar_config(disable=True)
@@ -912,18 +917,18 @@ def main(args):
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = text_encoder_cls.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
if model_has_vae(args):
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
)
else:
vae = None
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
@@ -1379,6 +1384,7 @@ def main(args):
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
revision=args.revision,
variant=args.variant,
**pipeline_args,
)
+8 -2
View File
@@ -460,7 +460,10 @@ def main():
# Load models and create wrapper for stable diffusion
text_encoder = FlaxCLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype, revision=args.revision
args.pretrained_model_name_or_path,
subfolder="text_encoder",
dtype=weight_dtype,
revision=args.revision,
)
vae, vae_params = FlaxAutoencoderKL.from_pretrained(
vae_arg,
@@ -468,7 +471,10 @@ def main():
**vae_kwargs,
)
unet, unet_params = FlaxUNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", dtype=weight_dtype, revision=args.revision
args.pretrained_model_name_or_path,
subfolder="unet",
dtype=weight_dtype,
revision=args.revision,
)
# Optimization
+46 -5
View File
@@ -57,7 +57,7 @@ from diffusers.models.attention_processor import (
AttnAddedKVProcessor2_0,
SlicedAttnAddedKVProcessor,
)
from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
from diffusers.training_utils import unet_lora_state_dict
from diffusers.utils import check_min_version, is_wandb_available
@@ -70,6 +70,39 @@ check_min_version("0.24.0.dev0")
logger = get_logger(__name__)
# TODO: This function should be removed once training scripts are rewritten in PEFT
def text_encoder_lora_state_dict(text_encoder):
state_dict = {}
def text_encoder_attn_modules(text_encoder):
from transformers import CLIPTextModel, CLIPTextModelWithProjection
attn_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn
attn_modules.append((name, mod))
return attn_modules
for name, module in text_encoder_attn_modules(text_encoder):
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
return state_dict
def save_model_card(
repo_id: str,
images=None,
@@ -150,6 +183,12 @@ def parse_args(input_args=None):
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--tokenizer_name",
type=str,
@@ -717,6 +756,7 @@ def main(args):
torch_dtype=torch_dtype,
safety_checker=None,
revision=args.revision,
variant=args.variant,
)
pipeline.set_progress_bar_config(disable=True)
@@ -770,11 +810,11 @@ def main(args):
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = text_encoder_cls.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
try:
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
)
except OSError:
# IF does not have a VAE so let's just set it to None
@@ -782,7 +822,7 @@ def main(args):
vae = None
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
# We only train the additional adapter LoRA layers
@@ -1277,6 +1317,7 @@ def main(args):
unet=accelerator.unwrap_model(unet),
text_encoder=None if args.pre_compute_text_embeddings else accelerator.unwrap_model(text_encoder),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
@@ -1362,7 +1403,7 @@ def main(args):
# Final inference
# Load previous pipeline
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype
)
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
+498 -110
View File
@@ -50,9 +50,9 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.loaders import LoraLoaderMixin
from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
from diffusers.training_utils import unet_lora_state_dict
from diffusers.training_utils import compute_snr, unet_lora_state_dict
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
@@ -63,37 +63,100 @@ check_min_version("0.24.0.dev0")
logger = get_logger(__name__)
# TODO: This function should be removed once training scripts are rewritten in PEFT
def text_encoder_lora_state_dict(text_encoder):
state_dict = {}
def text_encoder_attn_modules(text_encoder):
from transformers import CLIPTextModel, CLIPTextModelWithProjection
attn_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn
attn_modules.append((name, mod))
return attn_modules
for name, module in text_encoder_attn_modules(text_encoder):
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
return state_dict
def save_model_card(
repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, vae_path=None
repo_id: str,
images=None,
base_model=str,
train_text_encoder=False,
instance_prompt=str,
validation_prompt=str,
repo_folder=None,
vae_path=None,
):
img_str = ""
img_str = "widget:\n" if images else ""
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n"
img_str += f"""
- text: '{validation_prompt if validation_prompt else ' ' }'
output:
url:
"image_{i}.png"
"""
yaml = f"""
---
license: openrail++
base_model: {base_model}
instance_prompt: {prompt}
tags:
- stable-diffusion-xl
- stable-diffusion-xl-diffusers
- text-to-image
- diffusers
- lora
inference: true
- template:sd-lora
{img_str}
base_model: {base_model}
instance_prompt: {instance_prompt}
license: openrail++
---
"""
model_card = f"""
# LoRA DreamBooth - {repo_id}
These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
{img_str}
model_card = f"""
# SDXL LoRA DreamBooth - {repo_id}
<Gallery />
## Model description
These are {repo_id} LoRA adaption weights for {base_model}.
The weights were trained using [DreamBooth](https://dreambooth.github.io/).
LoRA for the text encoder was enabled: {train_text_encoder}.
Special VAE used for training: {vae_path}.
## Trigger words
You should use {instance_prompt} to trigger the image generation.
## Download model
Weights for this model are available in Safetensors format.
[Download]({repo_id}/tree/main) them in the Files & versions tab.
"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
@@ -141,13 +204,59 @@ def parse_args(input_args=None):
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--dataset_name",
type=str,
default=None,
help=(
"The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
" or to a folder containing files that 🤗 Datasets can understand."
),
)
parser.add_argument(
"--dataset_config_name",
type=str,
default=None,
help="The config of the Dataset, leave as None if there's only one config.",
)
parser.add_argument(
"--instance_data_dir",
type=str,
default=None,
required=True,
help="A folder containing the training data of instance images.",
help=("A folder containing the training data. "),
)
parser.add_argument(
"--cache_dir",
type=str,
default=None,
help="The directory where the downloaded models and datasets will be stored.",
)
parser.add_argument(
"--image_column",
type=str,
default="image",
help="The column of the dataset containing the target image. By "
"default, the standard Image Dataset maps out 'file_name' "
"to 'image'.",
)
parser.add_argument(
"--caption_column",
type=str,
default=None,
help="The column of the dataset containing the instance prompt for each image",
)
parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
parser.add_argument(
"--class_data_dir",
type=str,
@@ -160,7 +269,7 @@ def parse_args(input_args=None):
type=str,
default=None,
required=True,
help="The prompt with identifier specifying the instance",
help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
)
parser.add_argument(
"--class_prompt",
@@ -299,9 +408,16 @@ def parse_args(input_args=None):
parser.add_argument(
"--learning_rate",
type=float,
default=5e-4,
default=1e-4,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--text_encoder_lr",
type=float,
default=5e-6,
help="Text encoder learning rate to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
@@ -317,6 +433,14 @@ def parse_args(input_args=None):
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--snr_gamma",
type=float,
default=None,
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
"More details here: https://arxiv.org/abs/2303.09556.",
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
@@ -335,13 +459,59 @@ def parse_args(input_args=None):
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
"--optimizer",
type=str,
default="AdamW",
help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
)
parser.add_argument(
"--use_8bit_adam",
action="store_true",
help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
)
parser.add_argument(
"--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
)
parser.add_argument(
"--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
)
parser.add_argument(
"--prodigy_beta3",
type=float,
default=None,
help="coefficients for computing the Prodidy stepsize using running averages. If set to None, "
"uses the value of square root of beta2. Ignored if optimizer is adamW",
)
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
parser.add_argument(
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
)
parser.add_argument(
"--adam_epsilon",
type=float,
default=1e-08,
help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
)
parser.add_argument(
"--prodigy_use_bias_correction",
type=bool,
default=True,
help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
)
parser.add_argument(
"--prodigy_safeguard_warmup",
type=bool,
default=True,
help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
"Ignored if optimizer is adamW",
)
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
@@ -414,6 +584,12 @@ def parse_args(input_args=None):
else:
args = parser.parse_args()
if args.dataset_name is None and args.instance_data_dir is None:
raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`")
if args.dataset_name is not None and args.instance_data_dir is not None:
raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`")
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
@@ -442,20 +618,84 @@ class DreamBoothDataset(Dataset):
def __init__(
self,
instance_data_root,
instance_prompt,
class_prompt,
class_data_root=None,
class_num=None,
size=1024,
repeats=1,
center_crop=False,
):
self.size = size
self.center_crop = center_crop
self.instance_data_root = Path(instance_data_root)
if not self.instance_data_root.exists():
raise ValueError("Instance images root doesn't exists.")
self.instance_prompt = instance_prompt
self.custom_instance_prompts = None
self.class_prompt = class_prompt
self.instance_images_path = list(Path(instance_data_root).iterdir())
self.num_instance_images = len(self.instance_images_path)
# if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
# we load the training data using load_dataset
if args.dataset_name is not None:
try:
from datasets import load_dataset
except ImportError:
raise ImportError(
"You are trying to load your data using the datasets library. If you wish to train using custom "
"captions please install the datasets library: `pip install datasets`. If you wish to load a "
"local folder containing images only, specify --instance_data_dir instead."
)
# Downloading and loading a dataset from the hub.
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
dataset = load_dataset(
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
)
# Preprocessing the datasets.
column_names = dataset["train"].column_names
# 6. Get the column names for input/target.
if args.image_column is None:
image_column = column_names[0]
logger.info(f"image column defaulting to {image_column}")
else:
image_column = args.image_column
if image_column not in column_names:
raise ValueError(
f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
instance_images = dataset["train"][image_column]
if args.caption_column is None:
logger.info(
"No caption column provided, defaulting to instance_prompt for all images. If your dataset "
"contains captions/prompts for the images, make sure to specify the "
"column as --caption_column"
)
self.custom_instance_prompts = None
else:
if args.caption_column not in column_names:
raise ValueError(
f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
custom_instance_prompts = dataset["train"][args.caption_column]
# create final list of captions according to --repeats
self.custom_instance_prompts = []
for caption in custom_instance_prompts:
self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))
else:
self.instance_data_root = Path(instance_data_root)
if not self.instance_data_root.exists():
raise ValueError("Instance images root doesn't exists.")
instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
self.custom_instance_prompts = None
self.instance_images = []
for img in instance_images:
self.instance_images.extend(itertools.repeat(img, repeats))
self.num_instance_images = len(self.instance_images)
self._length = self.num_instance_images
if class_data_root is not None:
@@ -484,13 +724,23 @@ class DreamBoothDataset(Dataset):
def __getitem__(self, index):
example = {}
instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
instance_image = self.instance_images[index % self.num_instance_images]
instance_image = exif_transpose(instance_image)
if not instance_image.mode == "RGB":
instance_image = instance_image.convert("RGB")
example["instance_images"] = self.image_transforms(instance_image)
if self.custom_instance_prompts:
caption = self.custom_instance_prompts[index % self.num_instance_images]
if caption:
example["instance_prompt"] = caption
else:
example["instance_prompt"] = self.instance_prompt
else: # costum prompts were provided, but length does not match size of image dataset
example["instance_prompt"] = self.instance_prompt
if self.class_data_root:
class_image = Image.open(self.class_images_path[index % self.num_class_images])
class_image = exif_transpose(class_image)
@@ -498,22 +748,25 @@ class DreamBoothDataset(Dataset):
if not class_image.mode == "RGB":
class_image = class_image.convert("RGB")
example["class_images"] = self.image_transforms(class_image)
example["class_prompt"] = self.class_prompt
return example
def collate_fn(examples, with_prior_preservation=False):
pixel_values = [example["instance_images"] for example in examples]
prompts = [example["instance_prompt"] for example in examples]
# Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes.
if with_prior_preservation:
pixel_values += [example["class_images"] for example in examples]
prompts += [example["class_prompt"] for example in examples]
pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
batch = {"pixel_values": pixel_values}
batch = {"pixel_values": pixel_values, "prompts": prompts}
return batch
@@ -630,6 +883,7 @@ def main(args):
args.pretrained_model_name_or_path,
torch_dtype=torch_dtype,
revision=args.revision,
variant=args.variant,
)
pipeline.set_progress_bar_config(disable=True)
@@ -668,10 +922,16 @@ def main(args):
# Load the tokenizers
tokenizer_one = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
use_fast=False,
)
tokenizer_two = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
args.pretrained_model_name_or_path,
subfolder="tokenizer_2",
revision=args.revision,
use_fast=False,
)
# import correct text encoder classes
@@ -685,10 +945,10 @@ def main(args):
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
text_encoder_two = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
)
vae_path = (
args.pretrained_model_name_or_path
@@ -696,10 +956,13 @@ def main(args):
else args.pretrained_vae_model_name_or_path
)
vae = AutoencoderKL.from_pretrained(
vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
# We only train the additional adapter LoRA layers
@@ -732,7 +995,8 @@ def main(args):
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, "
"please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
else:
@@ -866,35 +1130,119 @@ def main(args):
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
optimizer_class = bnb.optim.AdamW8bit
# Optimization parameters
unet_lora_parameters_with_lr = {"params": unet_lora_parameters, "lr": args.learning_rate}
if args.train_text_encoder:
# different learning rate for text encoder and unet
text_lora_parameters_one_with_lr = {
"params": text_lora_parameters_one,
"weight_decay": args.adam_weight_decay_text_encoder,
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
}
text_lora_parameters_two_with_lr = {
"params": text_lora_parameters_two,
"weight_decay": args.adam_weight_decay_text_encoder,
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
}
params_to_optimize = [
unet_lora_parameters_with_lr,
text_lora_parameters_one_with_lr,
text_lora_parameters_two_with_lr,
]
else:
optimizer_class = torch.optim.AdamW
params_to_optimize = [unet_lora_parameters_with_lr]
# Optimizer creation
params_to_optimize = (
itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two)
if args.train_text_encoder
else unet_lora_parameters
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
logger.warn(
f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
"Defaulting to adamW"
)
args.optimizer = "adamw"
if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
logger.warn(
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
f"set to {args.optimizer.lower()}"
)
if args.optimizer.lower() == "adamw":
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(
params_to_optimize,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
if args.optimizer.lower() == "prodigy":
try:
import prodigyopt
except ImportError:
raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
optimizer_class = prodigyopt.Prodigy
if args.learning_rate <= 0.1:
logger.warn(
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
)
if args.train_text_encoder and args.text_encoder_lr:
logger.warn(
f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:"
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
f"When using prodigy only learning_rate is used as the initial learning rate."
)
# changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
# --learning_rate
params_to_optimize[1]["lr"] = args.learning_rate
params_to_optimize[2]["lr"] = args.learning_rate
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
decouple=args.prodigy_decouple,
use_bias_correction=args.prodigy_use_bias_correction,
safeguard_warmup=args.prodigy_safeguard_warmup,
)
# Dataset and DataLoaders creation:
train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt,
class_prompt=args.class_prompt,
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
class_num=args.num_class_images,
size=args.resolution,
repeats=args.repeats,
center_crop=args.center_crop,
)
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.train_batch_size,
shuffle=True,
collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
num_workers=args.dataloader_num_workers,
)
# Computes additional embeddings/ids required by the SDXL UNet.
# regular text emebddings (when `train_text_encoder` is not True)
# regular text embeddings (when `train_text_encoder` is not True)
# pooled text embeddings
# time ids
@@ -921,7 +1269,11 @@ def main(args):
# Handle instance prompt.
instance_time_ids = compute_time_ids()
if not args.train_text_encoder:
# If no type of tuning is done on the text_encoder and custom instance prompts are NOT
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
# the redundant encoding.
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(
args.instance_prompt, text_encoders, tokenizers
)
@@ -934,49 +1286,36 @@ def main(args):
args.class_prompt, text_encoders, tokenizers
)
# Clear the memory here.
if not args.train_text_encoder:
# Clear the memory here
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
del tokenizers, text_encoders
gc.collect()
torch.cuda.empty_cache()
# Pack the statically computed variables appropriately. This is so that we don't
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
# have to pass them to the dataloader.
add_time_ids = instance_time_ids
if args.with_prior_preservation:
add_time_ids = torch.cat([add_time_ids, class_time_ids], dim=0)
if not args.train_text_encoder:
prompt_embeds = instance_prompt_hidden_states
unet_add_text_embeds = instance_pooled_prompt_embeds
if args.with_prior_preservation:
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)
else:
tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt)
if args.with_prior_preservation:
class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt)
class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt)
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
# Dataset and DataLoaders creation:
train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir,
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
class_num=args.num_class_images,
size=args.resolution,
center_crop=args.center_crop,
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.train_batch_size,
shuffle=True,
collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
num_workers=args.dataloader_num_workers,
)
if not train_dataset.custom_instance_prompts:
if not args.train_text_encoder:
prompt_embeds = instance_prompt_hidden_states
unet_add_text_embeds = instance_pooled_prompt_embeds
if args.with_prior_preservation:
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)
# if we're optmizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
# batch prompts on all training steps
else:
tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt)
if args.with_prior_preservation:
class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt)
class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt)
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
@@ -1079,6 +1418,17 @@ def main(args):
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
prompts = batch["prompts"]
# encode batch prompts when custom prompts are provided for each image -
if train_dataset.custom_instance_prompts:
if not args.train_text_encoder:
prompt_embeds, unet_add_text_embeds = compute_text_embeddings(
prompts, text_encoders, tokenizers
)
else:
tokens_one = tokenize_prompt(tokenizer_one, prompts)
tokens_two = tokenize_prompt(tokenizer_two, prompts)
# Convert images to latent space
model_input = vae.encode(pixel_values).latent_dist.sample()
@@ -1099,16 +1449,21 @@ def main(args):
# (this is the forward diffusion process)
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
# Calculate the elements to repeat depending on the use of prior-preservation.
elems_to_repeat = bsz // 2 if args.with_prior_preservation else bsz
# Calculate the elements to repeat depending on the use of prior-preservation and custom captions.
if not train_dataset.custom_instance_prompts:
elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz
elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz
else:
elems_to_repeat_text_embeds = 1
elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz
# Predict the noise residual
if not args.train_text_encoder:
unet_added_conditions = {
"time_ids": add_time_ids.repeat(elems_to_repeat, 1),
"text_embeds": unet_add_text_embeds.repeat(elems_to_repeat, 1),
"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1),
"text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1),
}
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat, 1, 1)
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
model_pred = unet(
noisy_model_input,
timesteps,
@@ -1116,15 +1471,17 @@ def main(args):
added_cond_kwargs=unet_added_conditions,
).sample
else:
unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat, 1)}
unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)}
prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two],
tokenizers=None,
prompt=None,
text_input_ids_list=[tokens_one, tokens_two],
)
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat, 1)})
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat, 1, 1)
unet_added_conditions.update(
{"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)}
)
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
model_pred = unet(
noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions
).sample
@@ -1142,16 +1499,34 @@ def main(args):
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
target, target_prior = torch.chunk(target, 2, dim=0)
# Compute instance loss
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
# Compute prior loss
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
if args.snr_gamma is None:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
else:
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps)
base_weight = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective needs to be floored to an SNR weight of one.
mse_loss_weights = base_weight + 1
else:
# Epsilon and sample both use the same loss weights.
mse_loss_weights = base_weight
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
if args.with_prior_preservation:
# Add the prior loss to the instance loss.
loss = loss + args.prior_loss_weight * prior_loss
else:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss)
if accelerator.sync_gradients:
@@ -1212,10 +1587,16 @@ def main(args):
# create pipeline
if not args.train_text_encoder:
text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
args.pretrained_model_name_or_path,
subfolder="text_encoder",
revision=args.revision,
variant=args.variant,
)
text_encoder_two = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
args.pretrained_model_name_or_path,
subfolder="text_encoder_2",
revision=args.revision,
variant=args.variant,
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path,
@@ -1224,6 +1605,7 @@ def main(args):
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
unet=accelerator.unwrap_model(unet),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
@@ -1301,10 +1683,15 @@ def main(args):
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path, vae=vae, revision=args.revision, torch_dtype=weight_dtype
args.pretrained_model_name_or_path,
vae=vae,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
@@ -1353,7 +1740,8 @@ def main(args):
images=images,
base_model=args.pretrained_model_name_or_path,
train_text_encoder=args.train_text_encoder,
prompt=args.instance_prompt,
instance_prompt=args.instance_prompt,
validation_prompt=args.validation_prompt,
repo_folder=args.output_dir,
vae_path=args.pretrained_vae_model_name_or_path,
)
@@ -78,6 +78,12 @@ def parse_args():
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--dataset_name",
type=str,
@@ -435,9 +441,11 @@ def main():
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
)
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
)
@@ -915,6 +923,7 @@ def main():
text_encoder=accelerator.unwrap_model(text_encoder),
vae=accelerator.unwrap_model(vae),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
@@ -966,6 +975,7 @@ def main():
vae=accelerator.unwrap_model(vae),
unet=unet,
revision=args.revision,
variant=args.variant,
)
pipeline.save_pretrained(args.output_dir)
@@ -118,6 +118,12 @@ def parse_args():
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--dataset_name",
type=str,
@@ -484,9 +490,10 @@ def main():
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
# InstructPix2Pix uses an additional image for conditioning. To accommodate that,
@@ -695,10 +702,16 @@ def main():
# Load scheduler, tokenizer and models.
tokenizer_1 = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
use_fast=False,
)
tokenizer_2 = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
args.pretrained_model_name_or_path,
subfolder="tokenizer_2",
revision=args.revision,
use_fast=False,
)
text_encoder_cls_1 = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
text_encoder_cls_2 = import_model_class_from_model_name_or_path(
@@ -708,10 +721,10 @@ def main():
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder_1 = text_encoder_cls_1.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
text_encoder_2 = text_encoder_cls_2.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
)
# We ALWAYS pre-compute the additional condition embeddings needed for SDXL
@@ -1109,6 +1122,7 @@ def main():
tokenizer_2=tokenizer_2,
vae=vae,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
@@ -1176,6 +1190,7 @@ def main():
vae=vae,
unet=unet,
revision=args.revision,
variant=args.variant,
)
pipeline.save_pretrained(args.output_dir)
+19 -5
View File
@@ -85,6 +85,7 @@ def log_validation(vae, unet, adapter, args, accelerator, weight_dtype, step):
unet=unet,
adapter=adapter,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
@@ -262,6 +263,12 @@ def parse_args(input_args=None):
" float32 precision."
),
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--tokenizer_name",
type=str,
@@ -812,10 +819,16 @@ def main(args):
# Load the tokenizers
tokenizer_one = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
use_fast=False,
)
tokenizer_two = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
args.pretrained_model_name_or_path,
subfolder="tokenizer_2",
revision=args.revision,
use_fast=False,
)
# import correct text encoder classes
@@ -829,10 +842,10 @@ def main(args):
# Load scheduler and models
noise_scheduler = EulerDiscreteScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
text_encoder_two = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
)
vae_path = (
args.pretrained_model_name_or_path
@@ -843,9 +856,10 @@ def main(args):
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
if args.adapter_model_name_or_path:
+43
View File
@@ -421,6 +421,49 @@ class ExamplesTestsAccelerate(unittest.TestCase):
)
self.assertTrue(starts_with_unet)
def test_dreambooth_lora_sdxl_custom_captions(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth_lora_sdxl.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
--dataset_name hf-internal-testing/dummy_image_text_data
--caption_column text
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
def test_dreambooth_lora_sdxl_text_encoder_custom_captions(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth_lora_sdxl.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
--dataset_name hf-internal-testing/dummy_image_text_data
--caption_column text
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--train_text_encoder
""".split()
run_command(self._launch_args + test_args)
def test_dreambooth_lora_sdxl_checkpointing_checkpoints_total_limit(self):
pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
+11 -3
View File
@@ -148,6 +148,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight
unet=accelerator.unwrap_model(unet),
safety_checker=None,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
@@ -209,6 +210,12 @@ def parse_args():
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--dataset_name",
type=str,
@@ -567,10 +574,10 @@ def main():
# across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.
with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
)
unet = UNet2DConditionModel.from_pretrained(
@@ -585,7 +592,7 @@ def main():
# Create EMA for the unet.
if args.use_ema:
ema_unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
@@ -1026,6 +1033,7 @@ def main():
vae=vae,
unet=unet,
revision=args.revision,
variant=args.variant,
)
pipeline.save_pretrained(args.output_dir)
@@ -54,6 +54,12 @@ def parse_args():
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--dataset_name",
type=str,
@@ -40,8 +40,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available
@@ -54,6 +53,39 @@ check_min_version("0.24.0.dev0")
logger = get_logger(__name__, log_level="INFO")
# TODO: This function should be removed once training scripts are rewritten in PEFT
def text_encoder_lora_state_dict(text_encoder):
state_dict = {}
def text_encoder_attn_modules(text_encoder):
from transformers import CLIPTextModel, CLIPTextModelWithProjection
attn_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn
attn_modules.append((name, mod))
return attn_modules
for name, module in text_encoder_attn_modules(text_encoder):
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
return state_dict
def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):
img_str = ""
for i, image in enumerate(images):
@@ -98,6 +130,12 @@ def parse_args():
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--dataset_name",
type=str,
@@ -422,9 +460,11 @@ def main():
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
# freeze parameters of models to save more memory
unet.requires_grad_(False)
@@ -458,25 +498,43 @@ def main():
# => 32 layers
# Set correct lora layers
lora_attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
unet_lora_parameters = []
for attn_processor_name, attn_processor in unet.attn_processors.items():
# Parse the attention module.
attn_module = unet
for n in attn_processor_name.split(".")[:-1]:
attn_module = getattr(attn_module, n)
lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
rank=args.rank,
# Set the `lora_layer` attribute of the attention-related matrices.
attn_module.to_q.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
)
)
attn_module.to_k.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
)
)
unet.set_attn_processor(lora_attn_procs)
attn_module.to_v.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
)
)
attn_module.to_out[0].set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_out[0].in_features,
out_features=attn_module.to_out[0].out_features,
rank=args.rank,
)
)
# Accumulate the LoRA params to optimize.
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
@@ -491,8 +549,6 @@ def main():
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
lora_layers = AttnProcsLayers(unet.attn_processors)
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
@@ -517,7 +573,7 @@ def main():
optimizer_cls = torch.optim.AdamW
optimizer = optimizer_cls(
lora_layers.parameters(),
unet_lora_parameters,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
@@ -644,8 +700,8 @@ def main():
)
# Prepare everything with our `accelerator`.
lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
lora_layers, optimizer, train_dataloader, lr_scheduler
unet_lora_parameters, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet_lora_parameters, optimizer, train_dataloader, lr_scheduler
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
@@ -777,7 +833,7 @@ def main():
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = lora_layers.parameters()
params_to_clip = unet_lora_parameters
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
@@ -833,6 +889,7 @@ def main():
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
@@ -889,7 +946,7 @@ def main():
# Final inference
# Load previous pipeline
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype
)
pipeline = pipeline.to(accelerator.device)
@@ -50,7 +50,7 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.loaders import LoraLoaderMixin
from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available
@@ -63,6 +63,39 @@ check_min_version("0.24.0.dev0")
logger = get_logger(__name__)
# TODO: This function should be removed once training scripts are rewritten in PEFT
def text_encoder_lora_state_dict(text_encoder):
state_dict = {}
def text_encoder_attn_modules(text_encoder):
from transformers import CLIPTextModel, CLIPTextModelWithProjection
attn_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn
attn_modules.append((name, mod))
return attn_modules
for name, module in text_encoder_attn_modules(text_encoder):
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
return state_dict
def save_model_card(
repo_id: str,
images=None,
@@ -147,6 +180,12 @@ def parse_args(input_args=None):
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--dataset_name",
type=str,
@@ -537,10 +576,16 @@ def main(args):
# Load the tokenizers
tokenizer_one = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
use_fast=False,
)
tokenizer_two = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
args.pretrained_model_name_or_path,
subfolder="tokenizer_2",
revision=args.revision,
use_fast=False,
)
# import correct text encoder classes
@@ -554,10 +599,10 @@ def main(args):
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
text_encoder_two = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
)
vae_path = (
args.pretrained_model_name_or_path
@@ -565,10 +610,13 @@ def main(args):
else args.pretrained_vae_model_name_or_path
)
vae = AutoencoderKL.from_pretrained(
vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
# We only train the additional adapter LoRA layers
@@ -1143,6 +1191,7 @@ def main(args):
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
unet=accelerator.unwrap_model(unet),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
@@ -1208,7 +1257,11 @@ def main(args):
# Final inference
# Load previous pipeline
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path, vae=vae, revision=args.revision, torch_dtype=weight_dtype
args.pretrained_model_name_or_path,
vae=vae,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
@@ -148,6 +148,12 @@ def parse_args(input_args=None):
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--dataset_name",
type=str,
@@ -618,10 +624,16 @@ def main(args):
# Load the tokenizers
tokenizer_one = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
use_fast=False,
)
tokenizer_two = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False
args.pretrained_model_name_or_path,
subfolder="tokenizer_2",
revision=args.revision,
use_fast=False,
)
# import correct text encoder classes
@@ -636,10 +648,10 @@ def main(args):
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
# Check for terminal SNR in combination with SNR Gamma
text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
text_encoder_two = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
)
vae_path = (
args.pretrained_model_name_or_path
@@ -647,10 +659,13 @@ def main(args):
else args.pretrained_vae_model_name_or_path
)
vae = AutoencoderKL.from_pretrained(
vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
# Freeze vae and text encoders.
@@ -677,7 +692,7 @@ def main(args):
# Create EMA for the unet.
if args.use_ema:
ema_unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
@@ -1145,12 +1160,14 @@ def main(args):
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
unet=accelerator.unwrap_model(unet),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
if args.prediction_type is not None:
@@ -1198,10 +1215,16 @@ def main(args):
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path, unet=unet, vae=vae, revision=args.revision, torch_dtype=weight_dtype
args.pretrained_model_name_or_path,
unet=unet,
vae=vae,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
if args.prediction_type is not None:
scheduler_args = {"prediction_type": args.prediction_type}
@@ -126,6 +126,7 @@ def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight
vae=vae,
safety_checker=None,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
@@ -206,6 +207,12 @@ def parse_args():
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--tokenizer_name",
type=str,
@@ -624,9 +631,11 @@ def main():
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
)
# Add the placeholder token in tokenizer
@@ -752,6 +761,7 @@ def main():
num_cycles=args.lr_num_cycles,
)
text_encoder.train()
# Prepare everything with our `accelerator`.
text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder, optimizer, train_dataloader, lr_scheduler
+98
View File
@@ -0,0 +1,98 @@
#!/usr/bin/env python3
import argparse
import fnmatch
from safetensors.torch import load_file
from diffusers import Kandinsky3UNet
MAPPING = {
"to_time_embed.1": "time_embedding.linear_1",
"to_time_embed.3": "time_embedding.linear_2",
"in_layer": "conv_in",
"out_layer.0": "conv_norm_out",
"out_layer.2": "conv_out",
"down_samples": "down_blocks",
"up_samples": "up_blocks",
"projection_lin": "encoder_hid_proj.projection_linear",
"projection_ln": "encoder_hid_proj.projection_norm",
"feature_pooling": "add_time_condition",
"to_query": "to_q",
"to_key": "to_k",
"to_value": "to_v",
"output_layer": "to_out.0",
"self_attention_block": "attentions.0",
}
DYNAMIC_MAP = {
"resnet_attn_blocks.*.0": "resnets_in.*",
"resnet_attn_blocks.*.1": ("attentions.*", 1),
"resnet_attn_blocks.*.2": "resnets_out.*",
}
# MAPPING = {}
def convert_state_dict(unet_state_dict):
"""
Convert the state dict of a U-Net model to match the key format expected by Kandinsky3UNet model.
Args:
unet_model (torch.nn.Module): The original U-Net model.
unet_kandi3_model (torch.nn.Module): The Kandinsky3UNet model to match keys with.
Returns:
OrderedDict: The converted state dictionary.
"""
# Example of renaming logic (this will vary based on your model's architecture)
converted_state_dict = {}
for key in unet_state_dict:
new_key = key
for pattern, new_pattern in MAPPING.items():
new_key = new_key.replace(pattern, new_pattern)
for dyn_pattern, dyn_new_pattern in DYNAMIC_MAP.items():
has_matched = False
if fnmatch.fnmatch(new_key, f"*.{dyn_pattern}.*") and not has_matched:
star = int(new_key.split(dyn_pattern.split(".")[0])[-1].split(".")[1])
if isinstance(dyn_new_pattern, tuple):
new_star = star + dyn_new_pattern[-1]
dyn_new_pattern = dyn_new_pattern[0]
else:
new_star = star
pattern = dyn_pattern.replace("*", str(star))
new_pattern = dyn_new_pattern.replace("*", str(new_star))
new_key = new_key.replace(pattern, new_pattern)
has_matched = True
converted_state_dict[new_key] = unet_state_dict[key]
return converted_state_dict
def main(model_path, output_path):
# Load your original U-Net model
unet_state_dict = load_file(model_path)
# Initialize your Kandinsky3UNet model
config = {}
# Convert the state dict
converted_state_dict = convert_state_dict(unet_state_dict)
unet = Kandinsky3UNet(config)
unet.load_state_dict(converted_state_dict)
unet.save_pretrained(output_path)
print(f"Converted model saved to {output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert U-Net PyTorch model to Kandinsky3UNet format")
parser.add_argument("--model_path", type=str, required=True, help="Path to the original U-Net PyTorch model")
parser.add_argument("--output_path", type=str, required=True, help="Path to save the converted model")
args = parser.parse_args()
main(args.model_path, args.output_path)
+730
View File
@@ -0,0 +1,730 @@
from diffusers.utils import is_accelerate_available, logging
if is_accelerate_available():
pass
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
"""
Creates a config for the diffusers based on the config of the LDM model.
"""
if controlnet:
unet_params = original_config.model.params.control_stage_config.params
else:
if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None:
unet_params = original_config.model.params.unet_config.params
else:
unet_params = original_config.model.params.network_config.params
vae_params = original_config.model.params.first_stage_config.params.encoder_config.params
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
down_block_types = []
resolution = 1
for i in range(len(block_out_channels)):
block_type = (
"CrossAttnDownBlockSpatioTemporal"
if resolution in unet_params.attention_resolutions
else "DownBlockSpatioTemporal"
)
down_block_types.append(block_type)
if i != len(block_out_channels) - 1:
resolution *= 2
up_block_types = []
for i in range(len(block_out_channels)):
block_type = (
"CrossAttnUpBlockSpatioTemporal"
if resolution in unet_params.attention_resolutions
else "UpBlockSpatioTemporal"
)
up_block_types.append(block_type)
resolution //= 2
if unet_params.transformer_depth is not None:
transformer_layers_per_block = (
unet_params.transformer_depth
if isinstance(unet_params.transformer_depth, int)
else list(unet_params.transformer_depth)
)
else:
transformer_layers_per_block = 1
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
head_dim = unet_params.num_heads if "num_heads" in unet_params else None
use_linear_projection = (
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
)
if use_linear_projection:
# stable diffusion 2-base-512 and 2-768
if head_dim is None:
head_dim_mult = unet_params.model_channels // unet_params.num_head_channels
head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)]
class_embed_type = None
addition_embed_type = None
addition_time_embed_dim = None
projection_class_embeddings_input_dim = None
context_dim = None
if unet_params.context_dim is not None:
context_dim = (
unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0]
)
if "num_classes" in unet_params:
if unet_params.num_classes == "sequential":
addition_time_embed_dim = 256
assert "adm_in_channels" in unet_params
projection_class_embeddings_input_dim = unet_params.adm_in_channels
config = {
"sample_size": image_size // vae_scale_factor,
"in_channels": unet_params.in_channels,
"down_block_types": tuple(down_block_types),
"block_out_channels": tuple(block_out_channels),
"layers_per_block": unet_params.num_res_blocks,
"cross_attention_dim": context_dim,
"attention_head_dim": head_dim,
"use_linear_projection": use_linear_projection,
"class_embed_type": class_embed_type,
"addition_embed_type": addition_embed_type,
"addition_time_embed_dim": addition_time_embed_dim,
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
"transformer_layers_per_block": transformer_layers_per_block,
}
if "disable_self_attentions" in unet_params:
config["only_cross_attention"] = unet_params.disable_self_attentions
if "num_classes" in unet_params and isinstance(unet_params.num_classes, int):
config["num_class_embeds"] = unet_params.num_classes
if controlnet:
config["conditioning_channels"] = unet_params.hint_channels
else:
config["out_channels"] = unet_params.out_channels
config["up_block_types"] = tuple(up_block_types)
return config
def assign_to_checkpoint(
paths,
checkpoint,
old_checkpoint,
attention_paths_to_split=None,
additional_replacements=None,
config=None,
mid_block_suffix="",
):
"""
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
attention layers, and takes into account additional replacements that may arise.
Assigns the weights to the new checkpoint.
"""
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
# Splits the attention layers into three variables.
if attention_paths_to_split is not None:
for path, path_map in attention_paths_to_split.items():
old_tensor = old_checkpoint[path]
channels = old_tensor.shape[0] // 3
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
query, key, value = old_tensor.split(channels // num_heads, dim=1)
checkpoint[path_map["query"]] = query.reshape(target_shape)
checkpoint[path_map["key"]] = key.reshape(target_shape)
checkpoint[path_map["value"]] = value.reshape(target_shape)
if mid_block_suffix is not None:
mid_block_suffix = f".{mid_block_suffix}"
else:
mid_block_suffix = ""
for path in paths:
new_path = path["new"]
# These have already been assigned
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
continue
# Global renaming happens here
new_path = new_path.replace("middle_block.0", f"mid_block.resnets.0{mid_block_suffix}")
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
new_path = new_path.replace("middle_block.2", f"mid_block.resnets.1{mid_block_suffix}")
if additional_replacements is not None:
for replacement in additional_replacements:
new_path = new_path.replace(replacement["old"], replacement["new"])
if new_path == "mid_block.resnets.0.spatial_res_block.norm1.weight":
print("yeyy")
# proj_attn.weight has to be converted from conv 1D to linear
is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path)
shape = old_checkpoint[path["old"]].shape
if is_attn_weight and len(shape) == 3:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
elif is_attn_weight and len(shape) == 4:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
else:
checkpoint[new_path] = old_checkpoint[path["old"]]
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside attentions to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
new_item = new_item.replace("time_stack", "temporal_transformer_blocks")
new_item = new_item.replace("time_pos_embed.0.bias", "time_pos_embed.linear_1.bias")
new_item = new_item.replace("time_pos_embed.0.weight", "time_pos_embed.linear_1.weight")
new_item = new_item.replace("time_pos_embed.2.bias", "time_pos_embed.linear_2.bias")
new_item = new_item.replace("time_pos_embed.2.weight", "time_pos_embed.linear_2.weight")
mapping.append({"old": old_item, "new": new_item})
return mapping
def shave_segments(path, n_shave_prefix_segments=1):
"""
Removes segments. Positive values shave the first segments, negative shave the last segments.
"""
if n_shave_prefix_segments >= 0:
return ".".join(path.split(".")[n_shave_prefix_segments:])
else:
return ".".join(path.split(".")[:n_shave_prefix_segments])
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside resnets to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item.replace("in_layers.0", "norm1")
new_item = new_item.replace("in_layers.2", "conv1")
new_item = new_item.replace("out_layers.0", "norm2")
new_item = new_item.replace("out_layers.3", "conv2")
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
new_item = new_item.replace("skip_connection", "conv_shortcut")
new_item = new_item.replace("time_stack.", "")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def convert_ldm_unet_checkpoint(
checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False
):
"""
Takes a state dict and a config, and returns a converted checkpoint.
"""
if skip_extract_state_dict:
unet_state_dict = checkpoint
else:
# extract state_dict for UNet
unet_state_dict = {}
keys = list(checkpoint.keys())
unet_key = "model.diffusion_model."
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.")
logger.warning(
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
)
for key in keys:
if key.startswith("model.diffusion_model"):
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
else:
if sum(k.startswith("model_ema") for k in keys) > 100:
logger.warning(
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
)
for key in keys:
if key.startswith(unet_key):
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
new_checkpoint = {}
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
if config["class_embed_type"] is None:
# No parameters to port
...
elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
else:
raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
# if config["addition_embed_type"] == "text_time":
new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
# Retrieves the keys for the input blocks only
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
input_blocks = {
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
for layer_id in range(num_input_blocks)
}
# Retrieves the keys for the middle blocks only
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
middle_blocks = {
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
for layer_id in range(num_middle_blocks)
}
# Retrieves the keys for the output blocks only
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
output_blocks = {
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
for layer_id in range(num_output_blocks)
}
for i in range(1, num_input_blocks):
block_id = (i - 1) // (config["layers_per_block"] + 1)
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
spatial_resnets = [
key
for key in input_blocks[i]
if f"input_blocks.{i}.0" in key
and (
f"input_blocks.{i}.0.op" not in key
and f"input_blocks.{i}.0.time_stack" not in key
and f"input_blocks.{i}.0.time_mixer" not in key
)
]
temporal_resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0.time_stack" in key]
# import ipdb; ipdb.set_trace()
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.weight"
)
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.bias"
)
paths = renew_resnet_paths(spatial_resnets)
meta_path = {
"old": f"input_blocks.{i}.0",
"new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}.spatial_res_block",
}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
paths = renew_resnet_paths(temporal_resnets)
meta_path = {
"old": f"input_blocks.{i}.0",
"new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}.temporal_res_block",
}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
# TODO resnet time_mixer.mix_factor
if f"input_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
new_checkpoint[
f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"
] = unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"]
if len(attentions):
paths = renew_attention_paths(attentions)
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
# import ipdb; ipdb.set_trace()
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
resnet_0 = middle_blocks[0]
attentions = middle_blocks[1]
resnet_1 = middle_blocks[2]
resnet_0_spatial = [key for key in resnet_0 if "time_stack" not in key and "time_mixer" not in key]
resnet_0_paths = renew_resnet_paths(resnet_0_spatial)
# import ipdb; ipdb.set_trace()
assign_to_checkpoint(
resnet_0_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="spatial_res_block"
)
resnet_0_temporal = [key for key in resnet_0 if "time_stack" in key and "time_mixer" not in key]
resnet_0_paths = renew_resnet_paths(resnet_0_temporal)
assign_to_checkpoint(
resnet_0_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="temporal_res_block"
)
resnet_1_spatial = [key for key in resnet_1 if "time_stack" not in key and "time_mixer" not in key]
resnet_1_paths = renew_resnet_paths(resnet_1_spatial)
assign_to_checkpoint(
resnet_1_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="spatial_res_block"
)
resnet_1_temporal = [key for key in resnet_1 if "time_stack" in key and "time_mixer" not in key]
resnet_1_paths = renew_resnet_paths(resnet_1_temporal)
assign_to_checkpoint(
resnet_1_paths, new_checkpoint, unet_state_dict, config=config, mid_block_suffix="temporal_res_block"
)
new_checkpoint["mid_block.resnets.0.time_mixer.mix_factor"] = unet_state_dict[
"middle_block.0.time_mixer.mix_factor"
]
new_checkpoint["mid_block.resnets.1.time_mixer.mix_factor"] = unet_state_dict[
"middle_block.2.time_mixer.mix_factor"
]
attentions_paths = renew_attention_paths(attentions)
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
for i in range(num_output_blocks):
block_id = i // (config["layers_per_block"] + 1)
layer_in_block_id = i % (config["layers_per_block"] + 1)
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
output_block_list = {}
for layer in output_block_layers:
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
if layer_id in output_block_list:
output_block_list[layer_id].append(layer_name)
else:
output_block_list[layer_id] = [layer_name]
if len(output_block_list) > 1:
spatial_resnets = [
key
for key in output_blocks[i]
if f"output_blocks.{i}.0" in key
and (f"output_blocks.{i}.0.time_stack" not in key and "time_mixer" not in key)
]
# import ipdb; ipdb.set_trace()
temporal_resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0.time_stack" in key]
paths = renew_resnet_paths(spatial_resnets)
meta_path = {
"old": f"output_blocks.{i}.0",
"new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}.spatial_res_block",
}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
paths = renew_resnet_paths(temporal_resnets)
meta_path = {
"old": f"output_blocks.{i}.0",
"new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}.temporal_res_block",
}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
if f"output_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
new_checkpoint[
f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"
] = unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"]
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
if ["conv.bias", "conv.weight"] in output_block_list.values():
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.weight"
]
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.bias"
]
# Clear attentions as they have been attributed above.
if len(attentions) == 2:
attentions = []
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key and "conv" not in key]
if len(attentions):
paths = renew_attention_paths(attentions)
# import ipdb; ipdb.set_trace()
meta_path = {
"old": f"output_blocks.{i}.1",
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
else:
spatial_layers = [
layer for layer in output_block_layers if "time_stack" not in layer and "time_mixer" not in layer
]
resnet_0_paths = renew_resnet_paths(spatial_layers, n_shave_prefix_segments=1)
# import ipdb; ipdb.set_trace()
for path in resnet_0_paths:
old_path = ".".join(["output_blocks", str(i), path["old"]])
new_path = ".".join(
["up_blocks", str(block_id), "resnets", str(layer_in_block_id), "spatial_res_block", path["new"]]
)
new_checkpoint[new_path] = unet_state_dict[old_path]
temporal_layers = [
layer for layer in output_block_layers if "time_stack" in layer and "time_mixer" not in key
]
resnet_0_paths = renew_resnet_paths(temporal_layers, n_shave_prefix_segments=1)
# import ipdb; ipdb.set_trace()
for path in resnet_0_paths:
old_path = ".".join(["output_blocks", str(i), path["old"]])
new_path = ".".join(
["up_blocks", str(block_id), "resnets", str(layer_in_block_id), "temporal_res_block", path["new"]]
)
new_checkpoint[new_path] = unet_state_dict[old_path]
new_checkpoint["up_blocks.0.resnets.0.time_mixer.mix_factor"] = unet_state_dict[
f"output_blocks.{str(i)}.0.time_mixer.mix_factor"
]
return new_checkpoint
def conv_attn_to_linear(checkpoint):
keys = list(checkpoint.keys())
attn_keys = ["to_q.weight", "to_k.weight", "to_v.weight"]
for key in keys:
if ".".join(key.split(".")[-2:]) in attn_keys:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0, 0]
elif "proj_attn.weight" in key:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0]
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0, is_temporal=False):
"""
Updates paths inside resnets to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
# Temporal resnet
new_item = old_item.replace("in_layers.0", "norm1")
new_item = new_item.replace("in_layers.2", "conv1")
new_item = new_item.replace("out_layers.0", "norm2")
new_item = new_item.replace("out_layers.3", "conv2")
new_item = new_item.replace("skip_connection", "conv_shortcut")
new_item = new_item.replace("time_stack.", "temporal_res_block.")
# Spatial resnet
new_item = new_item.replace("conv1", "spatial_res_block.conv1")
new_item = new_item.replace("norm1", "spatial_res_block.norm1")
new_item = new_item.replace("conv2", "spatial_res_block.conv2")
new_item = new_item.replace("norm2", "spatial_res_block.norm2")
new_item = new_item.replace("nin_shortcut", "spatial_res_block.conv_shortcut")
new_item = new_item.replace("mix_factor", "spatial_res_block.time_mixer.mix_factor")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside attentions to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
new_item = new_item.replace("norm.weight", "group_norm.weight")
new_item = new_item.replace("norm.bias", "group_norm.bias")
new_item = new_item.replace("q.weight", "to_q.weight")
new_item = new_item.replace("q.bias", "to_q.bias")
new_item = new_item.replace("k.weight", "to_k.weight")
new_item = new_item.replace("k.bias", "to_k.bias")
new_item = new_item.replace("v.weight", "to_v.weight")
new_item = new_item.replace("v.bias", "to_v.bias")
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def convert_ldm_vae_checkpoint(checkpoint, config):
# extract state dict for VAE
vae_state_dict = {}
keys = list(checkpoint.keys())
vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else ""
for key in keys:
if key.startswith(vae_key):
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
new_checkpoint = {}
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
new_checkpoint["decoder.time_conv_out.weight"] = vae_state_dict["decoder.time_mix_conv.weight"]
new_checkpoint["decoder.time_conv_out.bias"] = vae_state_dict["decoder.time_mix_conv.bias"]
# new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
# new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
# new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
# new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
# Retrieves the keys for the encoder down blocks only
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
down_blocks = {
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
}
# Retrieves the keys for the decoder up blocks only
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
up_blocks = {
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
}
for i in range(num_down_blocks):
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
f"encoder.down.{i}.downsample.conv.weight"
)
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
f"encoder.down.{i}.downsample.conv.bias"
)
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
paths = renew_vae_attention_paths(mid_attentions)
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
conv_attn_to_linear(new_checkpoint)
for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i
resnets = [
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
]
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.weight"
]
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.bias"
]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
paths = renew_vae_attention_paths(mid_attentions)
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
conv_attn_to_linear(new_checkpoint)
return new_checkpoint
+12
View File
@@ -76,9 +76,11 @@ else:
[
"AsymmetricAutoencoderKL",
"AutoencoderKL",
"AutoencoderKLTemporalDecoder",
"AutoencoderTiny",
"ConsistencyDecoderVAE",
"ControlNetModel",
"Kandinsky3UNet",
"ModelMixin",
"MotionAdapter",
"MultiAdapter",
@@ -91,6 +93,7 @@ else:
"UNet2DModel",
"UNet3DConditionModel",
"UNetMotionModel",
"UNetSpatioTemporalConditionModel",
"VQModel",
]
)
@@ -214,6 +217,8 @@ else:
"IFPipeline",
"IFSuperResolutionPipeline",
"ImageTextPipelineOutput",
"Kandinsky3Img2ImgPipeline",
"Kandinsky3Pipeline",
"KandinskyCombinedPipeline",
"KandinskyImg2ImgCombinedPipeline",
"KandinskyImg2ImgPipeline",
@@ -264,6 +269,7 @@ else:
"StableDiffusionPix2PixZeroPipeline",
"StableDiffusionSAGPipeline",
"StableDiffusionUpscalePipeline",
"StableDiffusionVideoPipeline",
"StableDiffusionXLAdapterPipeline",
"StableDiffusionXLControlNetImg2ImgPipeline",
"StableDiffusionXLControlNetInpaintPipeline",
@@ -443,9 +449,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .models import (
AsymmetricAutoencoderKL,
AutoencoderKL,
AutoencoderKLTemporalDecoder,
AutoencoderTiny,
ConsistencyDecoderVAE,
ControlNetModel,
Kandinsky3UNet,
ModelMixin,
MotionAdapter,
MultiAdapter,
@@ -458,6 +466,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
UNet2DModel,
UNet3DConditionModel,
UNetMotionModel,
UNetSpatioTemporalConditionModel,
VQModel,
)
from .optimization import (
@@ -560,6 +569,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
IFPipeline,
IFSuperResolutionPipeline,
ImageTextPipelineOutput,
Kandinsky3Img2ImgPipeline,
Kandinsky3Pipeline,
KandinskyCombinedPipeline,
KandinskyImg2ImgCombinedPipeline,
KandinskyImg2ImgPipeline,
@@ -610,6 +621,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionPix2PixZeroPipeline,
StableDiffusionSAGPipeline,
StableDiffusionUpscalePipeline,
StableDiffusionVideoPipeline,
StableDiffusionXLAdapterPipeline,
StableDiffusionXLControlNetImg2ImgPipeline,
StableDiffusionXLControlNetInpaintPipeline,
+1 -1
View File
@@ -326,7 +326,7 @@ class VaeImageProcessor(ConfigMixin):
# expected range [0,1], normalize to [-1,1]
do_normalize = self.config.do_normalize
if image.min() < 0 and do_normalize:
if do_normalize and image.min() < 0:
warnings.warn(
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
+4 -3
View File
@@ -8,7 +8,7 @@ def text_encoder_lora_state_dict(text_encoder):
deprecate(
"text_encoder_load_state_dict in `models`",
"0.27.0",
"`text_encoder_lora_state_dict` has been moved to `diffusers.models.lora`. Please make sure to import it via `from diffusers.models.lora import text_encoder_lora_state_dict`.",
"`text_encoder_lora_state_dict` is deprecated and will be removed in 0.27.0. Make sure to retrieve the weights using `get_peft_model`. See https://huggingface.co/docs/peft/v0.6.2/en/quicktour#peftmodel for more information.",
)
state_dict = {}
@@ -34,7 +34,7 @@ if is_transformers_available():
deprecate(
"text_encoder_attn_modules in `models`",
"0.27.0",
"`text_encoder_lora_state_dict` has been moved to `diffusers.models.lora`. Please make sure to import it via `from diffusers.models.lora import text_encoder_lora_state_dict`.",
"`text_encoder_lora_state_dict` is deprecated and will be removed in 0.27.0. Make sure to retrieve the weights using `get_peft_model`. See https://huggingface.co/docs/peft/v0.6.2/en/quicktour#peftmodel for more information.",
)
from transformers import CLIPTextModel, CLIPTextModelWithProjection
@@ -62,16 +62,17 @@ if is_torch_available():
_import_structure["single_file"].extend(["FromSingleFileMixin"])
_import_structure["lora"] = ["LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin"]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from ..models.lora import text_encoder_lora_state_dict
from .single_file import FromOriginalControlnetMixin, FromOriginalVAEMixin
from .unet import UNet2DConditionLoadersMixin
from .utils import AttnProcsLayers
if is_transformers_available():
from .ip_adapter import IPAdapterMixin
from .lora import LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin
from .single_file import FromSingleFileMixin
from .textual_inversion import TextualInversionLoaderMixin
+157
View File
@@ -0,0 +1,157 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Dict, Union
import torch
from safetensors import safe_open
from ..utils import (
DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
_get_model_file,
is_transformers_available,
logging,
)
if is_transformers_available():
from transformers import (
CLIPImageProcessor,
CLIPVisionModelWithProjection,
)
from ..models.attention_processor import (
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
)
logger = logging.get_logger(__name__)
class IPAdapterMixin:
"""Mixin for handling IP Adapters."""
def load_ip_adapter(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
subfolder: str,
weight_name: str,
**kwargs,
):
"""
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
incompletely downloaded files are deleted.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
"""
# Load the main state dict first.
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
if weight_name.endswith(".safetensors"):
state_dict = {"image_proj": {}, "ip_adapter": {}}
with safe_open(model_file, framework="pt", device="cpu") as f:
for key in f.keys():
if key.startswith("image_proj."):
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
else:
state_dict = torch.load(model_file, map_location="cpu")
else:
state_dict = pretrained_model_name_or_path_or_dict
keys = list(state_dict.keys())
if keys != ["image_proj", "ip_adapter"]:
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
# load CLIP image encoer here if it has not been registered to the pipeline yet
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
pretrained_model_name_or_path_or_dict,
subfolder=os.path.join(subfolder, "image_encoder"),
).to(self.device, dtype=self.dtype)
self.image_encoder = image_encoder
else:
raise ValueError("`image_encoder` cannot be None when using IP Adapters.")
# create feature extractor if it has not been registered to the pipeline yet
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
self.feature_extractor = CLIPImageProcessor()
# load ip-adapter into unet
self.unet._load_ip_adapter_weights(state_dict)
def set_ip_adapter_scale(self, scale):
for attn_processor in self.unet.attn_processors.values():
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
attn_processor.scale = scale
+33 -4
View File
@@ -47,9 +47,10 @@ from ..utils import (
if is_transformers_available():
from transformers import PreTrainedModel
from transformers import CLIPTextModel, CLIPTextModelWithProjection
from ..models.lora import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules
# To be deprecated soon
from ..models.lora import PatchedLoraProjection
if is_accelerate_available():
from accelerate import init_empty_weights
@@ -66,6 +67,34 @@ LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
LORA_DEPRECATION_MESSAGE = "You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT make sure to install the latest PEFT and transformers packages in the future."
def text_encoder_attn_modules(text_encoder):
attn_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn
attn_modules.append((name, mod))
else:
raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}")
return attn_modules
def text_encoder_mlp_modules(text_encoder):
mlp_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
mlp_mod = layer.mlp
name = f"text_model.encoder.layers.{i}.mlp"
mlp_modules.append((name, mlp_mod))
else:
raise ValueError(f"do not know how to get mlp modules for: {text_encoder.__class__.__name__}")
return mlp_modules
class LoraLoaderMixin:
r"""
Load LoRA layers into [`UNet2DConditionModel`] and [`~transformers.CLIPTextModel`].
@@ -1415,7 +1444,7 @@ class LoraLoaderMixin:
)
set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights)
def disable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None):
def disable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None): # noqa: F821
"""
Disable the text encoder's LoRA layers.
@@ -1445,7 +1474,7 @@ class LoraLoaderMixin:
raise ValueError("Text Encoder not found.")
set_adapter_layers(text_encoder, enabled=False)
def enable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None):
def enable_lora_for_text_encoder(self, text_encoder: Optional["PreTrainedModel"] = None): # noqa: F821
"""
Enables the text encoder's LoRA layers.
+70
View File
@@ -18,8 +18,10 @@ from typing import Callable, Dict, List, Optional, Union
import safetensors
import torch
import torch.nn.functional as F
from torch import nn
from ..models.embeddings import ImageProjection
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import (
DIFFUSERS_CACHE,
@@ -662,4 +664,72 @@ class UNet2DConditionLoadersMixin:
if hasattr(self, "peft_config"):
self.peft_config.pop(adapter_name, None)
def _load_ip_adapter_weights(self, state_dict):
from ..models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
)
# set ip-adapter cross-attention processors & load state_dict
attn_procs = {}
key_id = 1
for name in self.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else self.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = self.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(self.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = self.config.block_out_channels[block_id]
if cross_attention_dim is None or "motion_modules" in name:
attn_processor_class = (
AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
)
attn_procs[name] = attn_processor_class()
else:
attn_processor_class = (
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
)
attn_procs[name] = attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
).to(dtype=self.dtype, device=self.device)
value_dict = {}
for k, w in attn_procs[name].state_dict().items():
value_dict.update({f"{k}": state_dict["ip_adapter"][f"{key_id}.{k}"]})
attn_procs[name].load_state_dict(value_dict)
key_id += 2
self.set_attn_processor(attn_procs)
# create image projection layers.
clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1]
cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4
image_projection = ImageProjection(
cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim, num_image_text_embeds=4
)
image_projection.to(dtype=self.dtype, device=self.device)
# load image projection layer weights
image_proj_state_dict = {}
image_proj_state_dict.update(
{
"image_embeds.weight": state_dict["image_proj"]["proj.weight"],
"image_embeds.bias": state_dict["image_proj"]["proj.bias"],
"norm.weight": state_dict["image_proj"]["norm.weight"],
"norm.bias": state_dict["image_proj"]["norm.bias"],
}
)
image_projection.load_state_dict(image_proj_state_dict)
self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype)
self.config.encoder_hid_dim_type = "ip_image_proj"
delete_adapter_layers
+12 -1
View File
@@ -14,7 +14,12 @@
from typing import TYPE_CHECKING
from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available
from ..utils import (
DIFFUSERS_SLOW_IMPORT,
_LazyModule,
is_flax_available,
is_torch_available,
)
_import_structure = {}
@@ -23,6 +28,7 @@ if is_torch_available():
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoder_kl"] = ["AutoencoderKL"]
_import_structure["autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
_import_structure["autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["controlnet"] = ["ControlNetModel"]
@@ -36,7 +42,9 @@ if is_torch_available():
_import_structure["unet_2d"] = ["UNet2DModel"]
_import_structure["unet_2d_condition"] = ["UNet2DConditionModel"]
_import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
_import_structure["unet_kandi3"] = ["Kandinsky3UNet"]
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
_import_structure["unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
_import_structure["vq_model"] = ["VQModel"]
if is_flax_available():
@@ -50,6 +58,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .adapter import MultiAdapter, T2IAdapter
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
from .autoencoder_kl import AutoencoderKL
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
from .autoencoder_tiny import AutoencoderTiny
from .consistency_decoder_vae import ConsistencyDecoderVAE
from .controlnet import ControlNetModel
@@ -63,7 +72,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .unet_3d_condition import UNet3DConditionModel
from .unet_kandi3 import Kandinsky3UNet
from .unet_motion_model import MotionAdapter, UNetMotionModel
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
from .vq_model import VQModel
if is_flax_available():
+181 -1
View File
@@ -194,7 +194,12 @@ class BasicTransformerBlock(nn.Module):
if not self.use_ada_layer_norm_single:
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
)
# 4. Fuser
if attention_type == "gated" or attention_type == "gated-text-image":
@@ -339,6 +344,181 @@ class BasicTransformerBlock(nn.Module):
return hidden_states
@maybe_allow_in_graph
class TemporalBasicTransformerBlock(nn.Module):
r"""
A basic Transformer block for video like data.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
upcast_attention (`bool`, *optional*):
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
final_dropout (`bool` *optional*, defaults to False):
Whether to apply a final dropout after the last feed-forward layer.
attention_type (`str`, *optional*, defaults to `"default"`):
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
"""
def __init__(
self,
dim: int,
time_mix_inner_dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
norm_eps: float = 1e-5,
final_dropout: bool = False,
):
super().__init__()
self.is_res = dim == time_mix_inner_dim
self.norm_in = nn.LayerNorm(dim, eps=norm_eps)
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
self.norm_in = nn.LayerNorm(dim, eps=norm_eps)
self.ff_in = FeedForward(
dim,
dim_out=time_mix_inner_dim,
dropout=dropout,
activation_fn="geglu",
final_dropout=final_dropout,
)
self.norm1 = nn.LayerNorm(time_mix_inner_dim, eps=norm_eps)
self.attn1 = Attention(
query_dim=time_mix_inner_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
cross_attention_dim=None,
)
# 2. Cross-Attn
if cross_attention_dim is not None:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
self.norm2 = nn.LayerNorm(time_mix_inner_dim, eps=norm_eps)
self.attn2 = Attention(
query_dim=time_mix_inner_dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
) # is self-attn if encoder_hidden_states is none
else:
self.norm2 = None
self.attn2 = None
# 3. Feed-forward
self.norm3 = nn.LayerNorm(time_mix_inner_dim, eps=norm_eps)
self.ff = FeedForward(
time_mix_inner_dim,
dropout=dropout,
activation_fn="geglu",
final_dropout=final_dropout,
)
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.FloatTensor,
num_frames: int,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
) -> torch.FloatTensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
batch_size = hidden_states.shape[0]
batch_frames, seq_length, channels = hidden_states.shape
batch_size = batch_frames // num_frames
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
hidden_states = hidden_states.permute(0, 2, 1, 3)
hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
residual = hidden_states
hidden_states = self.norm_in(hidden_states)
hidden_states = self.ff_in(hidden_states)
if self.is_res:
hidden_states = hidden_states + residual
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 3. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = self.norm2(hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 4. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
raise ValueError(
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
)
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
ff_output = torch.cat(
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
dim=self._chunk_dim,
)
else:
ff_output = self.ff(norm_hidden_states)
if self.is_res:
hidden_states = ff_output + hidden_states
else:
hidden_states = ff_output
hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
hidden_states = hidden_states.permute(0, 2, 1, 3)
hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
return hidden_states
class FeedForward(nn.Module):
r"""
A feed-forward layer.
+286 -1
View File
@@ -16,7 +16,7 @@ from typing import Callable, Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
from torch import einsum, nn
from ..utils import USE_PEFT_BACKEND, deprecate, logging
from ..utils.import_utils import is_xformers_available
@@ -1975,6 +1975,288 @@ class LoRAAttnAddedKVProcessor(nn.Module):
return attn.processor(attn, hidden_states, *args, **kwargs)
class IPAdapterAttnProcessor(nn.Module):
r"""
Attention processor for IP-Adapater.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
num_tokens (`int`, defaults to 4):
The context length of the image features.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
"""
def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=4, scale=1.0):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.num_tokens = num_tokens
self.scale = scale
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
scale=1.0,
):
if scale != 1.0:
logger.warning("`scale` of IPAttnProcessor should be set with `set_ip_adapter_scale`.")
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
# split hidden states
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
encoder_hidden_states[:, end_pos:, :],
)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class IPAdapterAttnProcessor2_0(torch.nn.Module):
r"""
Attention processor for IP-Adapater for PyTorch 2.0.
Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
num_tokens (`int`, defaults to 4):
The context length of the image features.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
"""
def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=4, scale=1.0):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.num_tokens = num_tokens
self.scale = scale
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
scale=1.0,
):
if scale != 1.0:
logger.warning("`scale` of IPAttnProcessor should be set by `set_ip_adapter_scale`.")
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
# split hidden states
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
encoder_hidden_states[:, end_pos:, :],
)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
# TODO(Yiyi): This class should not exist, we can replace it with a normal attention processor I believe
# this way torch.compile and co. will work as well
class Kandi3AttnProcessor:
r"""
Default kandinsky3 proccesor for performing attention-related computations.
"""
@staticmethod
def _reshape(hid_states, h):
b, n, f = hid_states.shape
d = f // h
return hid_states.unsqueeze(-1).reshape(b, n, h, d).permute(0, 2, 1, 3)
def __call__(
self,
attn,
x,
context,
context_mask=None,
):
query = self._reshape(attn.to_q(x), h=attn.num_heads)
key = self._reshape(attn.to_k(context), h=attn.num_heads)
value = self._reshape(attn.to_v(context), h=attn.num_heads)
attention_matrix = einsum("b h i d, b h j d -> b h i j", query, key)
if context_mask is not None:
max_neg_value = -torch.finfo(attention_matrix.dtype).max
context_mask = context_mask.unsqueeze(1).unsqueeze(1)
attention_matrix = attention_matrix.masked_fill(~(context_mask != 0), max_neg_value)
attention_matrix = (attention_matrix * attn.scale).softmax(dim=-1)
out = einsum("b h i j, b h j d -> b h i d", attention_matrix, value)
out = out.permute(0, 2, 1, 3).reshape(out.shape[0], out.shape[2], -1)
out = attn.to_out[0](out)
return out
LORA_ATTENTION_PROCESSORS = (
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
@@ -1998,6 +2280,9 @@ CROSS_ATTENTION_PROCESSORS = (
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
Kandi3AttnProcessor,
)
AttentionProcessor = Union[
@@ -108,6 +108,9 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
self.use_slicing = False
self.use_tiling = False
self.register_to_config(block_out_channels=up_block_out_channels)
self.register_to_config(force_upcast=False)
@apply_forward_hook
def encode(
self, x: torch.FloatTensor, return_dict: bool = True
@@ -0,0 +1,672 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, Iterable, Optional, Tuple, Union
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalVAEMixin
from ..utils import BaseOutput, is_torch_version
from ..utils.accelerate_utils import apply_forward_hook
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from .modeling_utils import ModelMixin
from .unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
class TemporalDecoder(nn.Module):
def __init__(
self,
in_channels: int = 4,
out_channels: int = 3,
block_out_channels: Tuple[int, ...] = (
128,
256,
512,
512,
),
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
norm_type: str = "group", # group, spatial
alpha: float = 0.0,
merge_strategy: str = "learned",
conv_out_kernel_size=(3, 1, 1),
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
temb_channels = in_channels if norm_type == "spatial" else None
self.mid_block = MidBlockTemporalDecoder(
num_layers=self.layers_per_block,
in_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
attention_head_dim=block_out_channels[-1],
resnet_eps=1e-6,
temporal_resnet_eps=1e-5,
resnet_act_fn=act_fn,
norm_num_groups=norm_num_groups,
temb_channels=temb_channels,
resnet_time_scale_shift=norm_type,
merge_factor=alpha,
merge_strategy=merge_strategy,
)
# up
self.up_blocks = nn.ModuleList([])
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i in range(len(block_out_channels)):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
up_block = UpBlockTemporalDecoder(
num_layers=self.layers_per_block + 1,
in_channels=prev_output_channel,
out_channels=output_channel,
add_upsample=not is_final_block,
resnet_eps=1e-6,
temporal_resnet_eps=1e-5,
resnet_act_fn=act_fn,
norm_num_groups=norm_num_groups,
attention_head_dim=output_channel,
temb_channels=temb_channels,
resnet_time_scale_shift=norm_type,
merge_factor=alpha,
merge_strategy=merge_strategy,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
if isinstance(conv_out_kernel_size, Iterable):
padding = [int(k // 2) for k in conv_out_kernel_size]
else:
padding = int(conv_out_kernel_size // 2)
self.conv_act = nn.SiLU()
self.conv_out = torch.nn.Conv2d(
in_channels=block_out_channels[0],
out_channels=out_channels,
kernel_size=3,
padding=1,
)
self.time_conv_out = torch.nn.Conv3d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=conv_out_kernel_size,
padding=padding,
)
self.gradient_checkpointing = False
def forward(
self,
sample: torch.FloatTensor,
image_only_indicator: torch.FloatTensor,
num_frames: int = 1,
latent_embeds: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class."""
sample = self.conv_in(sample)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
sample,
image_only_indicator,
latent_embeds,
num_frames,
use_reentrant=False,
)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
sample,
image_only_indicator,
latent_embeds,
num_frames,
use_reentrant=False,
)
else:
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
sample,
image_only_indicator,
latent_embeds,
num_frames,
)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
sample,
image_only_indicator,
latent_embeds,
num_frames,
)
else:
# middle
sample = self.mid_block(
sample,
temb=latent_embeds,
num_frames=num_frames,
image_only_indicator=image_only_indicator,
)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = up_block(
sample,
temb=latent_embeds,
num_frames=num_frames,
image_only_indicator=image_only_indicator,
)
# post-process
if latent_embeds is None:
sample = self.conv_norm_out(sample)
else:
sample = self.conv_norm_out(sample, latent_embeds)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
batch_frames, channels, height, width = sample.shape
batch_size = batch_frames // num_frames
sample = sample[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
sample = self.time_conv_out(sample)
sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width)
return sample
@dataclass
class AutoencoderKLOutput(BaseOutput):
"""
Output of AutoencoderKL encoding method.
Args:
latent_dist (`DiagonalGaussianDistribution`):
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
"""
latent_dist: "DiagonalGaussianDistribution"
class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
Tuple of downsample block types.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
Tuple of block output channels.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
scaling_factor (`float`, *optional*, defaults to 0.18215):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
force_upcast (`bool`, *optional*, default to `True`):
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
can be fine-tuned / trained to a lower range without loosing too much precision in which case
`force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
block_out_channels: Tuple[int] = (64,),
layers_per_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 4,
norm_num_groups: int = 32,
sample_size: int = 32,
scaling_factor: float = 0.18215,
force_upcast: float = True,
):
super().__init__()
# pass init params to Encoder
self.encoder = Encoder(
in_channels=in_channels,
out_channels=latent_channels,
down_block_types=down_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
double_z=True,
)
# pass init params to Decoder
self.decoder = TemporalDecoder(
in_channels=latent_channels,
out_channels=out_channels,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn,
)
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
self.use_slicing = False
self.use_tiling = False
# only relevant if vae tiling is enabled
self.tile_sample_min_size = self.config.sample_size
sample_size = (
self.config.sample_size[0]
if isinstance(self.config.sample_size, (list, tuple))
else self.config.sample_size
)
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (Encoder, TemporalDecoder)):
module.gradient_checkpointing = value
def enable_tiling(self, use_tiling: bool = True):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.use_tiling = use_tiling
def disable_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.enable_tiling(False)
def enable_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
@property
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnAddedKVProcessor()
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor, _remove_lora=True)
@apply_forward_hook
def encode(
self, x: torch.FloatTensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
Args:
x (`torch.FloatTensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded images. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
return self.tiled_encode(x, return_dict=return_dict)
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(
self, z: torch.FloatTensor, num_frames: int, return_dict: bool = True
) -> Union[DecoderOutput, torch.FloatTensor]:
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
return self.tiled_decode(z, return_dict=return_dict)
batch_size = z.shape[0] // num_frames
# TODO: dont hardcode this
image_only_indicator = torch.zeros(batch_size, num_frames, dtype=z.dtype, device=z.device)
dec = self.decoder(z, num_frames=num_frames, image_only_indicator=image_only_indicator)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
@apply_forward_hook
def decode(
self,
z: torch.FloatTensor,
num_frames: int,
return_dict: bool = True,
generator=None,
) -> Union[DecoderOutput, torch.FloatTensor]:
"""
Decode a batch of images.
Args:
z (`torch.FloatTensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice, num_frames // 2).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z, num_frames).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
for y in range(blend_extent):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for x in range(blend_extent):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
return b
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
r"""Encode a batch of images using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
output, but they should be much less noticeable.
Args:
x (`torch.FloatTensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
[`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
`tuple` is returned.
"""
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
# Split the image into 512x512 tiles and encode them separately.
rows = []
for i in range(0, x.shape[2], overlap_size):
row = []
for j in range(0, x.shape[3], overlap_size):
tile = x[
:,
:,
i : i + self.tile_sample_min_size,
j : j + self.tile_sample_min_size,
]
tile = self.encoder(tile)
tile = self.quant_conv(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=3))
moments = torch.cat(result_rows, dim=2)
posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
r"""
Decode a batch of images using a tiled decoder.
Args:
z (`torch.FloatTensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
row_limit = self.tile_sample_min_size - blend_extent
# Split z into overlapping 64x64 tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, z.shape[2], overlap_size):
row = []
for j in range(0, z.shape[3], overlap_size):
tile = z[
:,
:,
i : i + self.tile_latent_min_size,
j : j + self.tile_latent_min_size,
]
tile = self.post_quant_conv(tile)
decoded = self.decoder(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=3))
dec = torch.cat(result_rows, dim=2)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(
self,
sample: torch.FloatTensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
num_frames: int = 1,
) -> Union[DecoderOutput, torch.FloatTensor]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z, num_frames=num_frames).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
+3
View File
@@ -148,6 +148,9 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
self.tile_sample_min_size = 512
self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor
self.register_to_config(block_out_channels=decoder_block_out_channels)
self.register_to_config(force_upcast=False)
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
if isinstance(module, (EncoderTiny, DecoderTiny)):
module.gradient_checkpointing = value
@@ -138,6 +138,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
)
self.decoder_scheduler = ConsistencyDecoderScheduler()
self.register_to_config(block_out_channels=encoder_block_out_channels)
self.register_to_config(force_upcast=False)
self.register_buffer(
"means",
torch.tensor([0.38862467, 0.02253063, 0.07381133, -0.0171294])[None, :, None, None],
+9 -19
View File
@@ -12,6 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# IMPORTANT: #
###################################################################
# ----------------------------------------------------------------#
# This file is deprecated and will be removed soon #
# (as soon as PEFT will become a required dependency for LoRA) #
# ----------------------------------------------------------------#
###################################################################
from typing import Optional, Tuple, Union
import torch
@@ -57,25 +66,6 @@ def text_encoder_mlp_modules(text_encoder):
return mlp_modules
def text_encoder_lora_state_dict(text_encoder):
state_dict = {}
for name, module in text_encoder_attn_modules(text_encoder):
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
return state_dict
def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
+354 -12
View File
@@ -165,7 +165,10 @@ class Upsample2D(nn.Module):
self.Conv2d_0 = conv
def forward(
self, hidden_states: torch.FloatTensor, output_size: Optional[int] = None, scale: float = 1.0
self,
hidden_states: torch.FloatTensor,
output_size: Optional[int] = None,
scale: float = 1.0,
) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels
@@ -379,7 +382,11 @@ class FirUpsample2D(nn.Module):
weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
inverse_conv = F.conv_transpose2d(
hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
hidden_states,
weight,
stride=stride,
output_padding=output_padding,
padding=0,
)
output = upfirdn2d_native(
@@ -530,7 +537,14 @@ class KDownsample2D(nn.Module):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
weight = inputs.new_zeros(
[
inputs.shape[1],
inputs.shape[1],
self.kernel.shape[0],
self.kernel.shape[1],
]
)
indices = torch.arange(inputs.shape[1], device=inputs.device)
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
weight[indices, indices] = kernel
@@ -553,7 +567,14 @@ class KUpsample2D(nn.Module):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
weight = inputs.new_zeros(
[
inputs.shape[1],
inputs.shape[1],
self.kernel.shape[0],
self.kernel.shape[1],
]
)
indices = torch.arange(inputs.shape[1], device=inputs.device)
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
weight[indices, indices] = kernel
@@ -690,11 +711,19 @@ class ResnetBlock2D(nn.Module):
self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = conv_cls(
in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
in_channels,
conv_2d_out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=conv_shortcut_bias,
)
def forward(
self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, scale: float = 1.0
self,
input_tensor: torch.FloatTensor,
temb: torch.FloatTensor,
scale: float = 1.0,
) -> torch.FloatTensor:
hidden_states = input_tensor
@@ -866,7 +895,10 @@ class ResidualTemporalBlock1D(nn.Module):
def upsample_2d(
hidden_states: torch.FloatTensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
hidden_states: torch.FloatTensor,
kernel: Optional[torch.FloatTensor] = None,
factor: int = 2,
gain: float = 1,
) -> torch.FloatTensor:
r"""Upsample2D a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
@@ -910,7 +942,10 @@ def upsample_2d(
def downsample_2d(
hidden_states: torch.FloatTensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
hidden_states: torch.FloatTensor,
kernel: Optional[torch.FloatTensor] = None,
factor: int = 2,
gain: float = 1,
) -> torch.FloatTensor:
r"""Downsample2D a batch of 2D images with the given filter.
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
@@ -946,13 +981,20 @@ def downsample_2d(
kernel = kernel * gain
pad_value = kernel.shape[0] - factor
output = upfirdn2d_native(
hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
hidden_states,
kernel.to(device=hidden_states.device),
down=factor,
pad=((pad_value + 1) // 2, pad_value // 2),
)
return output
def upfirdn2d_native(
tensor: torch.Tensor, kernel: torch.Tensor, up: int = 1, down: int = 1, pad: Tuple[int, int] = (0, 0)
tensor: torch.Tensor,
kernel: torch.Tensor,
up: int = 1,
down: int = 1,
pad: Tuple[int, int] = (0, 0),
) -> torch.Tensor:
up_x = up_y = up
down_x = down_y = down
@@ -1008,7 +1050,13 @@ class TemporalConvLayer(nn.Module):
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
"""
def __init__(self, in_dim: int, out_dim: Optional[int] = None, dropout: float = 0.0, norm_num_groups: int = 32):
def __init__(
self,
in_dim: int,
out_dim: Optional[int] = None,
dropout: float = 0.0,
norm_num_groups: int = 32,
):
super().__init__()
out_dim = out_dim or in_dim
self.in_dim = in_dim
@@ -1016,7 +1064,9 @@ class TemporalConvLayer(nn.Module):
# conv layers
self.conv1 = nn.Sequential(
nn.GroupNorm(norm_num_groups, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))
nn.GroupNorm(norm_num_groups, in_dim),
nn.SiLU(),
nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)),
)
self.conv2 = nn.Sequential(
nn.GroupNorm(norm_num_groups, out_dim),
@@ -1058,3 +1108,295 @@ class TemporalConvLayer(nn.Module):
(hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
)
return hidden_states
class TemporalResnetBlock(nn.Module):
r"""
A Resnet block.
Parameters:
in_channels (`int`): The number of channels in the input.
out_channels (`int`, *optional*, default to be `None`):
The number of output channels for the first conv2d layer. If None, same as `in_channels`.
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
groups_out (`int`, *optional*, default to None):
The number of groups to use for the second normalization layer. if set to None, same as `groups`.
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
"ada_group" for a stronger conditioning with scale and shift.
kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
[`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
use_in_shortcut (`bool`, *optional*, default to `True`):
If `True`, add a 1x1 nn.conv2d layer for skip-connection.
up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
`conv_shortcut` output.
conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
If None, same as `out_channels`.
"""
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
conv_shortcut: bool = False,
dropout: float = 0.0,
temb_channels: int = 512,
groups: int = 32,
groups_out: Optional[int] = None,
eps: float = 1e-6,
non_linearity: str = "swish",
kernel_size: Optional[torch.FloatTensor] = (3, 1, 1),
output_scale_factor: float = 1.0,
use_in_shortcut: Optional[bool] = None,
conv_shortcut_bias: bool = True,
conv_2d_out_channels: Optional[int] = None,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.output_scale_factor = output_scale_factor
linear_cls = nn.Linear
conv_cls = nn.Conv3d
padding = [k // 2 for k in kernel_size]
if groups_out is None:
groups_out = groups
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = conv_cls(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=1,
padding=padding,
)
if temb_channels is not None:
self.time_emb_proj = linear_cls(temb_channels, out_channels)
else:
self.time_emb_proj = None
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.dropout = torch.nn.Dropout(dropout)
conv_2d_out_channels = conv_2d_out_channels or out_channels
self.conv2 = conv_cls(
out_channels,
conv_2d_out_channels,
kernel_size=kernel_size,
stride=1,
padding=padding,
)
self.nonlinearity = get_activation(non_linearity)
self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = conv_cls(
in_channels,
conv_2d_out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=conv_shortcut_bias,
)
def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
if self.time_emb_proj is not None:
temb = self.nonlinearity(temb)
temb = self.time_emb_proj(temb)[:, :, :, None, None]
if temb is not None:
temb = temb.permute(0, 2, 1, 3, 4)
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return output_tensor
# VideoResBlock
class SpatioTemporalResBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
dropout: float = 0.0,
temb_channels: int = 512,
groups: int = 32,
pre_norm: bool = True,
eps: float = 1e-6,
temporal_eps: Optional[float] = None,
non_linearity: str = "swish",
time_embedding_norm: str = "default", # default, scale_shift, ada_group, spatial
output_scale_factor: float = 1.0,
kernel_size_3d: Optional[torch.FloatTensor] = (3, 1, 1),
merge_factor: float = 0.5,
merge_strategy="learned",
switch_spatial_to_temporal_mix: bool = False,
):
super().__init__()
self.spatial_res_block = ResnetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=eps,
groups=groups,
dropout=dropout,
time_embedding_norm=time_embedding_norm,
non_linearity=non_linearity,
output_scale_factor=output_scale_factor,
pre_norm=pre_norm,
)
self.temporal_res_block = TemporalResnetBlock(
in_channels=out_channels if out_channels is not None else in_channels,
out_channels=out_channels if out_channels is not None else in_channels,
temb_channels=temb_channels,
eps=temporal_eps if temporal_eps is not None else eps,
groups=groups,
dropout=dropout,
non_linearity=non_linearity,
output_scale_factor=output_scale_factor,
kernel_size=kernel_size_3d,
)
self.time_mixer = AlphaBlender(
alpha=merge_factor,
merge_strategy=merge_strategy,
switch_spatial_to_temporal_mix=switch_spatial_to_temporal_mix,
)
def forward(
self,
hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None,
num_frames: int = 1,
image_only_indicator: Optional[torch.Tensor] = None,
scale: float = 1.0,
):
hidden_states = self.spatial_res_block(hidden_states, temb, scale=scale)
batch_frames, channels, height, width = hidden_states.shape
batch_size = batch_frames // num_frames
hidden_states_mix = (
hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
)
hidden_states = (
hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
)
if temb is not None:
temb = temb.reshape(batch_size, num_frames, -1)
hidden_states = self.temporal_res_block(hidden_states, temb)
hidden_states = self.time_mixer(
x_spatial=hidden_states_mix,
x_temporal=hidden_states,
image_only_indicator=image_only_indicator,
)
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width)
return hidden_states
class AlphaBlender(nn.Module):
strategies = ["learned", "fixed", "learned_with_images"]
def __init__(
self,
alpha: float,
merge_strategy: str = "learned_with_images",
switch_spatial_to_temporal_mix: bool = False,
):
super().__init__()
self.merge_strategy = merge_strategy
self.switch_spatial_to_temporal_mix = switch_spatial_to_temporal_mix # For TemporalVAE
assert merge_strategy in self.strategies, f"merge_strategy needs to be in {self.strategies}"
if self.merge_strategy == "fixed":
self.register_buffer("mix_factor", torch.Tensor([alpha]))
elif self.merge_strategy == "learned" or self.merge_strategy == "learned_with_images":
self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha])))
else:
raise ValueError(f"Unknown merge strategy {self.merge_strategy}")
def get_alpha(self, image_only_indicator: torch.Tensor, ndims: int) -> torch.Tensor:
if self.merge_strategy == "fixed":
alpha = self.mix_factor
elif self.merge_strategy == "learned":
alpha = torch.sigmoid(self.mix_factor)
elif self.merge_strategy == "learned_with_images":
assert (
image_only_indicator is not None
), "Please provide image_only_indicator to use learned_with_images merge strategy"
alpha = torch.where(
image_only_indicator.bool(),
torch.ones(1, 1, device=image_only_indicator.device),
torch.sigmoid(self.mix_factor)[..., None],
)
# (batch, channel, frames, height, width)
if ndims == 5:
alpha = alpha[:, None, :, None, None]
# (batch*frames, height*width, channels)
elif ndims == 3:
alpha = alpha.reshape(-1)[:, None, None]
else:
raise ValueError(f"Unexpected ndims {ndims}. Dimensions should be 3 or 5")
else:
raise NotImplementedError
return alpha
def forward(
self,
x_spatial: torch.Tensor,
x_temporal: torch.Tensor,
image_only_indicator: Optional[torch.Tensor] = None,
) -> torch.Tensor:
alpha = self.get_alpha(image_only_indicator, x_spatial.ndim)
alpha = alpha.to(x_spatial.dtype)
if self.switch_spatial_to_temporal_mix:
alpha = 1.0 - alpha
x = alpha * x_spatial + (1.0 - alpha) * x_temporal
return x
+20 -3
View File
@@ -20,7 +20,7 @@ from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..models.embeddings import ImagePositionalEmbeddings
from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate
from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
from .attention import BasicTransformerBlock
from .embeddings import CaptionProjection, PatchEmbed
from .lora import LoRACompatibleConv, LoRACompatibleLinear
@@ -70,6 +70,8 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
Configure if the `TransformerBlocks` attention should contain a bias parameter.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
@@ -237,6 +239,10 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
@@ -360,8 +366,19 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
block,
create_custom_forward(block),
hidden_states,
attention_mask,
encoder_hidden_states,
@@ -369,7 +386,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
timestep,
cross_attention_kwargs,
class_labels,
use_reentrant=False,
**ckpt_kwargs,
)
else:
hidden_states = block(
+229 -1
View File
@@ -19,8 +19,10 @@ from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .attention import BasicTransformerBlock
from .attention import BasicTransformerBlock, TemporalBasicTransformerBlock
from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .resnet import AlphaBlender
@dataclass
@@ -195,3 +197,229 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
return (output,)
return TransformerTemporalModelOutput(sample=output)
# VideoBlock
class TransformerSpatioTemporalModel(ModelMixin, ConfigMixin):
"""
A Transformer model for video-like data.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
The number of channels in the input and output (specify if the input is **continuous**).
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
attention_bias (`bool`, *optional*):
Configure if the `TransformerBlock` attention should contain a bias parameter.
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
This is fixed during training since it is used to learn a number of position embeddings.
activation_fn (`str`, *optional*, defaults to `"geglu"`):
Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
activation functions.
norm_elementwise_affine (`bool`, *optional*):
Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
double_self_attention (`bool`, *optional*):
Configure if each `TransformerBlock` should contain two self-attention layers.
positional_embeddings: (`str`, *optional*):
The type of positional embeddings to apply to the sequence input before passing use.
num_positional_embeddings: (`int`, *optional*):
The maximum length of the sequence over which to apply positional embeddings.
"""
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: int = 320,
out_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
norm_eps: float = 1e-5,
merge_factor: float = 0.5,
merge_strategy: str = "learned_with_images",
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
self.inner_dim = inner_dim
linear_cls = nn.Linear
# 2. Define input layers
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps)
self.proj_in = linear_cls(in_channels, inner_dim)
# 3. Define transformers blocks
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
norm_eps=norm_eps,
)
for d in range(num_layers)
]
)
time_mix_inner_dim = inner_dim
self.temporal_transformer_blocks = nn.ModuleList(
[
TemporalBasicTransformerBlock(
inner_dim,
time_mix_inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
time_embed_dim = in_channels * 4
self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels)
self.time_proj = Timesteps(in_channels, True, 0)
self.time_mixer = AlphaBlender(alpha=merge_factor, merge_strategy=merge_strategy)
# 4. Define output layers
self.out_channels = in_channels if out_channels is None else out_channels
# TODO: should use out_channels for continuous projections
self.proj_out = linear_cls(inner_dim, in_channels)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
num_frames: int,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
class_labels: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
image_only_indicator: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
"""
Args:
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
Input hidden_states.
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.LongTensor`, *optional*):
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
`AdaLayerZeroNorm`.
num_frames (`int`, *optional*, defaults to 1):
The number of frames to be processed per batch. This is used to reshape the hidden states.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
Returns:
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
returned, otherwise a `tuple` where the first element is the sample tensor.
"""
assert (
encoder_hidden_states.ndim == 3
), f"n dims of spatial context should be 3 but are {encoder_hidden_states.ndim}"
# 1. Input
batch_frames, channel, height, width = hidden_states.shape
batch_size = batch_frames // num_frames
time_context = encoder_hidden_states
time_context_first_timestep = time_context[::num_frames]
time_context = time_context_first_timestep.repeat(height * width, 1, 1)
residual = hidden_states
hidden_states = self.norm(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim)
hidden_states = self.proj_in(hidden_states)
num_frames_emb = torch.arange(num_frames, device=hidden_states.device)
num_frames_emb = num_frames_emb.repeat(batch_size, 1)
num_frames_emb = num_frames_emb.reshape(-1)
t_emb = self.time_proj(num_frames_emb)
# `Timesteps` does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=hidden_states.dtype)
emb = self.time_pos_embed(t_emb)
emb = emb[:, None, :]
# 2. Blocks
for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
if self.training and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
block,
hidden_states,
None,
encoder_hidden_states,
None,
timestep,
cross_attention_kwargs,
class_labels,
use_reentrant=False,
)
else:
hidden_states = block(
hidden_states,
attention_mask=None,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=None,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
)
hidden_states_mix = hidden_states
hidden_states_mix = hidden_states_mix + emb
hidden_states_mix = temporal_block(
hidden_states_mix,
num_frames=num_frames,
encoder_hidden_states=time_context,
cross_attention_kwargs=cross_attention_kwargs,
)
hidden_states = self.time_mixer(
x_spatial=hidden_states,
x_temporal=hidden_states_mix,
image_only_indicator=image_only_indicator,
)
# 3. Output
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual
if not return_dict:
return (output,)
return TransformerTemporalModelOutput(sample=output)
@@ -1022,6 +1022,15 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
)
image_embeds = added_cond_kwargs.get("image_embeds")
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
)
image_embeds = added_cond_kwargs.get("image_embeds")
image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
# 2. pre-process
sample = self.conv_in(sample)
File diff suppressed because it is too large Load Diff
+3 -1
View File
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
@@ -22,6 +23,7 @@ import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin
from ..utils import BaseOutput, logging
from .activations import get_activation
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
@@ -271,7 +273,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
)
self.conv_act = nn.SiLU()
self.conv_act = get_activation("silu")
else:
self.conv_norm_out = None
self.conv_act = None
+589
View File
@@ -0,0 +1,589 @@
import math
from dataclasses import dataclass
from typing import Dict, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging
from .attention_processor import AttentionProcessor, Kandi3AttnProcessor
from .embeddings import TimestepEmbedding
from .modeling_utils import ModelMixin
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class Kandinsky3UNetOutput(BaseOutput):
sample: torch.FloatTensor = None
# TODO(Yiyi): This class needs to be removed
def set_default_item(condition, item_1, item_2=None):
if condition:
return item_1
else:
return item_2
# TODO(Yiyi): This class needs to be removed
def set_default_layer(condition, layer_1, args_1=[], kwargs_1={}, layer_2=torch.nn.Identity, args_2=[], kwargs_2={}):
if condition:
return layer_1(*args_1, **kwargs_1)
else:
return layer_2(*args_2, **kwargs_2)
# TODO(Yiyi): This class should be removed and be replaced by Timesteps
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x, type_tensor=None):
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=x.device) * -emb)
emb = x[:, None] * emb[None, :]
return torch.cat((emb.sin(), emb.cos()), dim=-1)
class Kandinsky3EncoderProj(nn.Module):
def __init__(self, encoder_hid_dim, cross_attention_dim):
super().__init__()
self.projection_linear = nn.Linear(encoder_hid_dim, cross_attention_dim, bias=False)
self.projection_norm = nn.LayerNorm(cross_attention_dim)
def forward(self, x):
x = self.projection_linear(x)
x = self.projection_norm(x)
return x
class Kandinsky3UNet(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
in_channels: int = 4,
time_embedding_dim: int = 1536,
groups: int = 32,
attention_head_dim: int = 64,
layers_per_block: Union[int, Tuple[int]] = 3,
block_out_channels: Tuple[int] = (384, 768, 1536, 3072),
cross_attention_dim: Union[int, Tuple[int]] = 4096,
encoder_hid_dim: int = 4096,
):
super().__init__()
# TOOD(Yiyi): Give better name and put into config for the following 4 parameters
expansion_ratio = 4
compression_ratio = 2
add_cross_attention = (False, True, True, True)
add_self_attention = (False, True, True, True)
out_channels = in_channels
init_channels = block_out_channels[0] // 2
# TODO(Yiyi): Should be replaced with Timesteps class -> make sure that results are the same
# self.time_proj = Timesteps(init_channels, flip_sin_to_cos=False, downscale_freq_shift=1)
self.time_proj = SinusoidalPosEmb(init_channels)
self.time_embedding = TimestepEmbedding(
init_channels,
time_embedding_dim,
)
self.add_time_condition = Kandinsky3AttentionPooling(
time_embedding_dim, cross_attention_dim, attention_head_dim
)
self.conv_in = nn.Conv2d(in_channels, init_channels, kernel_size=3, padding=1)
self.encoder_hid_proj = Kandinsky3EncoderProj(encoder_hid_dim, cross_attention_dim)
hidden_dims = [init_channels] + list(block_out_channels)
in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:]))
text_dims = [set_default_item(is_exist, cross_attention_dim) for is_exist in add_cross_attention]
num_blocks = len(block_out_channels) * [layers_per_block]
layer_params = [num_blocks, text_dims, add_self_attention]
rev_layer_params = map(reversed, layer_params)
cat_dims = []
self.num_levels = len(in_out_dims)
self.down_blocks = nn.ModuleList([])
for level, ((in_dim, out_dim), res_block_num, text_dim, self_attention) in enumerate(
zip(in_out_dims, *layer_params)
):
down_sample = level != (self.num_levels - 1)
cat_dims.append(set_default_item(level != (self.num_levels - 1), out_dim, 0))
self.down_blocks.append(
Kandinsky3DownSampleBlock(
in_dim,
out_dim,
time_embedding_dim,
text_dim,
res_block_num,
groups,
attention_head_dim,
expansion_ratio,
compression_ratio,
down_sample,
self_attention,
)
)
self.up_blocks = nn.ModuleList([])
for level, ((out_dim, in_dim), res_block_num, text_dim, self_attention) in enumerate(
zip(reversed(in_out_dims), *rev_layer_params)
):
up_sample = level != 0
self.up_blocks.append(
Kandinsky3UpSampleBlock(
in_dim,
cat_dims.pop(),
out_dim,
time_embedding_dim,
text_dim,
res_block_num,
groups,
attention_head_dim,
expansion_ratio,
compression_ratio,
up_sample,
self_attention,
)
)
self.conv_norm_out = nn.GroupNorm(groups, init_channels)
self.conv_act_out = nn.SiLU()
self.conv_out = nn.Conv2d(init_channels, out_channels, kernel_size=3, padding=1)
@property
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "set_processor"):
processors[f"{name}.processor"] = module.processor
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
self.set_attn_processor(Kandi3AttnProcessor())
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(self, sample, timestep, encoder_hidden_states=None, encoder_attention_mask=None, return_dict=True):
# TODO(Yiyi): Clean up the following variables - these names should not be used
# but instead only the ones that we pass to forward
x = sample
context_mask = encoder_attention_mask
context = encoder_hidden_states
if not torch.is_tensor(timestep):
dtype = torch.float32 if isinstance(timestep, float) else torch.int32
timestep = torch.tensor([timestep], dtype=dtype, device=sample.device)
elif len(timestep.shape) == 0:
timestep = timestep[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = timestep.expand(sample.shape[0])
time_embed_input = self.time_proj(timestep).to(x.dtype)
time_embed = self.time_embedding(time_embed_input)
context = self.encoder_hid_proj(context)
if context is not None:
time_embed = self.add_time_condition(time_embed, context, context_mask)
hidden_states = []
x = self.conv_in(x)
for level, down_sample in enumerate(self.down_blocks):
x = down_sample(x, time_embed, context, context_mask)
if level != self.num_levels - 1:
hidden_states.append(x)
for level, up_sample in enumerate(self.up_blocks):
if level != 0:
x = torch.cat([x, hidden_states.pop()], dim=1)
x = up_sample(x, time_embed, context, context_mask)
x = self.conv_norm_out(x)
x = self.conv_act_out(x)
x = self.conv_out(x)
if not return_dict:
return (x,)
return Kandinsky3UNetOutput(sample=x)
class Kandinsky3UpSampleBlock(nn.Module):
def __init__(
self,
in_channels,
cat_dim,
out_channels,
time_embed_dim,
context_dim=None,
num_blocks=3,
groups=32,
head_dim=64,
expansion_ratio=4,
compression_ratio=2,
up_sample=True,
self_attention=True,
):
super().__init__()
up_resolutions = [[None, set_default_item(up_sample, True), None, None]] + [[None] * 4] * (num_blocks - 1)
hidden_channels = (
[(in_channels + cat_dim, in_channels)]
+ [(in_channels, in_channels)] * (num_blocks - 2)
+ [(in_channels, out_channels)]
)
attentions = []
resnets_in = []
resnets_out = []
self.self_attention = self_attention
self.context_dim = context_dim
attentions.append(
set_default_layer(
self_attention,
Kandinsky3AttentionBlock,
(out_channels, time_embed_dim, None, groups, head_dim, expansion_ratio),
layer_2=nn.Identity,
)
)
for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions):
resnets_in.append(
Kandinsky3ResNetBlock(in_channel, in_channel, time_embed_dim, groups, compression_ratio, up_resolution)
)
attentions.append(
set_default_layer(
context_dim is not None,
Kandinsky3AttentionBlock,
(in_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio),
layer_2=nn.Identity,
)
)
resnets_out.append(
Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio)
)
self.attentions = nn.ModuleList(attentions)
self.resnets_in = nn.ModuleList(resnets_in)
self.resnets_out = nn.ModuleList(resnets_out)
def forward(self, x, time_embed, context=None, context_mask=None, image_mask=None):
for attention, resnet_in, resnet_out in zip(self.attentions[1:], self.resnets_in, self.resnets_out):
x = resnet_in(x, time_embed)
if self.context_dim is not None:
x = attention(x, time_embed, context, context_mask, image_mask)
x = resnet_out(x, time_embed)
if self.self_attention:
x = self.attentions[0](x, time_embed, image_mask=image_mask)
return x
class Kandinsky3DownSampleBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
time_embed_dim,
context_dim=None,
num_blocks=3,
groups=32,
head_dim=64,
expansion_ratio=4,
compression_ratio=2,
down_sample=True,
self_attention=True,
):
super().__init__()
attentions = []
resnets_in = []
resnets_out = []
self.self_attention = self_attention
self.context_dim = context_dim
attentions.append(
set_default_layer(
self_attention,
Kandinsky3AttentionBlock,
(in_channels, time_embed_dim, None, groups, head_dim, expansion_ratio),
layer_2=nn.Identity,
)
)
up_resolutions = [[None] * 4] * (num_blocks - 1) + [[None, None, set_default_item(down_sample, False), None]]
hidden_channels = [(in_channels, out_channels)] + [(out_channels, out_channels)] * (num_blocks - 1)
for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions):
resnets_in.append(
Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio)
)
attentions.append(
set_default_layer(
context_dim is not None,
Kandinsky3AttentionBlock,
(out_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio),
layer_2=nn.Identity,
)
)
resnets_out.append(
Kandinsky3ResNetBlock(
out_channel, out_channel, time_embed_dim, groups, compression_ratio, up_resolution
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets_in = nn.ModuleList(resnets_in)
self.resnets_out = nn.ModuleList(resnets_out)
def forward(self, x, time_embed, context=None, context_mask=None, image_mask=None):
if self.self_attention:
x = self.attentions[0](x, time_embed, image_mask=image_mask)
for attention, resnet_in, resnet_out in zip(self.attentions[1:], self.resnets_in, self.resnets_out):
x = resnet_in(x, time_embed)
if self.context_dim is not None:
x = attention(x, time_embed, context, context_mask, image_mask)
x = resnet_out(x, time_embed)
return x
class Kandinsky3ConditionalGroupNorm(nn.Module):
def __init__(self, groups, normalized_shape, context_dim):
super().__init__()
self.norm = nn.GroupNorm(groups, normalized_shape, affine=False)
self.context_mlp = nn.Sequential(nn.SiLU(), nn.Linear(context_dim, 2 * normalized_shape))
self.context_mlp[1].weight.data.zero_()
self.context_mlp[1].bias.data.zero_()
def forward(self, x, context):
context = self.context_mlp(context)
for _ in range(len(x.shape[2:])):
context = context.unsqueeze(-1)
scale, shift = context.chunk(2, dim=1)
x = self.norm(x) * (scale + 1.0) + shift
return x
# TODO(Yiyi): This class should ideally not even exist, it slows everything needlessly down. I'm pretty
# sure we can delete it and instead just pass an attention_mask
class Attention(nn.Module):
def __init__(self, in_channels, out_channels, context_dim, head_dim=64):
super().__init__()
assert out_channels % head_dim == 0
self.num_heads = out_channels // head_dim
self.scale = head_dim**-0.5
# to_q
self.to_q = nn.Linear(in_channels, out_channels, bias=False)
# to_k
self.to_k = nn.Linear(context_dim, out_channels, bias=False)
# to_v
self.to_v = nn.Linear(context_dim, out_channels, bias=False)
processor = Kandi3AttnProcessor()
self.set_processor(processor)
# to_out
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(out_channels, out_channels, bias=False))
def set_processor(self, processor: "AttnProcessor"): # noqa: F821
# if current processor is in `self._modules` and if passed `processor` is not, we need to
# pop `processor` from `self._modules`
if (
hasattr(self, "processor")
and isinstance(self.processor, torch.nn.Module)
and not isinstance(processor, torch.nn.Module)
):
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
self._modules.pop("processor")
self.processor = processor
def forward(self, x, context, context_mask=None, image_mask=None):
return self.processor(
self,
x,
context=context,
context_mask=context_mask,
)
class Kandinsky3Block(nn.Module):
def __init__(self, in_channels, out_channels, time_embed_dim, kernel_size=3, norm_groups=32, up_resolution=None):
super().__init__()
self.group_norm = Kandinsky3ConditionalGroupNorm(norm_groups, in_channels, time_embed_dim)
self.activation = nn.SiLU()
self.up_sample = set_default_layer(
up_resolution is not None and up_resolution,
nn.ConvTranspose2d,
(in_channels, in_channels),
{"kernel_size": 2, "stride": 2},
)
padding = int(kernel_size > 1)
self.projection = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
self.down_sample = set_default_layer(
up_resolution is not None and not up_resolution,
nn.Conv2d,
(out_channels, out_channels),
{"kernel_size": 2, "stride": 2},
)
def forward(self, x, time_embed):
x = self.group_norm(x, time_embed)
x = self.activation(x)
x = self.up_sample(x)
x = self.projection(x)
x = self.down_sample(x)
return x
class Kandinsky3ResNetBlock(nn.Module):
def __init__(
self, in_channels, out_channels, time_embed_dim, norm_groups=32, compression_ratio=2, up_resolutions=4 * [None]
):
super().__init__()
kernel_sizes = [1, 3, 3, 1]
hidden_channel = max(in_channels, out_channels) // compression_ratio
hidden_channels = (
[(in_channels, hidden_channel)] + [(hidden_channel, hidden_channel)] * 2 + [(hidden_channel, out_channels)]
)
self.resnet_blocks = nn.ModuleList(
[
Kandinsky3Block(in_channel, out_channel, time_embed_dim, kernel_size, norm_groups, up_resolution)
for (in_channel, out_channel), kernel_size, up_resolution in zip(
hidden_channels, kernel_sizes, up_resolutions
)
]
)
self.shortcut_up_sample = set_default_layer(
True in up_resolutions, nn.ConvTranspose2d, (in_channels, in_channels), {"kernel_size": 2, "stride": 2}
)
self.shortcut_projection = set_default_layer(
in_channels != out_channels, nn.Conv2d, (in_channels, out_channels), {"kernel_size": 1}
)
self.shortcut_down_sample = set_default_layer(
False in up_resolutions, nn.Conv2d, (out_channels, out_channels), {"kernel_size": 2, "stride": 2}
)
def forward(self, x, time_embed):
out = x
for resnet_block in self.resnet_blocks:
out = resnet_block(out, time_embed)
x = self.shortcut_up_sample(x)
x = self.shortcut_projection(x)
x = self.shortcut_down_sample(x)
x = x + out
return x
class Kandinsky3AttentionPooling(nn.Module):
def __init__(self, num_channels, context_dim, head_dim=64):
super().__init__()
self.attention = Attention(context_dim, num_channels, context_dim, head_dim)
def forward(self, x, context, context_mask=None):
context = self.attention(context.mean(dim=1, keepdim=True), context, context_mask)
return x + context.squeeze(1)
class Kandinsky3AttentionBlock(nn.Module):
def __init__(self, num_channels, time_embed_dim, context_dim=None, norm_groups=32, head_dim=64, expansion_ratio=4):
super().__init__()
self.in_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim)
self.attention = Attention(num_channels, num_channels, context_dim or num_channels, head_dim)
hidden_channels = expansion_ratio * num_channels
self.out_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim)
self.feed_forward = nn.Sequential(
nn.Conv2d(num_channels, hidden_channels, kernel_size=1, bias=False),
nn.SiLU(),
nn.Conv2d(hidden_channels, num_channels, kernel_size=1, bias=False),
)
def forward(self, x, time_embed, context=None, context_mask=None, image_mask=None):
height, width = x.shape[-2:]
out = self.in_norm(x, time_embed)
out = out.reshape(x.shape[0], -1, height * width).permute(0, 2, 1)
context = context if context is not None else out
if image_mask is not None:
mask_height, mask_width = image_mask.shape[-2:]
kernel_size = (mask_height // height, mask_width // width)
image_mask = F.max_pool2d(image_mask, kernel_size, kernel_size)
image_mask = image_mask.reshape(image_mask.shape[0], -1)
out = self.attention(out, context, context_mask, image_mask)
out = out.permute(0, 2, 1).unsqueeze(-1).reshape(out.shape[0], -1, height, width)
x = x + out
out = self.out_norm(x, time_embed)
out = self.feed_forward(out)
x = x + out
return x
+16
View File
@@ -208,6 +208,8 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
motion_max_seq_length: int = 32,
motion_num_attention_heads: int = 8,
use_motion_mid_block: int = True,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
):
super().__init__()
@@ -248,6 +250,9 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
act_fn=act_fn,
)
if encoder_hid_dim_type is None:
self.encoder_hid_proj = None
# class embedding
self.down_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
@@ -684,6 +689,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None,
return_dict: bool = True,
@@ -767,6 +773,16 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
emb = self.time_embedding(t_emb, timestep_cond)
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
)
image_embeds = added_cond_kwargs.get("image_embeds")
image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
# 2. pre-process
@@ -0,0 +1,859 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin
from ..utils import (
USE_PEFT_BACKEND,
BaseOutput,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from .activations import get_activation
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class UNetSpatioTemporalConditionOutput(BaseOutput):
"""
The output of [`UNetSpatioTemporalConditionModel`].
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
"""
sample: torch.FloatTensor = None
class UNetSpatioTemporalConditionModel(
ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin
):
r"""
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
shaped output.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Parameters:
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
Height and width of input/output sample.
in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
The tuple of upsample blocks to use.
only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
Whether to include self-attention in the basic transformer blocks, see
[`~models.attention.BasicTransformerBlock`].
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
If `None`, normalization and activation layers is skipped in post-processing.
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
The dimension of the cross attention features.
transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
encoder_hid_dim (`int`, *optional*, defaults to None):
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
dimension to `cross_attention_dim`.
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
num_attention_heads (`int`, *optional*):
The number of attention heads. If not defined, defaults to `attention_head_dim`
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to `None`):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
addition_embed_type (`str`, *optional*, defaults to `None`):
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
"text". "text" will use the `TextTimeEmbedding` layer.
addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
Dimension for the timestep embeddings.
num_class_embeds (`int`, *optional*, defaults to `None`):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`.
time_embedding_type (`str`, *optional*, defaults to `positional`):
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
time_embedding_dim (`int`, *optional*, defaults to `None`):
An optional override for the dimension of the projected time embedding.
time_embedding_act_fn (`str`, *optional*, defaults to `None`):
Optional activation function to use only once on the time embeddings before they are passed to the rest of
the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
timestep_post_act (`str`, *optional*, defaults to `None`):
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
time_cond_proj_dim (`int`, *optional*, defaults to `None`):
The dimension of `cond_proj` layer in the timestep embedding.
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
*optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
*optional*): The dimension of the `class_labels` input when
`class_embed_type="projection"`. Required when `class_embed_type="projection"`.
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
embeddings with the class embeddings.
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
`only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
otherwise.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
sample_size: Optional[int] = None,
in_channels: int = 8,
out_channels: int = 4,
center_input_sample: bool = False,
flip_sin_to_cos: bool = True,
down_block_types: Tuple[str] = (
"CrossAttnDownBlockSpatioTemporal",
"CrossAttnDownBlockSpatioTemporal",
"CrossAttnDownBlockSpatioTemporal",
"DownBlockSpatioTemporal",
),
mid_block_type: Optional[str] = "UNetMidBlockSpatioTemporal",
up_block_types: Tuple[str] = (
"UpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
),
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
projection_class_embeddings_input_dim: int = 768,
addition_time_embed_dim: int = 256,
layers_per_block: Union[int, Tuple[int]] = 2,
mid_block_scale_factor: float = 1,
dropout: float = 0.0,
act_fn: str = "silu",
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: Union[int, Tuple[int]] = 1024,
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
attention_head_dim: Union[int, Tuple[int]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
time_embedding_dim: Optional[int] = None,
conv_in_kernel: int = 3,
conv_out_kernel: int = 3,
kernel_size_3d: Optional[torch.FloatTensor] = (3, 1, 1),
merge_factor: float = 0.5,
merge_strategy: str = "learned_with_images",
):
super().__init__()
self.sample_size = sample_size
if num_attention_heads is not None:
raise ValueError(
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
)
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = num_attention_heads or attention_head_dim
# Check inputs
if len(down_block_types) != len(up_block_types):
raise ValueError(
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
)
if len(block_out_channels) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(
down_block_types
):
raise ValueError(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(
down_block_types
):
raise ValueError(
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
)
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(
down_block_types
):
raise ValueError(
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
)
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(
down_block_types
):
raise ValueError(
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
)
# input
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = nn.Conv2d(
in_channels,
block_out_channels[0],
kernel_size=conv_in_kernel,
padding=conv_in_padding,
)
# time
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
self.time_proj = Timesteps(
block_out_channels[0], flip_sin_to_cos, downscale_freq_shift=0
)
timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(
timestep_input_dim,
time_embed_dim,
act_fn=act_fn,
)
self.add_time_proj = Timesteps(
addition_time_embed_dim, flip_sin_to_cos, downscale_freq_shift=0
)
self.add_embedding = TimestepEmbedding(
projection_class_embeddings_input_dim, time_embed_dim
)
self.down_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
if isinstance(cross_attention_dim, int):
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
if isinstance(layers_per_block, int):
layers_per_block = [layers_per_block] * len(down_block_types)
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(
down_block_types
)
blocks_time_embed_dim = time_embed_dim
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block[i],
transformer_layers_per_block=transformer_layers_per_block[i],
in_channels=input_channel,
out_channels=output_channel,
temb_channels=blocks_time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim[i],
num_attention_heads=num_attention_heads[i],
downsample_padding=1,
dropout=dropout,
kernel_size_3d=kernel_size_3d,
merge_factor=merge_factor,
merge_strategy=merge_strategy,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlockSpatioTemporal(
block_out_channels[-1],
temb_channels=blocks_time_embed_dim,
transformer_layers_per_block=transformer_layers_per_block[-1],
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
cross_attention_dim=cross_attention_dim[-1],
num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups,
dropout=dropout,
kernel_size_3d=kernel_size_3d,
merge_factor=merge_factor,
merge_strategy=merge_strategy,
)
# count how many layers upsample the images
self.num_upsamplers = 0
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_num_attention_heads = list(reversed(num_attention_heads))
reversed_layers_per_block = list(reversed(layers_per_block))
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
reversed_transformer_layers_per_block = list(
reversed(transformer_layers_per_block)
)
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
is_final_block = i == len(block_out_channels) - 1
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[
min(i + 1, len(block_out_channels) - 1)
]
# add upsample block for all BUT final layer
if not is_final_block:
add_upsample = True
self.num_upsamplers += 1
else:
add_upsample = False
up_block = get_up_block(
up_block_type,
num_layers=reversed_layers_per_block[i] + 1,
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=blocks_time_embed_dim,
add_upsample=add_upsample,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resolution_idx=i,
resnet_groups=norm_num_groups,
cross_attention_dim=reversed_cross_attention_dim[i],
num_attention_heads=reversed_num_attention_heads[i],
dropout=dropout,
kernel_size_3d=kernel_size_3d,
merge_factor=merge_factor,
merge_strategy=merge_strategy,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
)
self.conv_act = get_activation(act_fn)
conv_out_padding = (conv_out_kernel - 1) // 2
self.conv_out = nn.Conv2d(
block_out_channels[0],
out_channels,
kernel_size=conv_out_kernel,
padding=conv_out_padding,
)
@property
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(
name: str,
module: torch.nn.Module,
processors: Dict[str, AttentionProcessor],
):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(
return_deprecated_lora=True
)
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
def set_attn_processor(
self,
processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]],
_remove_lora=False,
):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(
processor.pop(f"{name}.processor"), _remove_lora=_remove_lora
)
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(
proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS
for proc in self.attn_processors.values()
):
processor = AttnAddedKVProcessor()
elif all(
proc.__class__ in CROSS_ATTENTION_PROCESSORS
for proc in self.attn_processors.values()
):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor, _remove_lora=True)
def set_attention_slice(self, slice_size):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""
sliceable_head_dims = []
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
if hasattr(module, "set_attention_slice"):
sliceable_head_dims.append(module.sliceable_head_dim)
for child in module.children():
fn_recursive_retrieve_sliceable_dims(child)
# retrieve number of attention layers
for module in self.children():
fn_recursive_retrieve_sliceable_dims(module)
num_sliceable_layers = len(sliceable_head_dims)
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = [dim // 2 for dim in sliceable_head_dims]
elif slice_size == "max":
# make smallest slice possible
slice_size = num_sliceable_layers * [1]
slice_size = (
num_sliceable_layers * [slice_size]
if not isinstance(slice_size, list)
else slice_size
)
if len(slice_size) != len(sliceable_head_dims):
raise ValueError(
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
)
for i in range(len(slice_size)):
size = slice_size[i]
dim = sliceable_head_dims[i]
if size is not None and size > dim:
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
# Recursively walk through all the children.
# Any children which exposes the set_attention_slice method
# gets the message
def fn_recursive_set_attention_slice(
module: torch.nn.Module, slice_size: List[int]
):
if hasattr(module, "set_attention_slice"):
module.set_attention_slice(slice_size.pop())
for child in module.children():
fn_recursive_set_attention_slice(child, slice_size)
reversed_slice_size = list(reversed(slice_size))
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def enable_freeu(self, s1, s2, b1, b2):
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
The suffixes after the scaling factors represent the stage blocks where they are being applied.
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
Args:
s1 (`float`):
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
mitigate the "oversmoothing effect" in the enhanced denoising process.
s2 (`float`):
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
mitigate the "oversmoothing effect" in the enhanced denoising process.
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
"""
for i, upsample_block in enumerate(self.up_blocks):
setattr(upsample_block, "s1", s1)
setattr(upsample_block, "s2", s2)
setattr(upsample_block, "b1", b1)
setattr(upsample_block, "b2", b2)
def disable_freeu(self):
"""Disables the FreeU mechanism."""
freeu_keys = {"s1", "s2", "b1", "b2"}
for i, upsample_block in enumerate(self.up_blocks):
for k in freeu_keys:
if (
hasattr(upsample_block, k)
or getattr(upsample_block, k, None) is not None
):
setattr(upsample_block, k, None)
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
added_time_ids: torch.Tensor,
image_only_indicator: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
r"""
The [`UNet2DConditionModel`] forward method.
Args:
sample (`torch.FloatTensor`):
The noisy input tensor with the following shape `(batch, channel, height, width)`.
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
encoder_hidden_states (`torch.FloatTensor`):
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
negative values to the attention scores corresponding to "discard" tokens.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
added_cond_kwargs: (`dict`):
A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
are passed along to the UNet blocks.
encoder_attention_mask (`torch.Tensor`):
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
which adds large negative values to the attention scores corresponding to "discard" tokens.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
Returns:
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
a `tuple` is returned where the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
default_overall_up_factor = 2**self.num_upsamplers
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None
for dim in sample.shape[-2:]:
if dim % default_overall_up_factor != 0:
# Forward upsample size to force interpolation output size.
forward_upsample_size = True
break
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is not None:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None:
encoder_attention_mask = (
1 - encoder_attention_mask.to(sample.dtype)
) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
batch_size, num_frames = sample.shape[:2]
timesteps = timesteps.expand(batch_size)
t_emb = self.time_proj(timesteps)
# `Timesteps` does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb)
time_embeds = self.add_time_proj(added_time_ids.flatten())
time_embeds = time_embeds.reshape((batch_size, -1))
time_embeds = time_embeds.to(emb.dtype)
aug_emb = self.add_embedding(time_embeds)
emb = emb + aug_emb
# Flatten the batch and frames dimensions
# sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
sample = sample.flatten(0, 1)
# Repeat the embeddings num_video_frames times
# emb: [batch, channels] -> [batch * frames, channels]
emb = emb.repeat_interleave(num_frames, dim=0)
# encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
encoder_hidden_states = encoder_hidden_states.repeat_interleave(
num_frames, dim=0
)
# 2. pre-process
sample = self.conv_in(sample)
# 3. down
lora_scale = (
cross_attention_kwargs.get("scale", 1.0)
if cross_attention_kwargs is not None
else 1.0
)
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
image_only_indicator = torch.zeros(
batch_size, num_frames, dtype=sample.dtype, device=sample.device
)
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if (
hasattr(downsample_block, "has_cross_attention")
and downsample_block.has_cross_attention
):
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
num_video_frames=num_frames,
image_only_indicator=image_only_indicator,
)
else:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
scale=lora_scale,
num_video_frames=num_frames,
image_only_indicator=image_only_indicator,
)
down_block_res_samples += res_samples
# 4. mid
if self.mid_block is not None:
if (
hasattr(self.mid_block, "has_cross_attention")
and self.mid_block.has_cross_attention
):
sample = self.mid_block(
hidden_states=sample,
temb=emb,
num_video_frames=num_frames,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
image_only_indicator=image_only_indicator,
)
else:
sample = self.mid_block(
sample,
temb=emb,
num_video_frames=num_frames,
image_only_indicator=image_only_indicator,
)
# 5. up
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[
: -len(upsample_block.resnets)
]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if (
hasattr(upsample_block, "has_cross_attention")
and upsample_block.has_cross_attention
):
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
upsample_size=upsample_size,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
num_video_frames=num_frames,
image_only_indicator=image_only_indicator,
)
else:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
upsample_size=upsample_size,
scale=lora_scale,
num_video_frames=num_frames,
image_only_indicator=image_only_indicator,
)
# 6. post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
# 7. Reshape back to original shape
sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (sample,)
return UNetSpatioTemporalConditionOutput(sample=sample)
+150 -42
View File
@@ -22,7 +22,12 @@ from ..utils import BaseOutput, is_torch_version
from ..utils.torch_utils import randn_tensor
from .activations import get_activation
from .attention_processor import SpatialNorm
from .unet_2d_blocks import AutoencoderTinyBlock, UNetMidBlock2D, get_down_block, get_up_block
from .unet_2d_blocks import (
AutoencoderTinyBlock,
UNetMidBlock2D,
get_down_block,
get_up_block,
)
@dataclass
@@ -122,11 +127,15 @@ class Encoder(nn.Module):
)
# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6
)
self.conv_act = nn.SiLU()
conv_out_channels = 2 * out_channels if double_z else out_channels
self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
self.conv_out = nn.Conv2d(
block_out_channels[-1], conv_out_channels, 3, padding=1
)
self.gradient_checkpointing = False
@@ -155,9 +164,13 @@ class Encoder(nn.Module):
)
else:
for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(down_block), sample
)
# middle
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample
)
else:
# down
@@ -267,14 +280,18 @@ class Decoder(nn.Module):
if norm_type == "spatial":
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
else:
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6
)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
self.gradient_checkpointing = False
def forward(
self, sample: torch.FloatTensor, latent_embeds: Optional[torch.FloatTensor] = None
self,
sample: torch.FloatTensor,
latent_embeds: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class."""
@@ -292,14 +309,20 @@ class Decoder(nn.Module):
if is_torch_version(">=", "1.11.0"):
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False
create_custom_forward(self.mid_block),
sample,
latent_embeds,
use_reentrant=False,
)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False
create_custom_forward(up_block),
sample,
latent_embeds,
use_reentrant=False,
)
else:
# middle
@@ -310,7 +333,9 @@ class Decoder(nn.Module):
# up
for up_block in self.up_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block), sample, latent_embeds
)
else:
# middle
sample = self.mid_block(sample, latent_embeds)
@@ -350,7 +375,9 @@ class UpSample(nn.Module):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
self.deconv = nn.ConvTranspose2d(
in_channels, out_channels, kernel_size=4, stride=2, padding=1
)
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `UpSample` class."""
@@ -394,9 +421,13 @@ class MaskConditionEncoder(nn.Module):
for l in range(len(out_channels)):
out_ch_ = out_channels[l]
if l == 0 or l == 1:
layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=3, stride=1, padding=1))
layers.append(
nn.Conv2d(in_ch_, out_ch_, kernel_size=3, stride=1, padding=1)
)
else:
layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=4, stride=2, padding=1))
layers.append(
nn.Conv2d(in_ch_, out_ch_, kernel_size=4, stride=2, padding=1)
)
in_ch_ = out_ch_
self.layers = nn.Sequential(*layers)
@@ -511,7 +542,9 @@ class MaskConditionDecoder(nn.Module):
if norm_type == "spatial":
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
else:
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6
)
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
@@ -540,7 +573,10 @@ class MaskConditionDecoder(nn.Module):
if is_torch_version(">=", "1.11.0"):
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False
create_custom_forward(self.mid_block),
sample,
latent_embeds,
use_reentrant=False,
)
sample = sample.to(upscale_dtype)
@@ -548,17 +584,25 @@ class MaskConditionDecoder(nn.Module):
if image is not None and mask is not None:
masked_image = (1 - mask) * image
im_x = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.condition_encoder), masked_image, mask, use_reentrant=False
create_custom_forward(self.condition_encoder),
masked_image,
mask,
use_reentrant=False,
)
# up
for up_block in self.up_blocks:
if image is not None and mask is not None:
sample_ = im_x[str(tuple(sample.shape))]
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
mask_ = nn.functional.interpolate(
mask, size=sample.shape[-2:], mode="nearest"
)
sample = sample * mask_ + sample_ * (1 - mask_)
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False
create_custom_forward(up_block),
sample,
latent_embeds,
use_reentrant=False,
)
if image is not None and mask is not None:
sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
@@ -573,16 +617,22 @@ class MaskConditionDecoder(nn.Module):
if image is not None and mask is not None:
masked_image = (1 - mask) * image
im_x = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.condition_encoder), masked_image, mask
create_custom_forward(self.condition_encoder),
masked_image,
mask,
)
# up
for up_block in self.up_blocks:
if image is not None and mask is not None:
sample_ = im_x[str(tuple(sample.shape))]
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
mask_ = nn.functional.interpolate(
mask, size=sample.shape[-2:], mode="nearest"
)
sample = sample * mask_ + sample_ * (1 - mask_)
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block), sample, latent_embeds
)
if image is not None and mask is not None:
sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
else:
@@ -599,7 +649,9 @@ class MaskConditionDecoder(nn.Module):
for up_block in self.up_blocks:
if image is not None and mask is not None:
sample_ = im_x[str(tuple(sample.shape))]
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
mask_ = nn.functional.interpolate(
mask, size=sample.shape[-2:], mode="nearest"
)
sample = sample * mask_ + sample_ * (1 - mask_)
sample = up_block(sample, latent_embeds)
if image is not None and mask is not None:
@@ -671,7 +723,9 @@ class VectorQuantizer(nn.Module):
new = match.argmax(-1)
unknown = match.sum(2) < 1
if self.unknown_index == "random":
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(
device=new.device
)
else:
new[unknown] = self.unknown_index
return new.reshape(ishape)
@@ -686,13 +740,17 @@ class VectorQuantizer(nn.Module):
back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
return back.reshape(ishape)
def forward(self, z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, Tuple]:
def forward(
self, z: torch.FloatTensor
) -> Tuple[torch.FloatTensor, torch.FloatTensor, Tuple]:
# reshape z -> (batch, height, width, channel) and flatten
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.vq_embed_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1)
min_encoding_indices = torch.argmin(
torch.cdist(z_flattened, self.embedding.weight), dim=1
)
z_q = self.embedding(min_encoding_indices).view(z.shape)
perplexity = None
@@ -700,9 +758,13 @@ class VectorQuantizer(nn.Module):
# compute loss for embedding
if not self.legacy:
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean(
(z_q - z.detach()) ** 2
)
else:
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
(z_q - z.detach()) ** 2
)
# preserve gradients
z_q: torch.FloatTensor = z + (z_q - z).detach()
@@ -711,16 +773,22 @@ class VectorQuantizer(nn.Module):
z_q = z_q.permute(0, 3, 1, 2).contiguous()
if self.remap is not None:
min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
min_encoding_indices = min_encoding_indices.reshape(
z.shape[0], -1
) # add batch axis
min_encoding_indices = self.remap_to_used(min_encoding_indices)
min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
if self.sane_index_shape:
min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
min_encoding_indices = min_encoding_indices.reshape(
z_q.shape[0], z_q.shape[2], z_q.shape[3]
)
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.FloatTensor:
def get_codebook_entry(
self, indices: torch.LongTensor, shape: Tuple[int, ...]
) -> torch.FloatTensor:
# shape specifying (batch, height, width, channel)
if self.remap is not None:
indices = indices.reshape(shape[0], -1) # add batch axis
@@ -754,7 +822,10 @@ class DiagonalGaussianDistribution(object):
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
# make sure sample is on the same device as the parameters and has same dtype
sample = randn_tensor(
self.mean.shape, generator=generator, device=self.parameters.device, dtype=self.parameters.dtype
self.mean.shape,
generator=generator,
device=self.parameters.device,
dtype=self.parameters.dtype,
)
x = self.mean + self.std * sample
return x
@@ -764,7 +835,10 @@ class DiagonalGaussianDistribution(object):
return torch.Tensor([0.0])
else:
if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3],
)
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
@@ -775,11 +849,16 @@ class DiagonalGaussianDistribution(object):
dim=[1, 2, 3],
)
def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
def nll(
self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]
) -> torch.Tensor:
if self.deterministic:
return torch.Tensor([0.0])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims,
)
def mode(self) -> torch.Tensor:
return self.mean
@@ -818,14 +897,27 @@ class EncoderTiny(nn.Module):
num_channels = block_out_channels[i]
if i == 0:
layers.append(nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1))
layers.append(
nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1)
)
else:
layers.append(nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1, stride=2, bias=False))
layers.append(
nn.Conv2d(
num_channels,
num_channels,
kernel_size=3,
padding=1,
stride=2,
bias=False,
)
)
for _ in range(num_block):
layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
layers.append(nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, padding=1))
layers.append(
nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, padding=1)
)
self.layers = nn.Sequential(*layers)
self.gradient_checkpointing = False
@@ -841,9 +933,13 @@ class EncoderTiny(nn.Module):
return custom_forward
if is_torch_version(">=", "1.11.0"):
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.layers), x, use_reentrant=False
)
else:
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.layers), x
)
else:
# scale image from [-1, 1] to [0, 1] to match TAESD convention
@@ -899,7 +995,15 @@ class DecoderTiny(nn.Module):
layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor))
conv_out_channel = num_channels if not is_final_block else out_channels
layers.append(nn.Conv2d(num_channels, conv_out_channel, kernel_size=3, padding=1, bias=is_final_block))
layers.append(
nn.Conv2d(
num_channels,
conv_out_channel,
kernel_size=3,
padding=1,
bias=is_final_block,
)
)
self.layers = nn.Sequential(*layers)
self.gradient_checkpointing = False
@@ -918,9 +1022,13 @@ class DecoderTiny(nn.Module):
return custom_forward
if is_torch_version(">=", "1.11.0"):
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.layers), x, use_reentrant=False
)
else:
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.layers), x
)
else:
x = self.layers(x)
+7
View File
@@ -110,6 +110,7 @@ else:
"KandinskyV22PriorEmb2EmbPipeline",
"KandinskyV22PriorPipeline",
]
_import_structure["kandinsky3"] = ["Kandinsky3Img2ImgPipeline", "Kandinsky3Pipeline"]
_import_structure["latent_consistency_models"] = [
"LatentConsistencyModelImg2ImgPipeline",
"LatentConsistencyModelPipeline",
@@ -144,6 +145,7 @@ else:
"StableDiffusionPix2PixZeroPipeline",
"StableDiffusionSAGPipeline",
"StableDiffusionUpscalePipeline",
"StableDiffusionVideoPipeline",
"StableUnCLIPImg2ImgPipeline",
"StableUnCLIPPipeline",
]
@@ -338,6 +340,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
KandinskyV22PriorEmb2EmbPipeline,
KandinskyV22PriorPipeline,
)
from .kandinsky3 import (
Kandinsky3Img2ImgPipeline,
Kandinsky3Pipeline,
)
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
from .latent_diffusion import LDMTextToImagePipeline
from .musicldm import MusicLDMPipeline
@@ -367,6 +373,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionPix2PixZeroPipeline,
StableDiffusionSAGPipeline,
StableDiffusionUpscalePipeline,
StableDiffusionVideoPipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
)
@@ -17,11 +17,11 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch
from packaging import version
from transformers import CLIPImageProcessor, XLMRobertaTokenizer
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, XLMRobertaTokenizer
from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
@@ -73,8 +73,55 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
return noise_cfg
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used,
`timesteps` must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
class AltDiffusionPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
):
r"""
Pipeline for text-to-image generation using Alt Diffusion.
@@ -86,6 +133,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args:
vae ([`AutoencoderKL`]):
@@ -108,7 +156,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
_exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
@@ -121,6 +169,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -197,10 +246,9 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (
len(getattr(self.vae.config, "block_out_channels", self.vae.config.decoder_block_out_channels)) - 1
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -446,6 +494,19 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
return prompt_embeds, negative_prompt_embeds
def encode_image(self, image, device, num_images_per_prompt):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
has_nsfw_concept = None
@@ -646,6 +707,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
timesteps: List[int] = None,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
@@ -654,6 +716,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -676,6 +739,10 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
@@ -700,6 +767,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
@@ -799,15 +867,20 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
lora_scale=lora_scale,
clip_skip=self.clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None:
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
@@ -825,7 +898,10 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 6.5 Optionally get Guidance Scale Embedding
# 6.1 Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
# 6.2 Optionally get Guidance Scale Embedding
timestep_cond = None
if self.unet.config.time_cond_proj_dim is not None:
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
@@ -849,6 +925,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
@@ -19,11 +19,11 @@ import numpy as np
import PIL.Image
import torch
from packaging import version
from transformers import CLIPImageProcessor, XLMRobertaTokenizer
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, XLMRobertaTokenizer
from ...configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
@@ -76,9 +76,13 @@ EXAMPLE_DOC_STRING = """
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
@@ -109,9 +113,54 @@ def preprocess(image):
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used,
`timesteps` must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
class AltDiffusionImg2ImgPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin
):
r"""
Pipeline for text-guided image-to-image generation using Alt Diffusion.
@@ -124,6 +173,7 @@ class AltDiffusionImg2ImgPipeline(
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args:
vae ([`AutoencoderKL`]):
@@ -146,7 +196,7 @@ class AltDiffusionImg2ImgPipeline(
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
_exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
@@ -159,6 +209,7 @@ class AltDiffusionImg2ImgPipeline(
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -235,10 +286,9 @@ class AltDiffusionImg2ImgPipeline(
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (
len(getattr(self.vae.config, "block_out_channels", self.vae.config.decoder_block_out_channels)) - 1
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -455,6 +505,19 @@ class AltDiffusionImg2ImgPipeline(
return prompt_embeds, negative_prompt_embeds
def encode_image(self, image, device, num_images_per_prompt):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
has_nsfw_concept = None
@@ -700,6 +763,7 @@ class AltDiffusionImg2ImgPipeline(
image: PipelineImageInput = None,
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
timesteps: List[int] = None,
guidance_scale: Optional[float] = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
@@ -707,6 +771,7 @@ class AltDiffusionImg2ImgPipeline(
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -736,6 +801,10 @@ class AltDiffusionImg2ImgPipeline(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. This parameter is modulated by `strength`.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
@@ -756,6 +825,7 @@ class AltDiffusionImg2ImgPipeline(
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
@@ -848,11 +918,16 @@ class AltDiffusionImg2ImgPipeline(
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None:
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Preprocess image
image = self.image_processor.preprocess(image)
# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
@@ -870,7 +945,10 @@ class AltDiffusionImg2ImgPipeline(
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7.5 Optionally get Guidance Scale Embedding
# 7.1 Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
# 7.2 Optionally get Guidance Scale Embedding
timestep_cond = None
if self.unet.config.time_cond_proj_dim is not None:
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
@@ -894,6 +972,7 @@ class AltDiffusionImg2ImgPipeline(
encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
@@ -18,10 +18,10 @@ from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel, UNetMotionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...models.unet_motion_model import MotionAdapter
@@ -77,7 +77,7 @@ class AnimateDiffPipelineOutput(BaseOutput):
frames: Union[torch.Tensor, np.ndarray]
class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin):
r"""
Pipeline for text-to-video generation.
@@ -101,6 +101,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLo
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["feature_extractor", "image_encoder"]
def __init__(
self,
@@ -117,6 +118,8 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLo
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
],
feature_extractor: CLIPImageProcessor = None,
image_encoder: CLIPVisionModelWithProjection = None,
):
super().__init__()
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
@@ -128,10 +131,10 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLo
unet=unet,
motion_adapter=motion_adapter,
scheduler=scheduler,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (
len(getattr(self.vae.config, "block_out_channels", self.vae.config.decoder_block_out_channels)) - 1
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
@@ -316,6 +319,20 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLo
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
@@ -514,6 +531,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLo
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
@@ -560,6 +578,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLo
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or
`np.array`.
@@ -631,6 +650,11 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLo
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None:
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_videos_per_prompt)
if do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
@@ -651,6 +675,8 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLo
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7 Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
# Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
@@ -666,6 +692,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLo
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
).sample
# perform guidance
@@ -94,9 +94,7 @@ class AudioLDMPipeline(DiffusionPipeline):
scheduler=scheduler,
vocoder=vocoder,
)
self.vae_scale_factor = 2 ** (
len(getattr(self.vae.config, "block_out_channels", self.vae.config.decoder_block_out_channels)) - 1
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
def enable_vae_slicing(self):
@@ -171,9 +171,7 @@ class AudioLDM2Pipeline(DiffusionPipeline):
scheduler=scheduler,
vocoder=vocoder,
)
self.vae_scale_factor = 2 ** (
len(getattr(self.vae.config, "block_out_channels", self.vae.config.decoder_block_out_channels)) - 1
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
def enable_vae_slicing(self):
+3
View File
@@ -42,6 +42,7 @@ from .kandinsky2_2 import (
KandinskyV22InpaintPipeline,
KandinskyV22Pipeline,
)
from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
from .pixart_alpha import PixArtAlphaPipeline
from .stable_diffusion import (
@@ -64,6 +65,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("if", IFPipeline),
("kandinsky", KandinskyCombinedPipeline),
("kandinsky22", KandinskyV22CombinedPipeline),
("kandinsky3", Kandinsky3Pipeline),
("stable-diffusion-controlnet", StableDiffusionControlNetPipeline),
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline),
("wuerstchen", WuerstchenCombinedPipeline),
@@ -79,6 +81,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
("if", IFImg2ImgPipeline),
("kandinsky", KandinskyImg2ImgCombinedPipeline),
("kandinsky22", KandinskyV22Img2ImgCombinedPipeline),
("kandinsky3", Kandinsky3Img2ImgPipeline),
("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline),
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline),
("lcm", LatentConsistencyModelImg2ImgPipeline),
@@ -20,10 +20,10 @@ import numpy as np
import PIL.Image
import torch
import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
@@ -91,8 +91,53 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used,
`timesteps` must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class StableDiffusionControlNetPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
):
r"""
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
@@ -102,6 +147,7 @@ class StableDiffusionControlNetPipeline(
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args:
vae ([`AutoencoderKL`]):
@@ -128,8 +174,9 @@ class StableDiffusionControlNetPipeline(
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
_exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
@@ -141,6 +188,7 @@ class StableDiffusionControlNetPipeline(
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -173,10 +221,9 @@ class StableDiffusionControlNetPipeline(
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (
len(getattr(self.vae.config, "block_out_channels", self.vae.config.decoder_block_out_channels)) - 1
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -431,6 +478,20 @@ class StableDiffusionControlNetPipeline(
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
@@ -487,15 +548,21 @@ class StableDiffusionControlNetPipeline(
controlnet_conditioning_scale=1.0,
control_guidance_start=0.0,
control_guidance_end=1.0,
callback_on_step_end_tensor_inputs=None,
):
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
@@ -762,6 +829,10 @@ class StableDiffusionControlNetPipeline(
def guidance_scale(self):
return self._guidance_scale
@property
def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@@ -769,6 +840,14 @@ class StableDiffusionControlNetPipeline(
def do_classifier_free_guidance(self):
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -778,6 +857,7 @@ class StableDiffusionControlNetPipeline(
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
timesteps: List[int] = None,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
@@ -786,16 +866,18 @@ class StableDiffusionControlNetPipeline(
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
):
r"""
The call function to the pipeline for generation.
@@ -818,6 +900,10 @@ class StableDiffusionControlNetPipeline(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
@@ -842,6 +928,7 @@ class StableDiffusionControlNetPipeline(
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
@@ -870,6 +957,15 @@ class StableDiffusionControlNetPipeline(
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples:
@@ -880,6 +976,23 @@ class StableDiffusionControlNetPipeline(
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# align format for control guidance
@@ -905,9 +1018,12 @@ class StableDiffusionControlNetPipeline(
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
@@ -931,7 +1047,7 @@ class StableDiffusionControlNetPipeline(
# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
@@ -942,7 +1058,7 @@ class StableDiffusionControlNetPipeline(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=clip_skip,
clip_skip=self.clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
@@ -950,6 +1066,11 @@ class StableDiffusionControlNetPipeline(
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None:
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Prepare image
if isinstance(controlnet, ControlNetModel):
image = self.prepare_image(
@@ -988,8 +1109,8 @@ class StableDiffusionControlNetPipeline(
assert False
# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
self._num_timesteps = len(timesteps)
# 6. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
@@ -1015,7 +1136,10 @@ class StableDiffusionControlNetPipeline(
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7.1 Create tensor stating which controlnets to keep
# 7.1 Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
# 7.2 Create tensor stating which controlnets to keep
controlnet_keep = []
for i in range(len(timesteps)):
keeps = [
@@ -1080,20 +1204,31 @@ class StableDiffusionControlNetPipeline(
t,
encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=cross_attention_kwargs,
cross_attention_kwargs=self.cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
@@ -92,9 +92,13 @@ EXAMPLE_DOC_STRING = """
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
@@ -164,6 +168,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
@@ -208,9 +213,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (
len(getattr(self.vae.config, "block_out_channels", self.vae.config.decoder_block_out_channels)) - 1
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -521,15 +524,21 @@ class StableDiffusionControlNetImg2ImgPipeline(
controlnet_conditioning_scale=1.0,
control_guidance_start=0.0,
control_guidance_end=1.0,
callback_on_step_end_tensor_inputs=None,
):
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
@@ -810,6 +819,29 @@ class StableDiffusionControlNetImg2ImgPipeline(
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
@property
def guidance_scale(self):
return self._guidance_scale
@property
def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -831,14 +863,15 @@ class StableDiffusionControlNetImg2ImgPipeline(
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 0.8,
guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
):
r"""
The call function to the pipeline for generation.
@@ -894,12 +927,6 @@ class StableDiffusionControlNetImg2ImgPipeline(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
@@ -917,6 +944,15 @@ class StableDiffusionControlNetImg2ImgPipeline(
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples:
@@ -927,6 +963,23 @@ class StableDiffusionControlNetImg2ImgPipeline(
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# align format for control guidance
@@ -952,8 +1005,13 @@ class StableDiffusionControlNetImg2ImgPipeline(
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -963,10 +1021,6 @@ class StableDiffusionControlNetImg2ImgPipeline(
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
@@ -980,23 +1034,23 @@ class StableDiffusionControlNetImg2ImgPipeline(
# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
self.do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=clip_skip,
clip_skip=self.clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if do_classifier_free_guidance:
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 4. Prepare image
@@ -1012,7 +1066,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
elif isinstance(controlnet, MultiControlNetModel):
@@ -1027,7 +1081,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
@@ -1041,6 +1095,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
self._num_timesteps = len(timesteps)
# 6. Prepare latent variables
latents = self.prepare_latents(
@@ -1070,11 +1125,11 @@ class StableDiffusionControlNetImg2ImgPipeline(
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# controlnet(s) inference
if guess_mode and do_classifier_free_guidance:
if guess_mode and self.do_classifier_free_guidance:
# Infer ControlNet only for the conditional batch.
control_model_input = latents
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
@@ -1101,7 +1156,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
return_dict=False,
)
if guess_mode and do_classifier_free_guidance:
if guess_mode and self.do_classifier_free_guidance:
# Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
@@ -1113,20 +1168,30 @@ class StableDiffusionControlNetImg2ImgPipeline(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
cross_attention_kwargs=self.cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
@@ -104,9 +104,13 @@ EXAMPLE_DOC_STRING = """
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
@@ -286,6 +290,7 @@ class StableDiffusionControlNetInpaintPipeline(
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
@@ -330,9 +335,7 @@ class StableDiffusionControlNetInpaintPipeline(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (
len(getattr(self.vae.config, "block_out_channels", self.vae.config.decoder_block_out_channels)) - 1
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
@@ -658,18 +661,24 @@ class StableDiffusionControlNetInpaintPipeline(
controlnet_conditioning_scale=1.0,
control_guidance_start=0.0,
control_guidance_end=1.0,
callback_on_step_end_tensor_inputs=None,
):
if height is not None and height % 8 != 0 or width is not None and width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
@@ -1001,6 +1010,29 @@ class StableDiffusionControlNetInpaintPipeline(
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
@property
def guidance_scale(self):
return self._guidance_scale
@property
def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -1023,14 +1055,15 @@ class StableDiffusionControlNetInpaintPipeline(
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 0.5,
guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
):
r"""
The call function to the pipeline for generation.
@@ -1103,12 +1136,6 @@ class StableDiffusionControlNetInpaintPipeline(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
@@ -1126,6 +1153,15 @@ class StableDiffusionControlNetInpaintPipeline(
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples:
@@ -1136,6 +1172,23 @@ class StableDiffusionControlNetInpaintPipeline(
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# align format for control guidance
@@ -1163,8 +1216,13 @@ class StableDiffusionControlNetInpaintPipeline(
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -1174,10 +1232,6 @@ class StableDiffusionControlNetInpaintPipeline(
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
@@ -1191,23 +1245,23 @@ class StableDiffusionControlNetInpaintPipeline(
# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
self.do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=clip_skip,
clip_skip=self.clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if do_classifier_free_guidance:
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# 4. Prepare image
@@ -1220,7 +1274,7 @@ class StableDiffusionControlNetInpaintPipeline(
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
elif isinstance(controlnet, MultiControlNetModel):
@@ -1235,7 +1289,7 @@ class StableDiffusionControlNetInpaintPipeline(
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
@@ -1263,6 +1317,7 @@ class StableDiffusionControlNetInpaintPipeline(
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
is_strength_max = strength == 1.0
self._num_timesteps = len(timesteps)
# 6. Prepare latent variables
num_channels_latents = self.vae.config.latent_channels
@@ -1299,7 +1354,7 @@ class StableDiffusionControlNetInpaintPipeline(
prompt_embeds.dtype,
device,
generator,
do_classifier_free_guidance,
self.do_classifier_free_guidance,
)
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
@@ -1319,11 +1374,11 @@ class StableDiffusionControlNetInpaintPipeline(
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# controlnet(s) inference
if guess_mode and do_classifier_free_guidance:
if guess_mode and self.do_classifier_free_guidance:
# Infer ControlNet only for the conditional batch.
control_model_input = latents
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
@@ -1350,7 +1405,7 @@ class StableDiffusionControlNetInpaintPipeline(
return_dict=False,
)
if guess_mode and do_classifier_free_guidance:
if guess_mode and self.do_classifier_free_guidance:
# Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
@@ -1365,14 +1420,14 @@ class StableDiffusionControlNetInpaintPipeline(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
cross_attention_kwargs=self.cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
@@ -1381,7 +1436,7 @@ class StableDiffusionControlNetInpaintPipeline(
if num_channels_unet == 4:
init_latents_proper = image_latents
if do_classifier_free_guidance:
if self.do_classifier_free_guidance:
init_mask, _ = mask.chunk(2)
else:
init_mask = mask
@@ -1394,6 +1449,16 @@ class StableDiffusionControlNetInpaintPipeline(
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
@@ -34,6 +34,7 @@ from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
is_invisible_watermark_available,
logging,
replace_example_docstring,
@@ -53,6 +54,20 @@ if is_invisible_watermark_available():
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -167,6 +182,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
@@ -199,9 +215,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
self.vae_scale_factor = 2 ** (
len(getattr(self.vae.config, "block_out_channels", self.vae.config.decoder_block_out_channels)) - 1
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
@@ -557,6 +571,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
controlnet_conditioning_scale=1.0,
control_guidance_start=0.0,
control_guidance_end=1.0,
callback_on_step_end_tensor_inputs=None,
):
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
@@ -567,14 +582,20 @@ class StableDiffusionXLControlNetInpaintPipeline(
f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type"
f" {type(num_inference_steps)}."
)
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
@@ -817,12 +838,12 @@ class StableDiffusionXLControlNetInpaintPipeline(
if isinstance(generator, list):
image_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
if self.vae.config.force_upcast:
self.vae.to(dtype)
@@ -1010,6 +1031,29 @@ class StableDiffusionXLControlNetInpaintPipeline(
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
@property
def guidance_scale(self):
return self._guidance_scale
@property
def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -1041,8 +1085,6 @@ class StableDiffusionXLControlNetInpaintPipeline(
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
guess_mode: bool = False,
@@ -1055,6 +1097,9 @@ class StableDiffusionXLControlNetInpaintPipeline(
aesthetic_score: float = 6.0,
negative_aesthetic_score: float = 2.5,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -1149,12 +1194,6 @@ class StableDiffusionXLControlNetInpaintPipeline(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
@@ -1184,6 +1223,15 @@ class StableDiffusionXLControlNetInpaintPipeline(
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples:
@@ -1192,6 +1240,23 @@ class StableDiffusionXLControlNetInpaintPipeline(
[`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
`tuple. `tuple. When returning a tuple, the first element is a list with the generated images.
"""
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# align format for control guidance
@@ -1239,8 +1304,13 @@ class StableDiffusionXLControlNetInpaintPipeline(
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -1250,17 +1320,13 @@ class StableDiffusionXLControlNetInpaintPipeline(
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
)
(
@@ -1273,7 +1339,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
prompt_2=prompt_2,
device=device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds,
@@ -1281,7 +1347,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=clip_skip,
clip_skip=self.clip_skip,
)
# 4. set timesteps
@@ -1302,6 +1368,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
is_strength_max = strength == 1.0
self._num_timesteps = len(timesteps)
# 5. Preprocess mask and image - resizes image and mask w.r.t height and width
# 5.1 Prepare init image
@@ -1318,7 +1385,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
elif isinstance(controlnet, MultiControlNetModel):
@@ -1333,7 +1400,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
@@ -1387,7 +1454,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
prompt_embeds.dtype,
device,
generator,
do_classifier_free_guidance,
self.do_classifier_free_guidance,
)
# 8. Check that sizes of mask, masked image and latents match
@@ -1448,7 +1515,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
)
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
if do_classifier_free_guidance:
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
@@ -1485,7 +1552,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
# concat latents, mask, masked_image_latents in the channel dimension
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -1493,7 +1560,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
# controlnet(s) inference
if guess_mode and do_classifier_free_guidance:
if guess_mode and self.do_classifier_free_guidance:
# Infer ControlNet only for the conditional batch.
control_model_input = latents
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
@@ -1530,7 +1597,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
return_dict=False,
)
if guess_mode and do_classifier_free_guidance:
if guess_mode and self.do_classifier_free_guidance:
# Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
@@ -1545,7 +1612,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
cross_attention_kwargs=self.cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
added_cond_kwargs=added_cond_kwargs,
@@ -1553,11 +1620,11 @@ class StableDiffusionXLControlNetInpaintPipeline(
)[0]
# perform guidance
if do_classifier_free_guidance:
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if do_classifier_free_guidance and guidance_rescale > 0.0:
if self.do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
@@ -1566,7 +1633,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
if num_channels_unet == 4:
init_latents_proper = image_latents
if do_classifier_free_guidance:
if self.do_classifier_free_guidance:
init_mask, _ = mask.chunk(2)
else:
init_mask = mask
@@ -1579,6 +1646,16 @@ class StableDiffusionXLControlNetInpaintPipeline(
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
@@ -20,12 +20,23 @@ import numpy as np
import PIL.Image
import torch
import torch.nn.functional as F
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from transformers import (
CLIPImageProcessor,
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionModelWithProjection,
)
from diffusers.utils.import_utils import is_invisible_watermark_available
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
from ...loaders import (
FromSingleFileMixin,
IPAdapterMixin,
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
)
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
@@ -35,7 +46,14 @@ from ...models.attention_processor import (
)
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
@@ -97,7 +115,11 @@ EXAMPLE_DOC_STRING = """
class StableDiffusionXLControlNetPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin
DiffusionPipeline,
TextualInversionLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
IPAdapterMixin,
FromSingleFileMixin,
):
r"""
Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance.
@@ -142,7 +164,15 @@ class StableDiffusionXLControlNetPipeline(
# leave controlnet out on purpose because it iterates with unet
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
_optional_components = [
"tokenizer",
"tokenizer_2",
"text_encoder",
"text_encoder_2",
"feature_extractor",
"image_encoder",
]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
@@ -156,6 +186,8 @@ class StableDiffusionXLControlNetPipeline(
scheduler: KarrasDiffusionSchedulers,
force_zeros_for_empty_prompt: bool = True,
add_watermarker: Optional[bool] = None,
feature_extractor: CLIPImageProcessor = None,
image_encoder: CLIPVisionModelWithProjection = None,
):
super().__init__()
@@ -171,10 +203,10 @@ class StableDiffusionXLControlNetPipeline(
unet=unet,
controlnet=controlnet,
scheduler=scheduler,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
out_channels = getattr(self.vae.config, "block_out_channels", self.vae.config.decoder_block_out_channels)
self.vae_scale_factor = 2 ** (len(out_channels) - 1)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -456,6 +488,20 @@ class StableDiffusionXLControlNetPipeline(
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@@ -489,15 +535,21 @@ class StableDiffusionXLControlNetPipeline(
controlnet_conditioning_scale=1.0,
control_guidance_start=0.0,
control_guidance_end=1.0,
callback_on_step_end_tensor_inputs=None,
):
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
@@ -827,6 +879,10 @@ class StableDiffusionXLControlNetPipeline(
def guidance_scale(self):
return self._guidance_scale
@property
def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@@ -834,6 +890,14 @@ class StableDiffusionXLControlNetPipeline(
def do_classifier_free_guidance(self):
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -855,10 +919,9 @@ class StableDiffusionXLControlNetPipeline(
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
guess_mode: bool = False,
@@ -871,6 +934,9 @@ class StableDiffusionXLControlNetPipeline(
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
negative_target_size: Optional[Tuple[int, int]] = None,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
):
r"""
The call function to the pipeline for generation.
@@ -934,17 +1000,12 @@ class StableDiffusionXLControlNetPipeline(
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt
weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input
argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
@@ -991,6 +1052,15 @@ class StableDiffusionXLControlNetPipeline(
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples:
@@ -999,6 +1069,23 @@ class StableDiffusionXLControlNetPipeline(
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
otherwise a `tuple` is returned containing the output images.
"""
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# align format for control guidance
@@ -1028,9 +1115,12 @@ class StableDiffusionXLControlNetPipeline(
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
@@ -1052,9 +1142,9 @@ class StableDiffusionXLControlNetPipeline(
)
guess_mode = guess_mode or global_pool_conditions
# 3. Encode input prompt
# 3.1 Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
)
(
prompt_embeds,
@@ -1074,9 +1164,15 @@ class StableDiffusionXLControlNetPipeline(
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=clip_skip,
clip_skip=self.clip_skip,
)
# 3.2 Encode ip_adapter_image
if ip_adapter_image is not None:
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Prepare image
if isinstance(controlnet, ControlNetModel):
image = self.prepare_image(
@@ -1117,6 +1213,7 @@ class StableDiffusionXLControlNetPipeline(
# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
self._num_timesteps = len(timesteps)
# 6. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
@@ -1250,13 +1347,16 @@ class StableDiffusionXLControlNetPipeline(
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
if ip_adapter_image is not None:
added_cond_kwargs["image_embeds"] = image_embeds
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=cross_attention_kwargs,
cross_attention_kwargs=self.cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
added_cond_kwargs=added_cond_kwargs,
@@ -1271,6 +1371,16 @@ class StableDiffusionXLControlNetPipeline(
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
@@ -37,6 +37,7 @@ from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
@@ -132,9 +133,13 @@ EXAMPLE_DOC_STRING = """
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
@@ -195,6 +200,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
_optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
@@ -225,9 +231,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
controlnet=controlnet,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (
len(getattr(self.vae.config, "block_out_channels", self.vae.config.decoder_block_out_channels)) - 1
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.control_image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -545,6 +549,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
controlnet_conditioning_scale=1.0,
control_guidance_start=0.0,
control_guidance_end=1.0,
callback_on_step_end_tensor_inputs=None,
):
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
@@ -555,14 +560,20 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type"
f" {type(num_inference_steps)}."
)
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
@@ -953,6 +964,29 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()
@property
def guidance_scale(self):
return self._guidance_scale
@property
def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -978,8 +1012,6 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 0.8,
guess_mode: bool = False,
@@ -994,6 +1026,9 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
aesthetic_score: float = 6.0,
negative_aesthetic_score: float = 2.5,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -1079,12 +1114,6 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
@@ -1140,6 +1169,15 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples:
@@ -1148,6 +1186,23 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple`
containing the output images.
"""
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# align format for control guidance
@@ -1179,8 +1234,13 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -1190,10 +1250,6 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
@@ -1207,7 +1263,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
)
(
prompt_embeds,
@@ -1219,7 +1275,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
prompt_2,
device,
num_images_per_prompt,
do_classifier_free_guidance,
self.do_classifier_free_guidance,
negative_prompt,
negative_prompt_2,
prompt_embeds=prompt_embeds,
@@ -1227,7 +1283,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=clip_skip,
clip_skip=self.clip_skip,
)
# 4. Prepare image and controlnet_conditioning_image
@@ -1242,7 +1298,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
height, width = control_image.shape[-2:]
@@ -1258,7 +1314,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
@@ -1273,6 +1329,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
self._num_timesteps = len(timesteps)
# 6. Prepare latent variables
latents = self.prepare_latents(
@@ -1330,7 +1387,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
)
add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
if do_classifier_free_guidance:
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
@@ -1345,13 +1402,13 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
# controlnet(s) inference
if guess_mode and do_classifier_free_guidance:
if guess_mode and self.do_classifier_free_guidance:
# Infer ControlNet only for the conditional batch.
control_model_input = latents
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
@@ -1384,7 +1441,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
return_dict=False,
)
if guess_mode and do_classifier_free_guidance:
if guess_mode and self.do_classifier_free_guidance:
# Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
@@ -1396,7 +1453,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
cross_attention_kwargs=self.cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
added_cond_kwargs=added_cond_kwargs,
@@ -1404,13 +1461,23 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
)[0]
# perform guidance
if do_classifier_free_guidance:
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
@@ -175,9 +175,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (
len(getattr(self.vae.config, "block_out_channels", self.vae.config.decoder_block_out_channels)) - 1
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
def prepare_text_inputs(self, prompt: Union[str, List[str]]):
if not isinstance(prompt, (str, list)):
@@ -0,0 +1,49 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["kandinsky3_pipeline"] = ["Kandinsky3Pipeline"]
_import_structure["kandinsky3img2img_pipeline"] = ["Kandinsky3Img2ImgPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .kandinsky3_pipeline import Kandinsky3Pipeline
from .kandinsky3img2img_pipeline import Kandinsky3Img2ImgPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
@@ -0,0 +1,452 @@
from typing import Callable, List, Optional, Union
import torch
from transformers import T5EncoderModel, T5Tokenizer
from ...loaders import LoraLoaderMixin
from ...models import Kandinsky3UNet, VQModel
from ...schedulers import DDPMScheduler
from ...utils import (
is_accelerate_available,
logging,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def downscale_height_and_width(height, width, scale_factor=8):
new_height = height // scale_factor**2
if height % scale_factor**2 != 0:
new_height += 1
new_width = width // scale_factor**2
if width % scale_factor**2 != 0:
new_width += 1
return new_height * scale_factor, new_width * scale_factor
class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
model_cpu_offload_seq = "text_encoder->unet->movq"
def __init__(
self,
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
unet: Kandinsky3UNet,
scheduler: DDPMScheduler,
movq: VQModel,
):
super().__init__()
self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, movq=movq
)
def remove_all_hooks(self):
if is_accelerate_available():
from accelerate.hooks import remove_hook_from_module
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
for model in [self.text_encoder, self.unet]:
if model is not None:
remove_hook_from_module(model, recurse=True)
self.unet_offload_hook = None
self.text_encoder_offload_hook = None
self.final_offload_hook = None
def process_embeds(self, embeddings, attention_mask, cut_context):
if cut_context:
embeddings[attention_mask == 0] = torch.zeros_like(embeddings[attention_mask == 0])
max_seq_length = attention_mask.sum(-1).max() + 1
embeddings = embeddings[:, :max_seq_length]
attention_mask = attention_mask[:, :max_seq_length]
return embeddings, attention_mask
@torch.no_grad()
def encode_prompt(
self,
prompt,
do_classifier_free_guidance=True,
num_images_per_prompt=1,
device=None,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
_cut_context=False,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`, *optional*):
torch device to place the resulting embeddings on
num_images_per_prompt (`int`, *optional*, defaults to 1):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
"""
if prompt is not None and negative_prompt is not None:
if type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
if device is None:
device = self._execution_device
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
max_length = 128
if prompt_embeds is None:
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
attention_mask = text_inputs.attention_mask.to(device)
prompt_embeds = self.text_encoder(
text_input_ids,
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
prompt_embeds, attention_mask = self.process_embeds(prompt_embeds, attention_mask, _cut_context)
prompt_embeds = prompt_embeds * attention_mask.unsqueeze(2)
if self.text_encoder is not None:
dtype = self.text_encoder.dtype
else:
dtype = None
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
attention_mask = attention_mask.repeat(num_images_per_prompt, 1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
if negative_prompt is not None:
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=128,
truncation=True,
return_attention_mask=True,
return_tensors="pt",
)
text_input_ids = uncond_input.input_ids.to(device)
negative_attention_mask = uncond_input.attention_mask.to(device)
negative_prompt_embeds = self.text_encoder(
text_input_ids,
attention_mask=negative_attention_mask,
)
negative_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds[:, : prompt_embeds.shape[1]]
negative_attention_mask = negative_attention_mask[:, : prompt_embeds.shape[1]]
negative_prompt_embeds = negative_prompt_embeds * negative_attention_mask.unsqueeze(2)
else:
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
negative_attention_mask = torch.zeros_like(attention_mask)
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
if negative_prompt_embeds.shape != prompt_embeds.shape:
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
negative_attention_mask = negative_attention_mask.repeat(num_images_per_prompt, 1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
else:
negative_prompt_embeds = None
negative_attention_mask = None
return prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
latents = latents * scheduler.init_noise_sigma
return latents
def check_inputs(
self,
prompt,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
num_inference_steps: int = 25,
guidance_scale: float = 3.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
height: Optional[int] = 1024,
width: Optional[int] = 1024,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
latents=None,
):
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps are used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 3.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
height (`int`, *optional*, defaults to self.unet.config.sample_size):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size):
The width in pixels of the generated image.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
clean_caption (`bool`, *optional*, defaults to `True`):
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
be installed. If the dependencies are not installed, the embeddings will be created from the raw
prompt.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
"""
cut_context = True
device = self._execution_device
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask = self.encode_prompt(
prompt,
do_classifier_free_guidance,
num_images_per_prompt=num_images_per_prompt,
device=device,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
_cut_context=cut_context,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
attention_mask = torch.cat([negative_attention_mask, attention_mask]).bool()
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latents
height, width = downscale_height_and_width(height, width, 8)
latents = self.prepare_latents(
(batch_size * num_images_per_prompt, 4, height, width),
prompt_embeds.dtype,
device,
generator,
latents,
self.scheduler,
)
if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None:
self.text_encoder_offload_hook.offload()
# 7. Denoising loop
# TODO(Yiyi): Correct the following line and use correctly
# num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=attention_mask,
return_dict=False,
)[0]
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = (guidance_scale + 1.0) * noise_pred_text - guidance_scale * noise_pred_uncond
# noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
noise_pred,
t,
latents,
generator=generator,
).prev_sample
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
if output_type not in ["pt", "np", "pil"]:
raise ValueError(
f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}"
)
if output_type in ["np", "pil"]:
image = image * 0.5 + 0.5
image = image.clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
@@ -0,0 +1,460 @@
import inspect
from typing import Callable, List, Optional, Union
import numpy as np
import PIL
import PIL.Image
import torch
from transformers import T5EncoderModel, T5Tokenizer
from ...loaders import LoraLoaderMixin
from ...models import Kandinsky3UNet, VQModel
from ...schedulers import DDPMScheduler
from ...utils import (
is_accelerate_available,
logging,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def downscale_height_and_width(height, width, scale_factor=8):
new_height = height // scale_factor**2
if height % scale_factor**2 != 0:
new_height += 1
new_width = width // scale_factor**2
if width % scale_factor**2 != 0:
new_width += 1
return new_height * scale_factor, new_width * scale_factor
def prepare_image(pil_image):
arr = np.array(pil_image.convert("RGB"))
arr = arr.astype(np.float32) / 127.5 - 1
arr = np.transpose(arr, [2, 0, 1])
image = torch.from_numpy(arr).unsqueeze(0)
return image
class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
model_cpu_offload_seq = "text_encoder->unet->movq"
def __init__(
self,
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
unet: Kandinsky3UNet,
scheduler: DDPMScheduler,
movq: VQModel,
):
super().__init__()
self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, movq=movq
)
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start:]
return timesteps, num_inference_steps - t_start
def remove_all_hooks(self):
if is_accelerate_available():
from accelerate.hooks import remove_hook_from_module
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
for model in [self.text_encoder, self.unet]:
if model is not None:
remove_hook_from_module(model, recurse=True)
self.unet_offload_hook = None
self.text_encoder_offload_hook = None
self.final_offload_hook = None
def _process_embeds(self, embeddings, attention_mask, cut_context):
# return embeddings, attention_mask
if cut_context:
embeddings[attention_mask == 0] = torch.zeros_like(embeddings[attention_mask == 0])
max_seq_length = attention_mask.sum(-1).max() + 1
embeddings = embeddings[:, :max_seq_length]
attention_mask = attention_mask[:, :max_seq_length]
return embeddings, attention_mask
@torch.no_grad()
def encode_prompt(
self,
prompt,
do_classifier_free_guidance=True,
num_images_per_prompt=1,
device=None,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
_cut_context=False,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`, *optional*):
torch device to place the resulting embeddings on
num_images_per_prompt (`int`, *optional*, defaults to 1):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
"""
if prompt is not None and negative_prompt is not None:
if type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
if device is None:
device = self._execution_device
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
max_length = 128
if prompt_embeds is None:
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
attention_mask = text_inputs.attention_mask.to(device)
prompt_embeds = self.text_encoder(
text_input_ids,
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
prompt_embeds, attention_mask = self._process_embeds(prompt_embeds, attention_mask, _cut_context)
prompt_embeds = prompt_embeds * attention_mask.unsqueeze(2)
if self.text_encoder is not None:
dtype = self.text_encoder.dtype
else:
dtype = None
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
attention_mask = attention_mask.repeat(num_images_per_prompt, 1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
if negative_prompt is not None:
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=128,
truncation=True,
return_attention_mask=True,
return_tensors="pt",
)
text_input_ids = uncond_input.input_ids.to(device)
negative_attention_mask = uncond_input.attention_mask.to(device)
negative_prompt_embeds = self.text_encoder(
text_input_ids,
attention_mask=negative_attention_mask,
)
negative_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds[:, : prompt_embeds.shape[1]]
negative_attention_mask = negative_attention_mask[:, : prompt_embeds.shape[1]]
negative_prompt_embeds = negative_prompt_embeds * negative_attention_mask.unsqueeze(2)
else:
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
negative_attention_mask = torch.zeros_like(attention_mask)
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
if negative_prompt_embeds.shape != prompt_embeds.shape:
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
negative_attention_mask = negative_attention_mask.repeat(num_images_per_prompt, 1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
else:
negative_prompt_embeds = None
negative_attention_mask = None
return prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
raise ValueError(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
)
image = image.to(device=device, dtype=dtype)
batch_size = batch_size * num_images_per_prompt
if image.shape[1] == 4:
init_latents = image
else:
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
elif isinstance(generator, list):
init_latents = [
self.movq.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
else:
init_latents = self.movq.encode(image).latent_dist.sample(generator)
init_latents = self.movq.config.scaling_factor * init_latents
init_latents = torch.cat([init_latents], dim=0)
shape = init_latents.shape
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
# get latents
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
latents = init_latents
return latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(
self,
prompt,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None,
strength: float = 0.3,
num_inference_steps: int = 25,
guidance_scale: float = 3.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
latents=None,
):
cut_context = True
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask = self.encode_prompt(
prompt,
do_classifier_free_guidance,
num_images_per_prompt=num_images_per_prompt,
device=device,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
_cut_context=cut_context,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
attention_mask = torch.cat([negative_attention_mask, attention_mask]).bool()
if not isinstance(image, list):
image = [image]
if not all(isinstance(i, (PIL.Image.Image, torch.Tensor)) for i in image):
raise ValueError(
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
)
image = torch.cat([prepare_image(i) for i in image], dim=0)
image = image.to(dtype=prompt_embeds.dtype, device=device)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
# 5. Prepare latents
latents = self.movq.encode(image)["latents"]
latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
latents = self.prepare_latents(
latents, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
)
if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None:
self.text_encoder_offload_hook.offload()
# 7. Denoising loop
# TODO(Yiyi): Correct the following line and use correctly
# num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=attention_mask,
)[0]
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = (guidance_scale + 1.0) * noise_pred_text - guidance_scale * noise_pred_uncond
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
noise_pred,
t,
latents,
generator=generator,
).prev_sample
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
if output_type not in ["pt", "np", "pil"]:
raise ValueError(
f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}"
)
if output_type in ["np", "pil"]:
image = image * 0.5 + 0.5
image = image.clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
@@ -44,15 +44,64 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used,
`timesteps` must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
EXAMPLE_DOC_STRING = """
Examples:
```py
@@ -154,9 +203,7 @@ class LatentConsistencyModelImg2ImgPipeline(
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
self.vae_scale_factor = 2 ** (
len(getattr(self.vae.config, "block_out_channels", self.vae.config.decoder_block_out_channels)) - 1
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
@@ -594,6 +641,7 @@ class LatentConsistencyModelImg2ImgPipeline(
num_inference_steps: int = 4,
strength: float = 0.8,
original_inference_steps: int = None,
timesteps: List[int] = None,
guidance_scale: float = 8.5,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
@@ -625,6 +673,10 @@ class LatentConsistencyModelImg2ImgPipeline(
we will draw `num_inference_steps` evenly spaced timesteps from as our final timestep schedule,
following the Skipping-Step method in the paper (see Section 4.3). If not set this will default to the
scheduler's `original_inference_steps` attribute.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps on the original LCM training/distillation timestep schedule are used. Must be in descending
order.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
@@ -730,10 +782,14 @@ class LatentConsistencyModelImg2ImgPipeline(
image = self.image_processor.preprocess(image)
# 5. Prepare timesteps
self.scheduler.set_timesteps(
num_inference_steps, device, original_inference_steps=original_inference_steps, strength=strength
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps,
original_inference_steps=original_inference_steps,
strength=strength,
)
timesteps = self.scheduler.timesteps
# 6. Prepare latent variables
original_inference_steps = (
@@ -61,6 +61,51 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used,
`timesteps` must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class LatentConsistencyModelPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
):
@@ -141,9 +186,7 @@ class LatentConsistencyModelPipeline(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (
len(getattr(self.vae.config, "block_out_channels", self.vae.config.decoder_block_out_channels)) - 1
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -532,6 +575,7 @@ class LatentConsistencyModelPipeline(
width: Optional[int] = None,
num_inference_steps: int = 4,
original_inference_steps: int = None,
timesteps: List[int] = None,
guidance_scale: float = 8.5,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
@@ -563,6 +607,10 @@ class LatentConsistencyModelPipeline(
we will draw `num_inference_steps` evenly spaced timesteps from as our final timestep schedule,
following the Skipping-Step method in the paper (see Section 4.3). If not set this will default to the
scheduler's `original_inference_steps` attribute.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps on the original LCM training/distillation timestep schedule are used. Must be in descending
order.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
@@ -670,8 +718,9 @@ class LatentConsistencyModelPipeline(
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device, original_inference_steps=original_inference_steps)
timesteps = self.scheduler.timesteps
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, original_inference_steps=original_inference_steps
)
# 5. Prepare latent variable
num_channels_latents = self.unet.config.in_channels
@@ -51,7 +51,7 @@ EXAMPLE_DOC_STRING = """
>>> import torch
>>> import scipy
>>> repo_id = "cvssp/audioldm-s-full-v2"
>>> repo_id = "ucsd-reach/musicldm"
>>> pipe = MusicLDMPipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
>>> pipe = pipe.to("cuda")
@@ -111,9 +111,7 @@ class MusicLDMPipeline(DiffusionPipeline):
scheduler=scheduler,
vocoder=vocoder,
)
self.vae_scale_factor = 2 ** (
len(getattr(self.vae.config, "block_out_channels", self.vae.config.decoder_block_out_channels)) - 1
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
def enable_vae_slicing(self):
@@ -35,9 +35,13 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
@@ -205,9 +209,7 @@ class PaintByExamplePipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (
len(getattr(self.vae.config, "block_out_channels", self.vae.config.decoder_block_out_channels)) - 1
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -538,12 +538,13 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
model = pipeline_class(**init_kwargs, dtype=dtype)
return model, params
@staticmethod
def _get_signature_keys(obj):
@classmethod
def _get_signature_keys(cls, obj):
parameters = inspect.signature(obj.__init__).parameters
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
expected_modules = set(required_parameters.keys()) - {"self"}
return expected_modules, optional_parameters
@property
+11 -4
View File
@@ -259,7 +259,7 @@ def warn_deprecated_model_variant(pretrained_model_name_or_path, use_auth_token,
comp_model_filenames, _ = variant_compatible_siblings(filenames, variant=revision)
comp_model_filenames = [".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames]
if set(comp_model_filenames) == set(model_filenames):
if set(model_filenames).issubset(set(comp_model_filenames)):
warnings.warn(
f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{revision}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.",
FutureWarning,
@@ -557,7 +557,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
for name, module in kwargs.items():
# retrieve library
if module is None:
if module is None or isinstance(module, (tuple, list)) and module[0] is None:
register_dict = {name: (None, None)}
else:
# register the config from the original module, not the dynamo compiled one
@@ -1906,12 +1906,19 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
" above."
) from model_info_call_error
@staticmethod
def _get_signature_keys(obj):
@classmethod
def _get_signature_keys(cls, obj):
parameters = inspect.signature(obj.__init__).parameters
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
expected_modules = set(required_parameters.keys()) - {"self"}
optional_names = list(optional_parameters)
for name in optional_names:
if name in cls._optional_components:
expected_modules.add(name)
optional_parameters.remove(name)
return expected_modules, optional_parameters
@property
@@ -97,6 +97,42 @@ ASPECT_RATIO_1024_BIN = {
"4.0": [2048.0, 512.0],
}
ASPECT_RATIO_512_BIN = {
"0.25": [256.0, 1024.0],
"0.28": [256.0, 928.0],
"0.32": [288.0, 896.0],
"0.33": [288.0, 864.0],
"0.35": [288.0, 832.0],
"0.4": [320.0, 800.0],
"0.42": [320.0, 768.0],
"0.48": [352.0, 736.0],
"0.5": [352.0, 704.0],
"0.52": [352.0, 672.0],
"0.57": [384.0, 672.0],
"0.6": [384.0, 640.0],
"0.68": [416.0, 608.0],
"0.72": [416.0, 576.0],
"0.78": [448.0, 576.0],
"0.82": [448.0, 544.0],
"0.88": [480.0, 544.0],
"0.94": [480.0, 512.0],
"1.0": [512.0, 512.0],
"1.07": [512.0, 480.0],
"1.13": [544.0, 480.0],
"1.21": [544.0, 448.0],
"1.29": [576.0, 448.0],
"1.38": [576.0, 416.0],
"1.46": [608.0, 416.0],
"1.67": [640.0, 384.0],
"1.75": [672.0, 384.0],
"2.0": [704.0, 352.0],
"2.09": [736.0, 352.0],
"2.4": [768.0, 320.0],
"2.5": [800.0, 320.0],
"3.0": [864.0, 288.0],
"4.0": [1024.0, 256.0],
}
class PixArtAlphaPipeline(DiffusionPipeline):
r"""
@@ -154,9 +190,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor = 2 ** (
len(getattr(self.vae.config, "block_out_channels", self.vae.config.decoder_block_out_channels)) - 1
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py
@@ -693,8 +727,11 @@ class PixArtAlphaPipeline(DiffusionPipeline):
height = height or self.transformer.config.sample_size * self.vae_scale_factor
width = width or self.transformer.config.sample_size * self.vae_scale_factor
if use_resolution_binning:
aspect_ratio_bin = (
ASPECT_RATIO_1024_BIN if self.transformer.config.sample_size == 128 else ASPECT_RATIO_512_BIN
)
orig_height, orig_width = height, width
height, width = self.classify_height_width_bin(height, width, ratios=ASPECT_RATIO_1024_BIN)
height, width = self.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
self.check_inputs(
prompt,
@@ -87,9 +87,7 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (
len(getattr(self.vae.config, "block_out_channels", self.vae.config.decoder_block_out_channels)) - 1
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -47,6 +47,7 @@ else:
_import_structure["pipeline_stable_diffusion_paradigms"] = ["StableDiffusionParadigmsPipeline"]
_import_structure["pipeline_stable_diffusion_sag"] = ["StableDiffusionSAGPipeline"]
_import_structure["pipeline_stable_diffusion_upscale"] = ["StableDiffusionUpscalePipeline"]
_import_structure["pipeline_stable_diffusion_video"] = ["StableDiffusionVideoPipeline"]
_import_structure["pipeline_stable_unclip"] = ["StableUnCLIPPipeline"]
_import_structure["pipeline_stable_unclip_img2img"] = ["StableUnCLIPImg2ImgPipeline"]
_import_structure["safety_checker"] = ["StableDiffusionSafetyChecker"]
@@ -151,6 +152,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipeline_stable_diffusion_paradigms import StableDiffusionParadigmsPipeline
from .pipeline_stable_diffusion_sag import StableDiffusionSAGPipeline
from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
from .pipeline_stable_diffusion_video import StableDiffusionVideoPipeline
from .pipeline_stable_unclip import StableUnCLIPPipeline
from .pipeline_stable_unclip_img2img import StableUnCLIPImg2ImgPipeline
from .safety_checker import StableDiffusionSafetyChecker
@@ -61,6 +61,20 @@ def preprocess(image):
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
def posterior_sample(scheduler, latents, timestep, clean_latents, generator, eta):
# 1. get previous step value (=t-1)
prev_timestep = timestep - scheduler.config.num_train_timesteps // scheduler.num_inference_steps
@@ -224,9 +238,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (
len(getattr(self.vae.config, "block_out_channels", self.vae.config.decoder_block_out_channels)) - 1
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -569,11 +581,12 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
if isinstance(generator, list):
init_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
init_latents = torch.cat(init_latents, dim=0)
else:
init_latents = self.vae.encode(image).latent_dist.sample(generator)
init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
init_latents = self.vae.config.scaling_factor * init_latents
@@ -17,11 +17,11 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch
from packaging import version
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
@@ -70,7 +70,53 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
return noise_cfg
class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin):
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used,
`timesteps` must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class StableDiffusionPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
@@ -82,6 +128,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args:
vae ([`AutoencoderKL`]):
@@ -104,7 +151,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
_exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
@@ -117,6 +164,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -193,10 +241,9 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (
len(getattr(self.vae.config, "block_out_channels", self.vae.config.decoder_block_out_channels)) - 1
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -442,6 +489,19 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
return prompt_embeds, negative_prompt_embeds
def encode_image(self, image, device, num_images_per_prompt):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
has_nsfw_concept = None
@@ -643,6 +703,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
timesteps: List[int] = None,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
@@ -651,6 +712,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -673,6 +735,10 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
@@ -697,6 +763,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
@@ -796,15 +863,20 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
lora_scale=lora_scale,
clip_skip=self.clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None:
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
@@ -822,7 +894,10 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 6.5 Optionally get Guidance Scale Embedding
# 6.1 Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
# 6.2 Optionally get Guidance Scale Embedding
timestep_cond = None
if self.unet.config.time_cond_proj_dim is not None:
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
@@ -846,6 +921,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
@@ -239,9 +239,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (
len(getattr(self.vae.config, "block_out_channels", self.vae.config.decoder_block_out_channels)) - 1
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -37,9 +37,13 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
@@ -141,9 +145,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
depth_estimator=depth_estimator,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (
len(getattr(self.vae.config, "block_out_channels", self.vae.config.decoder_block_out_channels)) - 1
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
@@ -367,9 +367,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
feature_extractor=feature_extractor,
inverse_scheduler=inverse_scheduler,
)
self.vae_scale_factor = 2 ** (
len(getattr(self.vae.config, "block_out_channels", self.vae.config.decoder_block_out_channels)) - 1
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -168,9 +168,7 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (
len(getattr(self.vae.config, "block_out_channels", self.vae.config.decoder_block_out_channels)) - 1
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -226,9 +226,7 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (
len(getattr(self.vae.config, "block_out_channels", self.vae.config.decoder_block_out_channels)) - 1
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -126,9 +126,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (
len(getattr(self.vae.config, "block_out_channels", self.vae.config.decoder_block_out_channels)) - 1
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -19,11 +19,11 @@ import numpy as np
import PIL.Image
import torch
from packaging import version
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
@@ -73,9 +73,13 @@ EXAMPLE_DOC_STRING = """
"""
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
@@ -105,8 +109,53 @@ def preprocess(image):
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used,
`timesteps` must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class StableDiffusionImg2ImgPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin
):
r"""
Pipeline for text-guided image-to-image generation using Stable Diffusion.
@@ -119,6 +168,7 @@ class StableDiffusionImg2ImgPipeline(
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args:
vae ([`AutoencoderKL`]):
@@ -141,7 +191,7 @@ class StableDiffusionImg2ImgPipeline(
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
_exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
@@ -154,6 +204,7 @@ class StableDiffusionImg2ImgPipeline(
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -230,10 +281,9 @@ class StableDiffusionImg2ImgPipeline(
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (
len(getattr(self.vae.config, "block_out_channels", self.vae.config.decoder_block_out_channels)) - 1
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -452,6 +502,20 @@ class StableDiffusionImg2ImgPipeline(
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
@@ -703,6 +767,7 @@ class StableDiffusionImg2ImgPipeline(
image: PipelineImageInput = None,
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
timesteps: List[int] = None,
guidance_scale: Optional[float] = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
@@ -710,6 +775,7 @@ class StableDiffusionImg2ImgPipeline(
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -739,6 +805,10 @@ class StableDiffusionImg2ImgPipeline(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. This parameter is modulated by `strength`.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
@@ -759,6 +829,7 @@ class StableDiffusionImg2ImgPipeline(
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
@@ -851,11 +922,16 @@ class StableDiffusionImg2ImgPipeline(
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None:
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. Preprocess image
image = self.image_processor.preprocess(image)
# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
@@ -873,7 +949,10 @@ class StableDiffusionImg2ImgPipeline(
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7.5 Optionally get Guidance Scale Embedding
# 7.1 Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
# 7.2 Optionally get Guidance Scale Embedding
timestep_cond = None
if self.unet.config.time_cond_proj_dim is not None:
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
@@ -897,6 +976,7 @@ class StableDiffusionImg2ImgPipeline(
encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
@@ -19,11 +19,11 @@ import numpy as np
import PIL.Image
import torch
from packaging import version
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AsymmetricAutoencoderKL, AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
@@ -160,17 +160,66 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used,
`timesteps` must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class StableDiffusionInpaintPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
DiffusionPipeline, TextualInversionLoaderMixin, IPAdapterMixin, LoraLoaderMixin, FromSingleFileMixin
):
r"""
Pipeline for text-guided image inpainting using Stable Diffusion.
@@ -182,6 +231,7 @@ class StableDiffusionInpaintPipeline(
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args:
vae ([`AutoencoderKL`, `AsymmetricAutoencoderKL`]):
@@ -204,7 +254,7 @@ class StableDiffusionInpaintPipeline(
"""
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
_exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "mask", "masked_image_latents"]
@@ -217,6 +267,7 @@ class StableDiffusionInpaintPipeline(
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -298,10 +349,9 @@ class StableDiffusionInpaintPipeline(
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (
len(getattr(self.vae.config, "block_out_channels", self.vae.config.decoder_block_out_channels)) - 1
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
@@ -523,6 +573,20 @@ class StableDiffusionInpaintPipeline(
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
@@ -831,6 +895,7 @@ class StableDiffusionInpaintPipeline(
width: Optional[int] = None,
strength: float = 1.0,
num_inference_steps: int = 50,
timesteps: List[int] = None,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
@@ -839,6 +904,7 @@ class StableDiffusionInpaintPipeline(
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -880,6 +946,10 @@ class StableDiffusionInpaintPipeline(
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. This parameter is modulated by `strength`.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
@@ -904,6 +974,7 @@ class StableDiffusionInpaintPipeline(
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
@@ -1031,8 +1102,13 @@ class StableDiffusionInpaintPipeline(
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None:
image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
if self.do_classifier_free_guidance:
image_embeds = torch.cat([negative_image_embeds, image_embeds])
# 4. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
timesteps, num_inference_steps = self.get_timesteps(
num_inference_steps=num_inference_steps, strength=strength, device=device
)
@@ -1119,7 +1195,10 @@ class StableDiffusionInpaintPipeline(
# 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 9.5 Optionally get Guidance Scale Embedding
# 9.1 Add image embeds for IP-Adapter
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
# 9.2 Optionally get Guidance Scale Embedding
timestep_cond = None
if self.unet.config.time_cond_proj_dim is not None:
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
@@ -1148,6 +1227,7 @@ class StableDiffusionInpaintPipeline(
encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
@@ -213,9 +213,7 @@ class StableDiffusionInpaintPipelineLegacy(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (
len(getattr(self.vae.config, "block_out_channels", self.vae.config.decoder_block_out_channels)) - 1
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -58,6 +58,20 @@ def preprocess(image):
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
r"""
Pipeline for pixel-level image editing by following text instructions (based on Stable Diffusion).
@@ -133,9 +147,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
self.vae_scale_factor = 2 ** (
len(getattr(self.vae.config, "block_out_channels", self.vae.config.decoder_block_out_channels)) - 1
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -322,7 +334,6 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
prompt_embeds.dtype,
device,
self.do_classifier_free_guidance,
generator,
)
height, width = image_latents.shape[-2:]
@@ -718,17 +729,7 @@ class StableDiffusionInstructPix2PixPipeline(DiffusionPipeline, TextualInversion
if image.shape[1] == 4:
image_latents = image
else:
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if isinstance(generator, list):
image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = self.vae.encode(image).latent_dist.mode()
image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax")
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand image_latents for batch_size
@@ -117,9 +117,7 @@ class StableDiffusionKDiffusionPipeline(DiffusionPipeline, TextualInversionLoade
feature_extractor=feature_extractor,
)
self.register_to_config(requires_safety_checker=requires_safety_checker)
self.vae_scale_factor = 2 ** (
len(getattr(self.vae.config, "block_out_channels", self.vae.config.decoder_block_out_channels)) - 1
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
model = ModelWrapper(unet, scheduler.alphas_cumprod)

Some files were not shown because too many files have changed in this diff Show More