Compare commits

..

112 Commits

Author SHA1 Message Date
YiYi Xu 6a509ba862 Merge branch 'main' into modular-diffusers 2025-04-30 17:56:25 -10:00
Yao Matrix 06beecafc5 make autoencoders. controlnet_flux and wan_transformer3d_single_file pass on xpu (#11461)
* make autoencoders. controlnet_flux and wan_transformer3d_single_file
pass on XPU

Signed-off-by: Yao Matrix <matrix.yao@intel.com>

* Apply style fixes

---------

Signed-off-by: Yao Matrix <matrix.yao@intel.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Aryan <aryan@huggingface.co>
2025-05-01 02:43:31 +05:30
Vaibhav Kumawat daf0a23958 Add LANCZOS as default interplotation mode. (#11463)
* Add LANCZOS as default interplotation mode.

* LANCZOS as default interplotation

* LANCZOS as default interplotation mode

* Added LANCZOS as default interplotation mode
2025-04-30 14:22:38 -04:00
tongyu 38ced7ee59 [test_models_transformer_hunyuan_video] help us test torch.compile() for impactful models (#11431)
* Update test_models_transformer_hunyuan_video.py

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-04-30 19:11:42 +08:00
Yao Matrix 23c98025b3 make safe diffusion test cases pass on XPU and A100 (#11458)
* make safe diffusion test cases pass on XPU and A100

Signed-off-by: Yao Matrix <matrix.yao@intel.com>

* calibrate A100 expected values

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

---------

Signed-off-by: Yao Matrix <matrix.yao@intel.com>
Signed-off-by: YAO Matrix <matrix.yao@intel.com>
2025-04-30 16:05:28 +05:30
captainzz 8cd7426e56 Add StableDiffusion3InstructPix2PixPipeline (#11378)
* upload StableDiffusion3InstructPix2PixPipeline

* Move to community

* Add readme

* Fix images

* remove images

* Change image url

* fix

* Apply style fixes
2025-04-30 06:13:12 -04:00
Daniel Socek fbce7aeb32 Add generic support for Intel Gaudi accelerator (hpu device) (#11328)
* Add generic support for Intel Gaudi accelerator (hpu device)

Signed-off-by: Daniel Socek <daniel.socek@intel.com>
Co-authored-by: Libin Tang <libin.tang@intel.com>

* Add loggers for generic HPU support

Signed-off-by: Daniel Socek <daniel.socek@intel.com>

* Refactor hpu support with is_hpu_available() logic

Signed-off-by: Daniel Socek <daniel.socek@intel.com>

* Fix style for hpu support update

Signed-off-by: Daniel Socek <daniel.socek@intel.com>

* Decouple soft HPU check from hard device validation to support HPU migration

Signed-off-by: Daniel Socek <daniel.socek@intel.com>

---------

Signed-off-by: Daniel Socek <daniel.socek@intel.com>
Co-authored-by: Libin Tang <libin.tang@intel.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-04-30 14:45:02 +05:30
Yao Matrix 35fada4169 enable unidiffuser test cases on xpu (#11444)
* enable unidiffuser cases on XPU

Signed-off-by: Yao Matrix <matrix.yao@intel.com>

* fix a typo

Signed-off-by: Yao Matrix <matrix.yao@intel.com>

* fix style

Signed-off-by: Yao Matrix <matrix.yao@intel.com>

---------

Signed-off-by: Yao Matrix <matrix.yao@intel.com>
2025-04-30 13:58:00 +05:30
Yao Matrix fbe2fe5578 enable consistency test cases on XPU, all passed (#11446)
Signed-off-by: Yao Matrix <matrix.yao@intel.com>
2025-04-30 12:41:29 +05:30
Aryan c86511586f torch.compile fullgraph compatibility for Hunyuan Video (#11457)
udpate
2025-04-30 11:21:17 +05:30
Yao Matrix 60892c55a4 enable marigold_intrinsics cases on XPU (#11445)
Signed-off-by: Yao Matrix <matrix.yao@intel.com>
2025-04-30 11:07:37 +05:30
Aryan 8fe5a14d9b Raise warning instead of error for block offloading with streams (#11425)
raise warning instead of error
2025-04-30 08:26:16 +05:30
Youlun Peng 58431f102c Set LANCZOS as the default interpolation for image resizing in ControlNet training (#11449)
Set LANCZOS as the default interpolation for image resizing
2025-04-29 08:47:02 -04:00
urpetkov-amd 4a9ab650aa Fixing missing provider options argument (#11397)
* Fixing missing provider options argument

* Adding if else for provider options

* Apply suggestions from code review

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Apply style fixes

* Update src/diffusers/pipelines/onnx_utils.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/onnx_utils.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

---------

Co-authored-by: Uros Petkovic <urpektov@amd.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-04-28 10:23:05 -10:00
Linoy Tsaban 0ac1d5b482 [Hi-Dream LoRA] fix bug in validation (#11439)
remove unnecessary pipeline moving to cpu in validation

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-04-28 06:22:32 -10:00
Yao Matrix 7567adfc45 enable 28 GGUF test cases on XPU (#11404)
* enable gguf test cases on XPU

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* make SD35LargeGGUFSingleFileTests::test_pipeline_inference pas

Signed-off-by: root <root@a4bf01945cfe.jf.intel.com>

* make FluxControlLoRAGGUFTests::test_lora_loading pass

Signed-off-by: Yao Matrix <matrix.yao@intel.com>

* polish code

Signed-off-by: Yao Matrix <matrix.yao@intel.com>

* Apply style fixes

---------

Signed-off-by: YAO Matrix <matrix.yao@intel.com>
Signed-off-by: root <root@a4bf01945cfe.jf.intel.com>
Signed-off-by: Yao Matrix <matrix.yao@intel.com>
Co-authored-by: root <root@a4bf01945cfe.jf.intel.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-04-28 21:32:04 +05:30
tongyu 3da98e7ee3 [train_text_to_image_lora] Better image interpolation in training scripts follow up (#11427)
* Update train_text_to_image_lora.py

* update_train_text_to_image_lora
2025-04-28 11:23:24 -04:00
tongyu b3b04fefde [train_text_to_image] Better image interpolation in training scripts follow up (#11426)
* Update train_text_to_image.py

* update
2025-04-28 10:50:33 -04:00
Sayak Paul 0e3f2713c2 [tests] fix import. (#11434)
fix import.
2025-04-28 13:32:28 +08:00
Yao Matrix a7e9f85e21 enable test_layerwise_casting_memory cases on XPU (#11406)
* enable test_layerwise_casting_memory cases on XPU

Signed-off-by: Yao Matrix <matrix.yao@intel.com>

* fix style

Signed-off-by: Yao Matrix <matrix.yao@intel.com>

---------

Signed-off-by: Yao Matrix <matrix.yao@intel.com>
2025-04-28 06:38:39 +05:30
Yao Matrix 9ce89e2efa enable group_offload cases and quanto cases on XPU (#11405)
* enable group_offload cases and quanto cases on XPU

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* use backend APIs

Signed-off-by: Yao Matrix <matrix.yao@intel.com>

* fix style

Signed-off-by: Yao Matrix <matrix.yao@intel.com>

---------

Signed-off-by: YAO Matrix <matrix.yao@intel.com>
Signed-off-by: Yao Matrix <matrix.yao@intel.com>
2025-04-28 06:37:16 +05:30
Sayak Paul aa5f5d41d6 [tests] add tests to check for graph breaks, recompilation, cuda syncs in pipelines during torch.compile() (#11085)
* test for better torch.compile stuff.

* fixes

* recompilation and graph break.

* clear compilation cache.

* change to modeling level test.

* allow running compilation tests during nightlies.
2025-04-28 08:36:33 +08:00
Mert Erbak bd96a084d3 [train_dreambooth_lora.py] Set LANCZOS as default interpolation mode for resizing (#11421)
* Set LANCZOS as default interpolation mode for resizing

* [train_dreambooth_lora.py] Set LANCZOS as default interpolation mode for resizing
2025-04-26 01:58:41 -04:00
co63oc f00a995753 Fix typos in strings and comments (#11407) 2025-04-24 08:53:47 -10:00
Ishan Modi e8312e7ca9 [BUG] fixed WAN docstring (#11226)
update
2025-04-24 08:49:37 -10:00
Emiliano 7986834572 Fix Flux IP adapter argument in the pipeline example (#11402)
Fix Flux IP adapter argument in the example

IP-Adapter example had a wrong argument. Fix `true_cfg` -> `true_cfg_scale`
2025-04-24 08:41:12 -10:00
Linoy Tsaban edd7880418 [HiDream LoRA] optimizations + small updates (#11381)
* 1. add pre-computation of prompt embeddings when custom prompts are used as well
2. save model card even if model is not pushed to hub
3. remove scheduler initialization from code example - not necessary anymore (it's now if the base model's config)
4. add skip_final_inference - to allow to run with validation, but skip the final loading of the pipeline with the lora weights to reduce memory reqs

* pre encode validation prompt as well

* Update examples/dreambooth/train_dreambooth_lora_hidream.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update examples/dreambooth/train_dreambooth_lora_hidream.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update examples/dreambooth/train_dreambooth_lora_hidream.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* pre encode validation prompt as well

* Apply style fixes

* empty commit

* change default trained modules

* empty commit

* address comments + change encoding of validation prompt (before it was only pre-encoded if custom prompts are provided, but should be pre-encoded either way)

* Apply style fixes

* empty commit

* fix validation_embeddings definition

* fix final inference condition

* fix pipeline deletion in last inference

* Apply style fixes

* empty commit

* layers

* remove readme remarks on only pre-computing when instance prompt is provided and change example to 3d icons

* smol fix

* empty commit

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-04-24 07:48:19 +03:00
Teriks b4be42282d Kolors additional pipelines, community contrib (#11372)
* Kolors additional pipelines, community contrib

---------

Co-authored-by: Teriks <Teriks@users.noreply.github.com>
Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
2025-04-23 11:07:27 -10:00
Ishan Modi a4f9c3cbc3 [Feature] Added Xlab Controlnet support (#11249)
update
2025-04-23 10:43:50 -10:00
Ishan Dutta 4b60f4b602 [train_dreambooth_flux] Add LANCZOS as the default interpolation mode for image resizing (#11395) 2025-04-23 10:47:05 -04:00
Aryan 6cef71de3a Fix group offloading with block_level and use_stream=True (#11375)
* fix

* add tests

* add message check
2025-04-23 18:17:53 +05:30
Ameer Azam 026507c06c Update README_hidream.md (#11386)
Small change
requirements_sana.txt to 
requirements_hidream.txt
2025-04-22 20:08:26 -04:00
YiYi Xu 448c72a230 [HiDream] move deprecation to 0.35.0 (#11384)
up
2025-04-22 08:08:08 -10:00
Aryan f108ad8888 Update modeling imports (#11129)
update
2025-04-22 06:59:25 -10:00
Linoy Tsaban e30d3bf544 [LoRA] add LoRA support to HiDream and fine-tuning script (#11281)
* initial commit

* initial commit

* initial commit

* initial commit

* initial commit

* initial commit

* Update examples/dreambooth/train_dreambooth_lora_hidream.py

Co-authored-by: Bagheera <59658056+bghira@users.noreply.github.com>

* move prompt embeds, pooled embeds outside

* Update examples/dreambooth/train_dreambooth_lora_hidream.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update examples/dreambooth/train_dreambooth_lora_hidream.py

Co-authored-by: hlky <hlky@hlky.ac>

* fix import

* fix import and tokenizer 4, text encoder 4 loading

* te

* prompt embeds

* fix naming

* shapes

* initial commit to add HiDreamImageLoraLoaderMixin

* fix init

* add tests

* loader

* fix model input

* add code example to readme

* fix default max length of text encoders

* prints

* nullify training cond in unpatchify for temp fix to incompatible shaping of transformer output during training

* smol fix

* unpatchify

* unpatchify

* fix validation

* flip pred and loss

* fix shift!!!

* revert unpatchify changes (for now)

* smol fix

* Apply style fixes

* workaround moe training

* workaround moe training

* remove prints

* to reduce some memory, keep vae in `weight_dtype` same as we have for flux (as it's the same vae)
https://github.com/huggingface/diffusers/blob/bbd0c161b55ba2234304f1e6325832dd69c60565/examples/dreambooth/train_dreambooth_lora_flux.py#L1207

* refactor to align with HiDream refactor

* refactor to align with HiDream refactor

* refactor to align with HiDream refactor

* add support for cpu offloading of text encoders

* Apply style fixes

* adjust lr and rank for train example

* fix copies

* Apply style fixes

* update README

* update README

* update README

* fix license

* keep prompt2,3,4 as None in validation

* remove reverse ode comment

* Update examples/dreambooth/train_dreambooth_lora_hidream.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update examples/dreambooth/train_dreambooth_lora_hidream.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* vae offload change

* fix text encoder offloading

* Apply style fixes

* cleaner to_kwargs

* fix module name in copied from

* add requirements

* fix offloading

* fix offloading

* fix offloading

* update transformers version in reqs

* try AutoTokenizer

* try AutoTokenizer

* Apply style fixes

* empty commit

* Delete tests/lora/test_lora_layers_hidream.py

* change tokenizer_4 to load with AutoTokenizer as well

* make text_encoder_four and tokenizer_four configurable

* save model card

* save model card

* revert T5

* fix test

* remove non diffusers lumina2 conversion

---------

Co-authored-by: Bagheera <59658056+bghira@users.noreply.github.com>
Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-04-22 11:44:02 +03:00
apolinário 6ab62c7431 Add stochastic sampling to FlowMatchEulerDiscreteScheduler (#11369)
* Add stochastic sampling to FlowMatchEulerDiscreteScheduler

This PR adds stochastic sampling to FlowMatchEulerDiscreteScheduler based on https://github.com/Lightricks/LTX-Video/commit/b1aeddd7ccac85e6d1b0d97762610ddb53c1b408  ltx_video/schedulers/rf.py

* Apply style fixes

* Use config value directly

* Apply style fixes

* Swap order

* Update src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-04-21 17:18:30 -10:00
Ishan Modi f59df3bb8b [Refactor] Minor Improvement for import utils (#11161)
* update

* update

* addressed PR comments

* update

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-04-21 09:56:55 -10:00
josephrocca a00c73a5e1 Support different-length pos/neg prompts for FLUX.1-schnell variants like Chroma (#11120)
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-04-21 09:28:19 -10:00
OleehyO 0434db9a99 [cogview4][feat] Support attention mechanism with variable-length support and batch packing (#11349)
* [cogview4] Enhance attention mechanism with variable-length support and batch packing

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-04-21 09:27:55 -10:00
Aamir Nazir aff574fb29 Add Serialized Type Name kwarg in Model Output (#10502)
* Update outputs.py
2025-04-21 08:45:28 -10:00
Ishan Modi 79ea8eb258 [BUG] fixes in kadinsky pipeline (#11080)
* bug fix kadinsky pipeline
2025-04-21 08:41:09 -10:00
Aryan e7f3a73786 Fix Wan I2V prepare_latents dtype (#11371)
update
2025-04-21 08:18:50 -10:00
PromeAI 7a4a126db8 fix issue that training flux controlnet was unstable and validation r… (#11373)
* fix issue that training flux controlnet was unstable and validation results were unstable

* del unused code pieces, fix grammar

---------

Co-authored-by: Your Name <you@example.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-04-21 08:16:05 -10:00
YiYi Xu 96795afc72 Merge branch 'main' into modular-diffusers 2025-04-07 18:05:00 -10:00
yiyixuxu 12650e1393 up 2025-02-04 02:08:28 +01:00
yiyixuxu addaad013c more more more refactor 2025-02-03 20:36:05 +01:00
yiyixuxu 485f8d1758 more refactor 2025-02-01 21:30:05 +01:00
yiyixuxu cff0fd6260 more refactor 2025-02-01 11:36:13 +01:00
yiyixuxu 8ddb20bfb8 up 2025-02-01 05:45:00 +01:00
yiyixuxu e5089d702b update 2025-01-31 21:55:45 +01:00
yiyixuxu 2c3e4eafa8 fix 2025-01-29 17:58:40 +01:00
yiyixuxu c7020df2cf add model_info 2025-01-27 11:33:27 +01:00
yiyixuxu 4bed3e306e up up 2025-01-26 13:04:33 +01:00
yiyixuxu 00a3bc9d6c fix 2025-01-23 18:16:00 +01:00
YiYi Xu ccb35acd81 Merge branch 'main' into modular-diffusers 2025-01-23 07:07:11 -10:00
yiyixuxu 00cae4e857 docstring doc doc doc 2025-01-23 11:07:13 +01:00
yiyixuxu b3fb4188f5 Merge branch 'modular-diffusers' of github.com:huggingface/diffusers into modular-diffusers 2025-01-22 17:24:06 +01:00
YiYi Xu 71df1581f7 Update src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py
Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>
2025-01-22 06:19:22 -10:00
yiyixuxu d046cf7d35 block state + fix for num_images_per_prompt > 1 for denoise/controlnet union etc 2025-01-22 09:48:57 +01:00
yiyixuxu 68a5185c86 refactor more, ipadapter node, lora node 2025-01-20 03:36:01 +01:00
yiyixuxu 6e2fe26bfd fix more for lora 2025-01-18 08:04:12 +01:00
yiyixuxu 77b5fa59c5 make it work with lora has both text_encoder & unet 2025-01-18 04:12:07 +01:00
yiyixuxu a226920b52 get_block_state make it less verbose 2025-01-17 01:37:18 +01:00
yiyixuxu 7007f72409 InputParam, OutputParam, get_auto_doc 2025-01-16 11:44:24 +01:00
yiyixuxu a6804de4a2 add controlnet union to auto & fix for pag 2025-01-12 16:24:01 +01:00
yiyixuxu 7f897a9fc4 fix 2025-01-12 04:50:45 +01:00
yiyixuxu 0966663d2a adjust print 2025-01-11 19:15:54 +01:00
yiyixuxu fb78f4f12d Merge branch 'modular-diffusers' of github.com:huggingface/diffusers into modular-diffusers 2025-01-11 09:05:56 +01:00
yiyixuxu 2220af6940 refactor 2025-01-11 09:05:47 +01:00
hlky 7a34832d52 [modular] Stable Diffusion XL ControlNet Union (#10509)
StableDiffusionXLControlNetUnionDenoiseStep
2025-01-09 10:29:45 -10:00
yiyixuxu e973de64f9 fix contro;net inpaint preprocess 2025-01-08 21:47:20 +01:00
yiyixuxu db94ca882d add controlnet inpaint + more refactor 2025-01-07 20:49:58 +01:00
yiyixuxu 6985906a2e controlnet input & remove the MultiPipelineBlocks class 2025-01-07 01:56:33 +01:00
yiyixuxu 54f410db6c add inpaint 2025-01-06 09:19:59 +01:00
yiyixuxu c12a05b9c1 update to to not assume pipeline has hf_device_map 2025-01-03 20:57:44 +01:00
yiyixuxu 2e0f5c86cc start to add inpaint 2025-01-03 18:20:39 +01:00
yiyixuxu 1d63306295 make it work with lora 2025-01-03 06:07:25 +01:00
yiyixuxu 6c93626f6f remove run_blocks, just use __call__ 2025-01-02 00:59:12 +01:00
yiyixuxu 72c5bf07c8 add a from_block class method to modular pipeline 2025-01-02 00:49:34 +01:00
yiyixuxu ed59f90f15 modular pipeline builder -> ModularPipeline 2025-01-01 22:15:48 +01:00
yiyixuxu a09ca7f27e refactors: block __init__ no longer accept args. remove update_states from pipeline blocks, add update_states to modularpipeline, remove multi-block support for modular pipeline, remove offload support on modular pipeline 2025-01-01 21:43:20 +01:00
yiyixuxu 8c02572e16 add memory_reserve_margin arg to auto offload 2024-12-31 20:08:53 +01:00
yiyixuxu 27dde51de8 add output arg to run_blocks 2024-12-31 18:06:44 +01:00
yiyixuxu 10d4a775f1 style 2024-12-31 09:55:50 +01:00
yiyixuxu 72d9a81d99 components manager 2024-12-31 09:54:46 +01:00
yiyixuxu 4fa85c7963 add model_manager and global offloading method 2024-12-31 02:57:42 +01:00
YiYi Xu 806e8e66fb Merge branch 'main' into modular-diffusers 2024-12-29 00:44:43 -10:00
yiyixuxu 0b90051db8 add vae encoder node 2024-12-19 17:57:12 +01:00
yiyixuxu b305c779b2 add offload support! 2024-12-14 21:37:21 +01:00
yiyixuxu 2b3cd2d39c update 2024-12-14 03:02:31 +01:00
yiyixuxu bc3d1c9ee6 add model_cpu_offload_seq + _exlude_from_cpu_offload 2024-12-14 00:24:15 +01:00
yiyixuxu e50d614636 only add model as expected_component when the model need to run for the block, currently it's added even when only config is needed 2024-12-11 03:39:39 +01:00
hlky a8df0f1ffb Modular APG (#10173) 2024-12-10 08:22:42 -10:00
yiyixuxu ace53e2d2f update/refactor 2024-12-10 03:41:28 +01:00
yiyixuxu ffc2992fc2 add autostep (not complete) 2024-11-16 22:42:06 +01:00
yiyixuxu c70a285c2c style 2024-10-30 10:33:25 +01:00
yiyixuxu 8b811feece refactor, from_pretrained, from_pipe, remove_blocks, replace_blocks 2024-10-30 10:13:03 +01:00
yiyixuxu 37e8dc7a59 remove img2img blocksgit status consolidate text2img and img2img 2024-10-28 00:37:48 +01:00
yiyixuxu 024a9f5de3 fix so that run_blocks can work with inputs in the state 2024-10-27 18:52:56 +01:00
yiyixuxu 005195c23e add 2024-10-27 15:18:10 +01:00
yiyixuxu 6742f160df up 2024-10-27 14:59:31 +01:00
yiyixuxu 540d303250 refactor guider 2024-10-26 21:17:06 +02:00
yiyixuxu f1b3036ca1 update pag guider - draft 2024-10-24 00:14:59 +02:00
yiyixuxu 46ec1743a2 refactor guider, remove prepareguidance step to be combinedd into denoisestep 2024-10-23 21:42:40 +02:00
yiyixuxu 70272b1108 combine controlnetstep into contronetdesnoisestep 2024-10-20 19:45:00 +02:00
yiyixuxu 2b6dcbfa1d fix controlnet 2024-10-20 19:23:37 +02:00
yiyixuxu af9572d759 controlnet 2024-10-19 12:36:12 +02:00
yiyixuxu ddea157979 add from_pipe + run_blocks 2024-10-17 20:02:36 +02:00
yiyixuxu ad3f9a26c0 update img2img, result match 2024-10-17 05:47:15 +02:00
yiyixuxu e8d0980f9f add img2img support - output does not match with non-modular pipeline completely yet (look into later) 2024-10-16 20:56:39 +02:00
yiyixuxu 52a7f1cb97 add dataflow info for each block in builder _repr_ 2024-10-16 09:04:32 +02:00
yiyixuxu 33f85fadf6 add 2024-10-14 19:16:23 +02:00
135 changed files with 18949 additions and 687 deletions
+49
View File
@@ -180,6 +180,55 @@ jobs:
pip install slack_sdk tabulate
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
run_torch_compile_tests:
name: PyTorch Compile CUDA tests
runs-on:
group: aws-g4dn-2xlarge
container:
image: diffusers/diffusers-pytorch-compile-cuda
options: --gpus 0 --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: NVIDIA-SMI
run: |
nvidia-smi
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test,training]
- name: Environment
run: |
python utils/print_env.py
- name: Run torch compile tests on GPU
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
RUN_COMPILE: yes
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_compile_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: torch_compile_test_reports
path: reports
- name: Generate Report and Notify Channel
if: always()
run: |
pip install slack_sdk tabulate
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
run_big_gpu_torch_tests:
name: Torch tests on big GPU
strategy:
+1 -1
View File
@@ -335,7 +335,7 @@ jobs:
- name: Environment
run: |
python utils/print_env.py
- name: Run example tests on GPU
- name: Run torch compile tests on GPU
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
RUN_COMPILE: yes
+5
View File
@@ -28,6 +28,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
- [`WanLoraLoaderMixin`] provides similar functions for [Wan](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan).
- [`CogView4LoraLoaderMixin`] provides similar functions for [CogView4](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogview4).
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
- [`HiDreamImageLoraLoaderMixin`] provides similar functions for [HiDream Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hidream)
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.
<Tip>
@@ -91,6 +92,10 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
[[autodoc]] loaders.lora_pipeline.AmusedLoraLoaderMixin
## HiDreamImageLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.HiDreamImageLoraLoaderMixin
## LoraBaseMixin
[[autodoc]] loaders.lora_base.LoraBaseMixin
+1 -1
View File
@@ -347,7 +347,7 @@ image = pipe(
height=1024,
prompt="wearing sunglasses",
negative_prompt="",
true_cfg=4.0,
true_cfg_scale=4.0,
generator=torch.Generator().manual_seed(4444),
ip_adapter_image=image,
).images[0]
+1 -1
View File
@@ -24,7 +24,7 @@
## Generating Videos with Wan 2.1
We will first need to install some addtional dependencies.
We will first need to install some additional dependencies.
```shell
pip install -u ftfy imageio-ffmpeg imageio
+1 -1
View File
@@ -216,7 +216,7 @@ Setting the `<ID_TOKEN>` is not necessary. From some limited experimentation, we
> - The original repository uses a `lora_alpha` of `1`. We found this not suitable in many runs, possibly due to difference in modeling backends and training settings. Our recommendation is to set to the `lora_alpha` to either `rank` or `rank // 2`.
> - If you're training on data whose captions generate bad results with the original model, a `rank` of 64 and above is good and also the recommendation by the team behind CogVideoX. If the generations are already moderately good on your training captions, a `rank` of 16/32 should work. We found that setting the rank too low, say `4`, is not ideal and doesn't produce promising results.
> - The authors of CogVideoX recommend 4000 training steps and 100 training videos overall to achieve the best result. While that might yield the best results, we found from our limited experimentation that 2000 steps and 25 videos could also be sufficient.
> - When using the Prodigy opitimizer for training, one can follow the recommendations from [this](https://huggingface.co/blog/sdxl_lora_advanced_script) blog. Prodigy tends to overfit quickly. From my very limited testing, I found a learning rate of `0.5` to be suitable in addition to `--prodigy_use_bias_correction`, `prodigy_safeguard_warmup` and `--prodigy_decouple`.
> - When using the Prodigy optimizer for training, one can follow the recommendations from [this](https://huggingface.co/blog/sdxl_lora_advanced_script) blog. Prodigy tends to overfit quickly. From my very limited testing, I found a learning rate of `0.5` to be suitable in addition to `--prodigy_use_bias_correction`, `prodigy_safeguard_warmup` and `--prodigy_decouple`.
> - The recommended learning rate by the CogVideoX authors and from our experimentation with Adam/AdamW is between `1e-3` and `1e-4` for a dataset of 25+ videos.
>
> Note that our testing is not exhaustive due to limited time for exploration. Our recommendation would be to play around with the different knobs and dials to find the best settings for your data.
+1 -1
View File
@@ -589,7 +589,7 @@ For stage 2 of DeepFloyd IF with DreamBooth, pay attention to these parameters:
* `--learning_rate=5e-6`, use a lower learning rate with a smaller effective batch size
* `--resolution=256`, the expected resolution for the upscaler
* `--train_batch_size=2` and `--gradient_accumulation_steps=6`, to effectively train on images wiht faces requires larger batch sizes
* `--train_batch_size=2` and `--gradient_accumulation_steps=6`, to effectively train on images with faces requires larger batch sizes
```bash
export MODEL_NAME="DeepFloyd/IF-II-L-v1.0"
+1 -1
View File
@@ -89,7 +89,7 @@ Many of the basic and important parameters are described in the [Text-to-image](
As with the script parameters, a walkthrough of the training script is provided in the [Text-to-image](text2image#training-script) training guide. Instead, this guide takes a look at the T2I-Adapter relevant parts of the script.
The training script begins by preparing the dataset. This incudes [tokenizing](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L674) the prompt and [applying transforms](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L714) to the images and conditioning images.
The training script begins by preparing the dataset. This includes [tokenizing](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L674) the prompt and [applying transforms](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L714) to the images and conditioning images.
```py
conditioning_image_transforms = transforms.Compose(
@@ -2181,7 +2181,7 @@ def main(args):
# Predict the noise residual
model_pred = transformer(
hidden_states=packed_noisy_model_input,
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
timestep=timesteps / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
+49 -2
View File
@@ -86,6 +86,7 @@ PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixar
| Perturbed-Attention Guidance |StableDiffusionPAGPipeline is a modification of StableDiffusionPipeline to support Perturbed-Attention Guidance (PAG).|[Perturbed-Attention Guidance](#perturbed-attention-guidance)|[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/perturbed_attention_guidance.ipynb)|[Hyoungwon Cho](https://github.com/HyoungwonCho)|
| CogVideoX DDIM Inversion Pipeline | Implementation of DDIM inversion and guided attention-based editing denoising process on CogVideoX. | [CogVideoX DDIM Inversion Pipeline](#cogvideox-ddim-inversion-pipeline) | - | [LittleNyima](https://github.com/LittleNyima) |
| FaithDiff Stable Diffusion XL Pipeline | Implementation of [(CVPR 2025) FaithDiff: Unleashing Diffusion Priors for Faithful Image Super-resolutionUnleashing Diffusion Priors for Faithful Image Super-resolution](https://arxiv.org/abs/2411.18824) - FaithDiff is a faithful image super-resolution method that leverages latent diffusion models by actively adapting the diffusion prior and jointly fine-tuning its components (encoder and diffusion model) with an alignment module to ensure high fidelity and structural consistency. | [FaithDiff Stable Diffusion XL Pipeline](#faithdiff-stable-diffusion-xl-pipeline) | [![Hugging Face Models](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/jychen9811/FaithDiff) | [Junyang Chen, Jinshan Pan, Jiangxin Dong, IMAG Lab, (Adapted by Eliseu Silva)](https://github.com/JyChen9811/FaithDiff) |
| Stable Diffusion 3 InstructPix2Pix Pipeline | Implementation of Stable Diffusion 3 InstructPix2Pix Pipeline | [Stable Diffusion 3 InstructPix2Pix Pipeline](#stable-diffusion-3-instructpix2pix-pipeline) | [![Hugging Face Models](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/BleachNick/SD3_UltraEdit_freeform) [![Hugging Face Models](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/CaptainZZZ/sd3-instructpix2pix) | [Jiayu Zhang](https://github.com/xduzhangjiayu) and [Haozhe Zhao](https://github.com/HaozheZhao)|
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
```py
@@ -5381,7 +5382,7 @@ pipe = DiffusionPipeline.from_pretrained(
# Here we need use pipeline internal unet model
pipe.unet = pipe.unet_model.from_pretrained(model_id, subfolder="unet", variant="fp16", use_safetensors=True)
# Load aditional layers to the model
# Load additional layers to the model
pipe.unet.load_additional_layers(weight_path="proc_data/faithdiff/FaithDiff.bin", dtype=dtype)
# Enable vae tiling
@@ -5432,4 +5433,50 @@ cropped_image = gen_image.crop((0, 0, width_init, height_init))
cropped_image.save("data/result.png")
````
### Result
[<img src="https://huggingface.co/datasets/DEVAIEXP/assets/resolve/main/faithdiff_restored.PNG" width="512px" height="512px"/>](https://imgsli.com/MzY1NzE2)
[<img src="https://huggingface.co/datasets/DEVAIEXP/assets/resolve/main/faithdiff_restored.PNG" width="512px" height="512px"/>](https://imgsli.com/MzY1NzE2)
# Stable Diffusion 3 InstructPix2Pix Pipeline
This the implementation of the Stable Diffusion 3 InstructPix2Pix Pipeline, based on the HuggingFace Diffusers.
## Example Usage
This pipeline aims to edit image based on user's instruction by using SD3
````py
import torch
from diffusers import SD3Transformer2DModel
from diffusers import DiffusionPipeline
from diffusers.utils import load_image
resolution = 512
image = load_image("https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png").resize(
(resolution, resolution)
)
edit_instruction = "Turn sky into a sunny one"
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers", custom_pipeline="pipeline_stable_diffusion_3_instruct_pix2pix", torch_dtype=torch.float16).to('cuda')
pipe.transformer = SD3Transformer2DModel.from_pretrained("CaptainZZZ/sd3-instructpix2pix",torch_dtype=torch.float16).to('cuda')
edited_image = pipe(
prompt=edit_instruction,
image=image,
height=resolution,
width=resolution,
guidance_scale=7.5,
image_guidance_scale=1.5,
num_inference_steps=30,
).images[0]
edited_image.save("edited_image.png")
````
|Original|Edited|
|---|---|
|![Original image](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/StableDiffusion3InstructPix2Pix/mountain.png)|![Edited image](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/StableDiffusion3InstructPix2Pix/edited.png)
### Note
This model is trained on 512x512, so input size is better on 512x512.
For better editing performance, please refer to this powerful model https://huggingface.co/BleachNick/SD3_UltraEdit_freeform and Paper "UltraEdit: Instruction-based Fine-Grained Image
Editing at Scale", many thanks to their contribution!
+3 -3
View File
@@ -312,9 +312,9 @@ if __name__ == "__main__":
# These are the coordinates of the output image
out_coordinates = np.arange(1, out_length + 1)
# since both scale-factor and output size can be provided simulatneously, perserving the center of the image requires shifting
# the output coordinates. the deviation is because out_length doesn't necesary equal in_length*scale.
# to keep the center we need to subtract half of this deivation so that we get equal margins for boths sides and center is preserved.
# since both scale-factor and output size can be provided simultaneously, preserving the center of the image requires shifting
# the output coordinates. the deviation is because out_length doesn't necessary equal in_length*scale.
# to keep the center we need to subtract half of this deviation so that we get equal margins for both sides and center is preserved.
shifted_out_coordinates = out_coordinates - (out_length - in_length * scale) / 2
# These are the matching positions of the output-coordinates on the input image coordinates.
+4 -4
View File
@@ -351,7 +351,7 @@ def my_forward(
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
added_cond_kwargs: (`dict`, *optional*):
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
are passed along to the UNet blocks.
Returns:
@@ -864,9 +864,9 @@ def get_flow_and_interframe_paras(flow_model, imgs):
class AttentionControl:
"""
Control FRESCO-based attention
* enable/diable spatial-guided attention
* enable/diable temporal-guided attention
* enable/diable cross-frame attention
* enable/disable spatial-guided attention
* enable/disable temporal-guided attention
* enable/disable cross-frame attention
* collect intermediate attention feature (for spatial-guided attention)
"""
+1 -1
View File
@@ -34,7 +34,7 @@ class RASGAttnProcessor:
temb: Optional[torch.Tensor] = None,
scale: float = 1.0,
) -> torch.Tensor:
# Same as the default AttnProcessor up untill the part where similarity matrix gets saved
# Same as the default AttnProcessor up until the part where similarity matrix gets saved
downscale_factor = self.mask_resoltuion // hidden_states.shape[1]
residual = hidden_states
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -889,7 +889,7 @@ def main(args):
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes
split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be divided by the number of processes assuming batches are multiplied by the number of processes
)
# Make one log on every process with the configuration for debugging.
@@ -721,7 +721,7 @@ def main(args):
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes
split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be divided by the number of processes assuming batches are multiplied by the number of processes
)
# Make one log on every process with the configuration for debugging.
@@ -884,7 +884,7 @@ def main(args):
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes
split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be divided by the number of processes assuming batches are multiplied by the number of processes
)
# Make one log on every process with the configuration for debugging.
@@ -854,7 +854,7 @@ def main(args):
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes
split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be divided by the number of processes assuming batches are multiplied by the number of processes
)
# Make one log on every process with the configuration for debugging.
@@ -894,7 +894,7 @@ def main(args):
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes
split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be divided by the number of processes assuming batches are multiplied by the number of processes
)
# Make one log on every process with the configuration for debugging.
+15 -2
View File
@@ -6,7 +6,19 @@ Training script provided by LibAI, which is an institution dedicated to the prog
> [!NOTE]
> **Memory consumption**
>
> Flux can be quite expensive to run on consumer hardware devices and as a result, ControlNet training of it comes with higher memory requirements than usual.
> Flux can be quite expensive to run on consumer hardware devices and as a result, ControlNet training of it comes with higher memory requirements than usual.
Here is a gpu memory consumption for reference, tested on a single A100 with 80G.
| period | GPU |
| - | - |
| load as float32 | ~70G |
| mv transformer and vae to bf16 | ~48G |
| pre compute txt embeddings | ~62G |
| **offload te to cpu** | ~30G |
| training | ~58G |
| validation | ~71G |
> **Gated access**
>
@@ -98,8 +110,9 @@ accelerate launch train_controlnet_flux.py \
--validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
--validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_accumulation_steps=16 \
--report_to="wandb" \
--lr_scheduler="cosine" \
--num_double_layers=4 \
--num_single_layers=0 \
--seed=42 \
+18 -6
View File
@@ -148,7 +148,7 @@ def log_validation(
pooled_prompt_embeds=pooled_prompt_embeds,
control_image=validation_image,
num_inference_steps=28,
controlnet_conditioning_scale=0.7,
controlnet_conditioning_scale=1,
guidance_scale=3.5,
generator=generator,
).images[0]
@@ -639,6 +639,15 @@ def parse_args(input_args=None):
action="store_true",
help="Enable model cpu offload and save memory.",
)
parser.add_argument(
"--image_interpolation_mode",
type=str,
default="lanczos",
choices=[
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
],
help="The image interpolation method to use for resizing images.",
)
if input_args is not None:
args = parser.parse_args(input_args)
@@ -736,9 +745,13 @@ def get_train_dataset(args, accelerator):
def prepare_train_dataset(dataset, accelerator):
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
if interpolation is None:
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
image_transforms = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.Resize(args.resolution, interpolation=interpolation),
transforms.CenterCrop(args.resolution),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
@@ -747,7 +760,7 @@ def prepare_train_dataset(dataset, accelerator):
conditioning_image_transforms = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.Resize(args.resolution, interpolation=interpolation),
transforms.CenterCrop(args.resolution),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
@@ -1085,8 +1098,6 @@ def main(args):
return {"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds, "text_ids": text_ids}
train_dataset = get_train_dataset(args, accelerator)
text_encoders = [text_encoder_one, text_encoder_two]
tokenizers = [tokenizer_one, tokenizer_two]
compute_embeddings_fn = functools.partial(
compute_embeddings,
flux_controlnet_pipeline=flux_controlnet_pipeline,
@@ -1103,7 +1114,8 @@ def main(args):
compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint, batch_size=50
)
del text_encoders, tokenizers, text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
text_encoder_one.to("cpu")
text_encoder_two.to("cpu")
free_memory()
# Then get the training dataset ready to be passed to the dataloader.
+41 -3
View File
@@ -134,7 +134,25 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
for validation_prompt, validation_image in zip(validation_prompts, validation_images):
validation_image = Image.open(validation_image).convert("RGB")
validation_image = validation_image.resize((args.resolution, args.resolution))
try:
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper())
except (AttributeError, KeyError):
supported_interpolation_modes = [
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
]
raise ValueError(
f"Interpolation mode {args.image_interpolation_mode} is not supported. "
f"Please select one of the following: {', '.join(supported_interpolation_modes)}"
)
transform = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=interpolation),
transforms.CenterCrop(args.resolution),
]
)
validation_image = transform(validation_image)
images = []
@@ -587,6 +605,15 @@ def parse_args(input_args=None):
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
),
)
parser.add_argument(
"--image_interpolation_mode",
type=str,
default="lanczos",
choices=[
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
],
help="The image interpolation method to use for resizing images.",
)
if input_args is not None:
args = parser.parse_args(input_args)
@@ -732,9 +759,20 @@ def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prom
def prepare_train_dataset(dataset, accelerator):
try:
interpolation_mode = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper())
except (AttributeError, KeyError):
supported_interpolation_modes = [
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
]
raise ValueError(
f"Interpolation mode {args.image_interpolation_mode} is not supported. "
f"Please select one of the following: {', '.join(supported_interpolation_modes)}"
)
image_transforms = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.Resize(args.resolution, interpolation=interpolation_mode),
transforms.CenterCrop(args.resolution),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
@@ -743,7 +781,7 @@ def prepare_train_dataset(dataset, accelerator):
conditioning_image_transforms = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.Resize(args.resolution, interpolation=interpolation_mode),
transforms.CenterCrop(args.resolution),
transforms.ToTensor(),
]
+119
View File
@@ -0,0 +1,119 @@
# DreamBooth training example for HiDream Image
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject.
The `train_dreambooth_lora_hidream.py` script shows how to implement the training procedure with [LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) and adapt it for [HiDream Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/).
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
## Running locally with PyTorch
### Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies:
**Important**
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
```
Then cd in the `examples/dreambooth` folder and run
```bash
pip install -r requirements_hidream.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
```
Or for a default accelerate configuration without answering questions about your environment
```bash
accelerate config default
```
Or if your environment doesn't support an interactive shell (e.g., a notebook)
```python
from accelerate.utils import write_basic_config
write_basic_config()
```
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.14.0` installed in your environment.
### 3d icon example
For this example we will use some 3d icon images: https://huggingface.co/datasets/linoyts/3d_icon.
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
Now, we can launch training using:
> [!NOTE]
> The following training configuration prioritizes lower memory consumption by using gradient checkpointing,
> 8-bit Adam optimizer, latent caching, offloading, no validation.
> all text embeddings are pre-computed to save memory.
```bash
export MODEL_NAME="HiDream-ai/HiDream-I1-Dev"
export INSTANCE_DIR="linoyts/3d_icon"
export OUTPUT_DIR="trained-hidream-lora"
accelerate launch train_dreambooth_lora_hidream.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="bf16" \
--instance_prompt="3d icon" \
--caption_column="prompt"\
--validation_prompt="a 3dicon, a llama eating ramen" \
--resolution=1024 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--use_8bit_adam \
--rank=8 \
--learning_rate=2e-4 \
--report_to="wandb" \
--lr_scheduler="constant_with_warmup" \
--lr_warmup_steps=100 \
--max_train_steps=1000 \
--cache_latents\
--gradient_checkpointing \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```
For using `push_to_hub`, make you're logged into your Hugging Face account:
```bash
huggingface-cli login
```
To better track our training experiments, we're using the following flags in the command above:
* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
## Notes
Additionally, we welcome you to explore the following CLI arguments:
* `--lora_layers`: The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "to_k,to_q,to_v" will result in lora training of attention layers only.
* `--rank`: The rank of the LoRA layers. The higher the rank, the more parameters are trained. The default is 16.
We provide several options for optimizing memory optimization:
* `--offload`: When enabled, we will offload the text encoder and VAE to CPU, when they are not used.
* `cache_latents`: When enabled, we will pre-compute the latents from the input images with the VAE and remove the VAE from memory once done.
* `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library.
Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/) of the `HiDreamImagePipeline` to know more about the model.
@@ -0,0 +1,8 @@
accelerate>=1.4.0
torchvision
transformers>=4.50.0
ftfy
tensorboard
Jinja2
peft>=0.14.0
sentencepiece
@@ -0,0 +1,220 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
import tempfile
import safetensors
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class DreamBoothLoRAHiDreamImage(ExamplesTestsAccelerate):
instance_data_dir = "docs/source/en/imgs"
pretrained_model_name_or_path = "hf-internal-testing/tiny-hidream-i1-pipe"
text_encoder_4_path = "hf-internal-testing/tiny-random-LlamaForCausalLM"
tokenizer_4_path = "hf-internal-testing/tiny-random-LlamaForCausalLM"
script_path = "examples/dreambooth/train_dreambooth_lora_hidream.py"
transformer_layer_type = "double_stream_blocks.0.block.attn1.to_k"
def test_dreambooth_lora_hidream(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}
--pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}
--instance_data_dir {self.instance_data_dir}
--resolution 32
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--max_sequence_length 16
""".split()
test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_latent_caching(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}
--pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}
--instance_data_dir {self.instance_data_dir}
--resolution 32
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--max_sequence_length 16
""".split()
test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_layers(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}
--pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}
--instance_data_dir {self.instance_data_dir}
--resolution 32
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lora_layers {self.transformer_layer_type}
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--max_sequence_length 16
""".split()
test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names. In this test, we only params of
# `self.transformer_layer_type` should be in the state dict.
starts_with_transformer = all(self.transformer_layer_type in key for key in lora_state_dict)
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_hidream_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}
--pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--resolution=32
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=6
--checkpoints_total_limit=2
--checkpointing_steps=2
--max_sequence_length 16
""".split()
test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)
def test_dreambooth_lora_hidream_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}
--pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--resolution=32
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=4
--checkpointing_steps=2
--max_sequence_length 16
""".split()
test_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + test_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
resume_run_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--pretrained_text_encoder_4_name_or_path {self.text_encoder_4_path}
--pretrained_tokenizer_4_name_or_path {self.tokenizer_4_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--resolution=32
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=8
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
--max_sequence_length 16
""".split()
resume_run_args.extend(["--instance_prompt", ""])
run_command(self._launch_args + resume_run_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
+14 -2
View File
@@ -618,6 +618,15 @@ def parse_args(input_args=None):
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument(
"--image_interpolation_mode",
type=str,
default="lanczos",
choices=[
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
],
help="The image interpolation method to use for resizing images.",
)
if input_args is not None:
args = parser.parse_args(input_args)
@@ -737,7 +746,10 @@ class DreamBoothDataset(Dataset):
self.instance_images.extend(itertools.repeat(img, repeats))
self.pixel_values = []
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
if interpolation is None:
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
train_resize = transforms.Resize(size, interpolation=interpolation)
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
train_flip = transforms.RandomHorizontalFlip(p=1.0)
train_transforms = transforms.Compose(
@@ -1622,7 +1634,7 @@ def main(args):
# Predict the noise residual
model_pred = transformer(
hidden_states=packed_noisy_model_input,
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
timestep=timesteps / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
+14 -1
View File
@@ -524,6 +524,15 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument(
"--image_interpolation_mode",
type=str,
default="lanczos",
choices=[
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
],
help="The image interpolation method to use for resizing images.",
)
if input_args is not None:
args = parser.parse_args(input_args)
@@ -601,9 +610,13 @@ class DreamBoothDataset(Dataset):
else:
self.class_data_root = None
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
if interpolation is None:
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
self.image_transforms = transforms.Compose(
[
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.Resize(size, interpolation=interpolation),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
@@ -1749,7 +1749,7 @@ def main(args):
# Predict the noise residual
model_pred = transformer(
hidden_states=packed_noisy_model_input,
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
timestep=timesteps / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
File diff suppressed because it is too large Load Diff
@@ -1088,7 +1088,7 @@ def main(args):
text_ids = batch["text_ids"].to(device=accelerator.device, dtype=weight_dtype)
model_pred = transformer(
hidden_states=packed_noisy_model_input,
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
timestep=timesteps / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
+18 -2
View File
@@ -499,6 +499,15 @@ def parse_args():
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
),
)
parser.add_argument(
"--image_interpolation_mode",
type=str,
default="lanczos",
choices=[
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
],
help="The image interpolation method to use for resizing images.",
)
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
@@ -787,10 +796,17 @@ def main():
)
return inputs.input_ids
# Preprocessing the datasets.
# Get the specified interpolation method from the args
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
# Raise an error if the interpolation method is invalid
if interpolation is None:
raise ValueError(f"Unsupported interpolation mode {args.image_interpolation_mode}.")
# Data preprocessing transformations
train_transforms = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.Resize(args.resolution, interpolation=interpolation), # Use dynamic interpolation method
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
transforms.ToTensor(),
@@ -418,6 +418,15 @@ def parse_args():
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument(
"--image_interpolation_mode",
type=str,
default="lanczos",
choices=[
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
],
help="The image interpolation method to use for resizing images.",
)
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
@@ -649,10 +658,17 @@ def main():
)
return inputs.input_ids
# Preprocessing the datasets.
# Get the specified interpolation method from the args
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
# Raise an error if the interpolation method is invalid
if interpolation is None:
raise ValueError(f"Unsupported interpolation mode {args.image_interpolation_mode}.")
# Data preprocessing transformations
train_transforms = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.Resize(args.resolution, interpolation=interpolation), # Use dynamic interpolation method
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
transforms.ToTensor(),
+6
View File
@@ -239,6 +239,7 @@ else:
"KarrasVePipeline",
"LDMPipeline",
"LDMSuperResolutionPipeline",
"ModularPipeline",
"PNDMPipeline",
"RePaintPipeline",
"ScoreSdeVePipeline",
@@ -493,10 +494,12 @@ else:
"StableDiffusionXLImg2ImgPipeline",
"StableDiffusionXLInpaintPipeline",
"StableDiffusionXLInstructPix2PixPipeline",
"StableDiffusionXLModularPipeline",
"StableDiffusionXLPAGImg2ImgPipeline",
"StableDiffusionXLPAGInpaintPipeline",
"StableDiffusionXLPAGPipeline",
"StableDiffusionXLPipeline",
"StableDiffusionXLAutoPipeline",
"StableUnCLIPImg2ImgPipeline",
"StableUnCLIPPipeline",
"StableVideoDiffusionPipeline",
@@ -834,6 +837,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
KarrasVePipeline,
LDMPipeline,
LDMSuperResolutionPipeline,
ModularPipeline,
PNDMPipeline,
RePaintPipeline,
ScoreSdeVePipeline,
@@ -1066,10 +1070,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLInstructPix2PixPipeline,
StableDiffusionXLModularPipeline,
StableDiffusionXLPAGImg2ImgPipeline,
StableDiffusionXLPAGInpaintPipeline,
StableDiffusionXLPAGPipeline,
StableDiffusionXLPipeline,
StableDiffusionXLAutoPipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
StableVideoDiffusionPipeline,
+745
View File
@@ -0,0 +1,745 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from .models.attention_processor import (
Attention,
AttentionProcessor,
PAGCFGIdentitySelfAttnProcessor2_0,
PAGIdentitySelfAttnProcessor2_0,
)
from .utils import logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Args:
noise_cfg (`torch.Tensor`):
The predicted noise tensor for the guided diffusion process.
noise_pred_text (`torch.Tensor`):
The predicted noise tensor for the text-guided diffusion process.
guidance_rescale (`float`, *optional*, defaults to 0.0):
A rescale factor applied to the noise predictions.
Returns:
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
class CFGGuider:
"""
This class is used to guide the pipeline with CFG (Classifier-Free Guidance).
"""
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0 and not self._disable_guidance
@property
def guidance_rescale(self):
return self._guidance_rescale
@property
def guidance_scale(self):
return self._guidance_scale
@property
def batch_size(self):
return self._batch_size
def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]):
# a flag to disable CFG, e.g. we disable it for LCM and use a guidance scale embedding instead
disable_guidance = guider_kwargs.get("disable_guidance", False)
guidance_scale = guider_kwargs.get("guidance_scale", None)
if guidance_scale is None:
raise ValueError("guidance_scale is not provided in guider_kwargs")
guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0)
batch_size = guider_kwargs.get("batch_size", None)
if batch_size is None:
raise ValueError("batch_size is not provided in guider_kwargs")
self._guidance_scale = guidance_scale
self._guidance_rescale = guidance_rescale
self._batch_size = batch_size
self._disable_guidance = disable_guidance
def reset_guider(self, pipeline):
pass
def maybe_update_guider(self, pipeline, timestep):
pass
def maybe_update_input(self, pipeline, cond_input):
pass
def _maybe_split_prepared_input(self, cond):
"""
Process and potentially split the conditional input for Classifier-Free Guidance (CFG).
This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`).
It determines whether to split the input based on its batch size relative to the expected batch size.
Args:
cond (torch.Tensor): The conditional input tensor to process.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- The negative conditional input (uncond_input)
- The positive conditional input (cond_input)
"""
if cond.shape[0] == self.batch_size * 2:
neg_cond = cond[0 : self.batch_size]
cond = cond[self.batch_size :]
return neg_cond, cond
elif cond.shape[0] == self.batch_size:
return cond, cond
else:
raise ValueError(f"Unsupported input shape: {cond.shape}")
def _is_prepared_input(self, cond):
"""
Check if the input is already prepared for Classifier-Free Guidance (CFG).
Args:
cond (torch.Tensor): The conditional input tensor to check.
Returns:
bool: True if the input is already prepared, False otherwise.
"""
cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond
return cond_tensor.shape[0] == self.batch_size * 2
def prepare_input(
self,
cond_input: Union[torch.Tensor, List[torch.Tensor]],
negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""
Prepare the input for CFG.
Args:
cond_input (Union[torch.Tensor, List[torch.Tensor]]):
The conditional input. It can be a single tensor or a
list of tensors. It must have the same length as `negative_cond_input`.
negative_cond_input (Union[torch.Tensor, List[torch.Tensor]]): The negative conditional input. It can be a
single tensor or a list of tensors. It must have the same length as `cond_input`.
Returns:
Union[torch.Tensor, List[torch.Tensor]]: The prepared input.
"""
# we check if cond_input already has CFG applied, and split if it is the case.
if self._is_prepared_input(cond_input) and self.do_classifier_free_guidance:
return cond_input
if self._is_prepared_input(cond_input) and not self.do_classifier_free_guidance:
if isinstance(cond_input, list):
negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input])
else:
negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input)
if not self._is_prepared_input(cond_input) and self.do_classifier_free_guidance and negative_cond_input is None:
raise ValueError(
"`negative_cond_input` is required when cond_input does not already contains negative conditional input"
)
if isinstance(cond_input, (list, tuple)):
if not self.do_classifier_free_guidance:
return cond_input
if len(negative_cond_input) != len(cond_input):
raise ValueError("The length of negative_cond_input and cond_input must be the same.")
prepared_input = []
for neg_cond, cond in zip(negative_cond_input, cond_input):
if neg_cond.shape[0] != cond.shape[0]:
raise ValueError("The batch size of negative_cond_input and cond_input must be the same.")
prepared_input.append(torch.cat([neg_cond, cond], dim=0))
return prepared_input
elif isinstance(cond_input, torch.Tensor):
if not self.do_classifier_free_guidance:
return cond_input
else:
return torch.cat([negative_cond_input, cond_input], dim=0)
else:
raise ValueError(f"Unsupported input type: {type(cond_input)}")
def apply_guidance(
self,
model_output: torch.Tensor,
timestep: int = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if not self.do_classifier_free_guidance:
return model_output
noise_pred_uncond, noise_pred_text = model_output.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
return noise_pred
class PAGGuider:
"""
This class is used to guide the pipeline with CFG (Classifier-Free Guidance).
"""
def __init__(
self,
pag_applied_layers: Union[str, List[str]],
pag_attn_processors: Tuple[AttentionProcessor, AttentionProcessor] = (
PAGCFGIdentitySelfAttnProcessor2_0(),
PAGIdentitySelfAttnProcessor2_0(),
),
):
r"""
Set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid.
Args:
pag_applied_layers (`str` or `List[str]`):
One or more strings identifying the layer names, or a simple regex for matching multiple layers, where
PAG is to be applied. A few ways of expected usage are as follows:
- Single layers specified as - "blocks.{layer_index}"
- Multiple layers as a list - ["blocks.{layers_index_1}", "blocks.{layer_index_2}", ...]
- Multiple layers as a block name - "mid"
- Multiple layers as regex - "blocks.({layer_index_1}|{layer_index_2})"
pag_attn_processors:
(`Tuple[AttentionProcessor, AttentionProcessor]`, defaults to `(PAGCFGIdentitySelfAttnProcessor2_0(),
PAGIdentitySelfAttnProcessor2_0())`): A tuple of two attention processors. The first attention
processor is for PAG with Classifier-free guidance enabled (conditional and unconditional). The second
attention processor is for PAG with CFG disabled (unconditional only).
"""
if not isinstance(pag_applied_layers, list):
pag_applied_layers = [pag_applied_layers]
if pag_attn_processors is not None:
if not isinstance(pag_attn_processors, tuple) or len(pag_attn_processors) != 2:
raise ValueError("Expected a tuple of two attention processors")
for i in range(len(pag_applied_layers)):
if not isinstance(pag_applied_layers[i], str):
raise ValueError(
f"Expected either a string or a list of string but got type {type(pag_applied_layers[i])}"
)
self.pag_applied_layers = pag_applied_layers
self._pag_attn_processors = pag_attn_processors
def _set_pag_attn_processor(self, model, pag_applied_layers, do_classifier_free_guidance):
r"""
Set the attention processor for the PAG layers.
"""
pag_attn_processors = self._pag_attn_processors
pag_attn_proc = pag_attn_processors[0] if do_classifier_free_guidance else pag_attn_processors[1]
def is_self_attn(module: nn.Module) -> bool:
r"""
Check if the module is self-attention module based on its name.
"""
return isinstance(module, Attention) and not module.is_cross_attention
def is_fake_integral_match(layer_id, name):
layer_id = layer_id.split(".")[-1]
name = name.split(".")[-1]
return layer_id.isnumeric() and name.isnumeric() and layer_id == name
for layer_id in pag_applied_layers:
# for each PAG layer input, we find corresponding self-attention layers in the unet model
target_modules = []
for name, module in model.named_modules():
# Identify the following simple cases:
# (1) Self Attention layer existing
# (2) Whether the module name matches pag layer id even partially
# (3) Make sure it's not a fake integral match if the layer_id ends with a number
# For example, blocks.1, blocks.10 should be differentiable if layer_id="blocks.1"
if (
is_self_attn(module)
and re.search(layer_id, name) is not None
and not is_fake_integral_match(layer_id, name)
):
logger.debug(f"Applying PAG to layer: {name}")
target_modules.append(module)
if len(target_modules) == 0:
raise ValueError(f"Cannot find PAG layer to set attention processor for: {layer_id}")
for module in target_modules:
module.processor = pag_attn_proc
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1 and not self._disable_guidance
@property
def do_perturbed_attention_guidance(self):
return self._pag_scale > 0 and not self._disable_guidance
@property
def do_pag_adaptive_scaling(self):
return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and not self._disable_guidance
@property
def guidance_scale(self):
return self._guidance_scale
@property
def guidance_rescale(self):
return self._guidance_rescale
@property
def batch_size(self):
return self._batch_size
@property
def pag_scale(self):
return self._pag_scale
@property
def pag_adaptive_scale(self):
return self._pag_adaptive_scale
def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]):
pag_scale = guider_kwargs.get("pag_scale", 3.0)
pag_adaptive_scale = guider_kwargs.get("pag_adaptive_scale", 0.0)
batch_size = guider_kwargs.get("batch_size", None)
if batch_size is None:
raise ValueError("batch_size is a required argument for PAGGuider")
guidance_scale = guider_kwargs.get("guidance_scale", None)
guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0)
disable_guidance = guider_kwargs.get("disable_guidance", False)
if guidance_scale is None:
raise ValueError("guidance_scale is a required argument for PAGGuider")
self._pag_scale = pag_scale
self._pag_adaptive_scale = pag_adaptive_scale
self._guidance_scale = guidance_scale
self._disable_guidance = disable_guidance
self._guidance_rescale = guidance_rescale
self._batch_size = batch_size
if not hasattr(pipeline, "original_attn_proc") or pipeline.original_attn_proc is None:
pipeline.original_attn_proc = pipeline.unet.attn_processors
self._set_pag_attn_processor(
model=pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer,
pag_applied_layers=self.pag_applied_layers,
do_classifier_free_guidance=self.do_classifier_free_guidance,
)
def reset_guider(self, pipeline):
if (
self.do_perturbed_attention_guidance
and hasattr(pipeline, "original_attn_proc")
and pipeline.original_attn_proc is not None
):
pipeline.unet.set_attn_processor(pipeline.original_attn_proc)
pipeline.original_attn_proc = None
def maybe_update_guider(self, pipeline, timestep):
pass
def maybe_update_input(self, pipeline, cond_input):
pass
def _is_prepared_input(self, cond):
"""
Check if the input is already prepared for Perturbed Attention Guidance (PAG).
Args:
cond (torch.Tensor): The conditional input tensor to check.
Returns:
bool: True if the input is already prepared, False otherwise.
"""
cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond
return cond_tensor.shape[0] == self.batch_size * 3
def _maybe_split_prepared_input(self, cond):
"""
Process and potentially split the conditional input for Classifier-Free Guidance (CFG).
This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`).
It determines whether to split the input based on its batch size relative to the expected batch size.
Args:
cond (torch.Tensor): The conditional input tensor to process.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- The negative conditional input (uncond_input)
- The positive conditional input (cond_input)
"""
if cond.shape[0] == self.batch_size * 3:
neg_cond = cond[0 : self.batch_size]
cond = cond[self.batch_size : self.batch_size * 2]
return neg_cond, cond
elif cond.shape[0] == self.batch_size:
return cond, cond
else:
raise ValueError(f"Unsupported input shape: {cond.shape}")
def prepare_input(
self,
cond_input: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]],
negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]] = None,
) -> Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]:
"""
Prepare the input for CFG.
Args:
cond_input (Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]):
The conditional input. It can be a single tensor or a
list of tensors. It must have the same length as `negative_cond_input`.
negative_cond_input (Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]):
The negative conditional input. It can be a single tensor or a list of tensors. It must have the same
length as `cond_input`.
Returns:
Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]: The prepared input.
"""
# we check if cond_input already has CFG applied, and split if it is the case.
if self._is_prepared_input(cond_input) and self.do_perturbed_attention_guidance:
return cond_input
if self._is_prepared_input(cond_input) and not self.do_perturbed_attention_guidance:
if isinstance(cond_input, list):
negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input])
else:
negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input)
if not self._is_prepared_input(cond_input) and self.do_perturbed_attention_guidance and negative_cond_input is None:
raise ValueError(
"`negative_cond_input` is required when cond_input does not already contains negative conditional input"
)
if isinstance(cond_input, (list, tuple)):
if not self.do_perturbed_attention_guidance:
return cond_input
if len(negative_cond_input) != len(cond_input):
raise ValueError("The length of negative_cond_input and cond_input must be the same.")
prepared_input = []
for neg_cond, cond in zip(negative_cond_input, cond_input):
if neg_cond.shape[0] != cond.shape[0]:
raise ValueError("The batch size of negative_cond_input and cond_input must be the same.")
cond = torch.cat([cond] * 2, dim=0)
if self.do_classifier_free_guidance:
prepared_input.append(torch.cat([neg_cond, cond], dim=0))
else:
prepared_input.append(cond)
return prepared_input
elif isinstance(cond_input, torch.Tensor):
if not self.do_perturbed_attention_guidance:
return cond_input
cond_input = torch.cat([cond_input] * 2, dim=0)
if self.do_classifier_free_guidance:
return torch.cat([negative_cond_input, cond_input], dim=0)
else:
return cond_input
else:
raise ValueError(f"Unsupported input type: {type(negative_cond_input)} and {type(cond_input)}")
def apply_guidance(
self,
model_output: torch.Tensor,
timestep: int,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if not self.do_perturbed_attention_guidance:
return model_output
if self.do_pag_adaptive_scaling:
pag_scale = max(self._pag_scale - self._pag_adaptive_scale * (1000 - timestep), 0)
else:
pag_scale = self._pag_scale
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text, noise_pred_perturb = model_output.chunk(3)
noise_pred = (
noise_pred_uncond
+ self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ pag_scale * (noise_pred_text - noise_pred_perturb)
)
else:
noise_pred_text, noise_pred_perturb = model_output.chunk(2)
noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb)
if self.guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
return noise_pred
class MomentumBuffer:
def __init__(self, momentum: float):
self.momentum = momentum
self.running_average = 0
def update(self, update_value: torch.Tensor):
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average
class APGGuider:
"""
This class is used to guide the pipeline with APG (Adaptive Projected Guidance).
"""
def normalized_guidance(
self,
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
guidance_scale: float,
momentum_buffer: MomentumBuffer = None,
norm_threshold: float = 0.0,
eta: float = 1.0,
):
"""
Based on the findings of [Eliminating Oversaturation and Artifacts of High Guidance Scales in Diffusion
Models](https://arxiv.org/pdf/2410.02416)
"""
diff = pred_cond - pred_uncond
if momentum_buffer is not None:
momentum_buffer.update(diff)
diff = momentum_buffer.running_average
if norm_threshold > 0:
ones = torch.ones_like(diff)
diff_norm = diff.norm(p=2, dim=[-1, -2, -3], keepdim=True)
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
diff = diff * scale_factor
v0, v1 = diff.double(), pred_cond.double()
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
diff_parallel, diff_orthogonal = v0_parallel.to(diff.dtype), v0_orthogonal.to(diff.dtype)
normalized_update = diff_orthogonal + eta * diff_parallel
pred_guided = pred_cond + (guidance_scale - 1) * normalized_update
return pred_guided
@property
def adaptive_projected_guidance_momentum(self):
return self._adaptive_projected_guidance_momentum
@property
def adaptive_projected_guidance_rescale_factor(self):
return self._adaptive_projected_guidance_rescale_factor
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0 and not self._disable_guidance
@property
def guidance_rescale(self):
return self._guidance_rescale
@property
def guidance_scale(self):
return self._guidance_scale
@property
def batch_size(self):
return self._batch_size
def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]):
disable_guidance = guider_kwargs.get("disable_guidance", False)
guidance_scale = guider_kwargs.get("guidance_scale", None)
if guidance_scale is None:
raise ValueError("guidance_scale is not provided in guider_kwargs")
adaptive_projected_guidance_momentum = guider_kwargs.get("adaptive_projected_guidance_momentum", None)
adaptive_projected_guidance_rescale_factor = guider_kwargs.get(
"adaptive_projected_guidance_rescale_factor", 15.0
)
guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0)
batch_size = guider_kwargs.get("batch_size", None)
if batch_size is None:
raise ValueError("batch_size is not provided in guider_kwargs")
self._adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
self._adaptive_projected_guidance_rescale_factor = adaptive_projected_guidance_rescale_factor
self._guidance_scale = guidance_scale
self._guidance_rescale = guidance_rescale
self._batch_size = batch_size
self._disable_guidance = disable_guidance
if adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(adaptive_projected_guidance_momentum)
else:
self.momentum_buffer = None
self.scheduler = pipeline.scheduler
def reset_guider(self, pipeline):
pass
def maybe_update_guider(self, pipeline, timestep):
pass
def maybe_update_input(self, pipeline, cond_input):
pass
def _maybe_split_prepared_input(self, cond):
"""
Process and potentially split the conditional input for Classifier-Free Guidance (CFG).
This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`).
It determines whether to split the input based on its batch size relative to the expected batch size.
Args:
cond (torch.Tensor): The conditional input tensor to process.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- The negative conditional input (uncond_input)
- The positive conditional input (cond_input)
"""
if cond.shape[0] == self.batch_size * 2:
neg_cond = cond[0 : self.batch_size]
cond = cond[self.batch_size :]
return neg_cond, cond
elif cond.shape[0] == self.batch_size:
return cond, cond
else:
raise ValueError(f"Unsupported input shape: {cond.shape}")
def _is_prepared_input(self, cond):
"""
Check if the input is already prepared for Classifier-Free Guidance (CFG).
Args:
cond (torch.Tensor): The conditional input tensor to check.
Returns:
bool: True if the input is already prepared, False otherwise.
"""
cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond
return cond_tensor.shape[0] == self.batch_size * 2
def prepare_input(
self,
cond_input: Union[torch.Tensor, List[torch.Tensor]],
negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""
Prepare the input for CFG.
Args:
cond_input (Union[torch.Tensor, List[torch.Tensor]]):
The conditional input. It can be a single tensor or a
list of tensors. It must have the same length as `negative_cond_input`.
negative_cond_input (Union[torch.Tensor, List[torch.Tensor]]): The negative conditional input. It can be a
single tensor or a list of tensors. It must have the same length as `cond_input`.
Returns:
Union[torch.Tensor, List[torch.Tensor]]: The prepared input.
"""
# we check if cond_input already has CFG applied, and split if it is the case.
if self._is_prepared_input(cond_input) and self.do_classifier_free_guidance:
return cond_input
if self._is_prepared_input(cond_input) and not self.do_classifier_free_guidance:
if isinstance(cond_input, list):
negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input])
else:
negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input)
if not self._is_prepared_input(cond_input) and self.do_classifier_free_guidance and negative_cond_input is None:
raise ValueError(
"`negative_cond_input` is required when cond_input does not already contains negative conditional input"
)
if isinstance(cond_input, (list, tuple)):
if not self.do_classifier_free_guidance:
return cond_input
if len(negative_cond_input) != len(cond_input):
raise ValueError("The length of negative_cond_input and cond_input must be the same.")
prepared_input = []
for neg_cond, cond in zip(negative_cond_input, cond_input):
if neg_cond.shape[0] != cond.shape[0]:
raise ValueError("The batch size of negative_cond_input and cond_input must be the same.")
prepared_input.append(torch.cat([neg_cond, cond], dim=0))
return prepared_input
elif isinstance(cond_input, torch.Tensor):
if not self.do_classifier_free_guidance:
return cond_input
else:
return torch.cat([negative_cond_input, cond_input], dim=0)
else:
raise ValueError(f"Unsupported input type: {type(cond_input)}")
def apply_guidance(
self,
model_output: torch.Tensor,
timestep: int = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if not self.do_classifier_free_guidance:
return model_output
if latents is None:
raise ValueError("APG requires `latents` to convert model output to denoised prediction (x0).")
sigma = self.scheduler.sigmas[self.scheduler.step_index]
noise_pred = latents - sigma * model_output
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = self.normalized_guidance(
noise_pred_text,
noise_pred_uncond,
self.guidance_scale,
self.momentum_buffer,
self.adaptive_projected_guidance_rescale_factor,
)
noise_pred = (latents - noise_pred) / sigma
if self.guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
return noise_pred
+35 -19
View File
@@ -13,7 +13,7 @@
# limitations under the License.
from contextlib import contextmanager, nullcontext
from typing import Dict, List, Optional, Set, Tuple
from typing import Dict, List, Optional, Set, Tuple, Union
import torch
@@ -55,9 +55,9 @@ class ModuleGroup:
parameters: Optional[List[torch.nn.Parameter]] = None,
buffers: Optional[List[torch.Tensor]] = None,
non_blocking: bool = False,
stream: Optional[torch.cuda.Stream] = None,
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
record_stream: Optional[bool] = False,
low_cpu_mem_usage=False,
low_cpu_mem_usage: bool = False,
onload_self: bool = True,
) -> None:
self.modules = modules
@@ -115,8 +115,13 @@ class ModuleGroup:
def onload_(self):
r"""Onloads the group of modules to the onload_device."""
context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream)
current_stream = torch.cuda.current_stream() if self.record_stream else None
torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
if hasattr(torch, "accelerator")
else torch.cuda
)
context = nullcontext() if self.stream is None else torch_accelerator_module.stream(self.stream)
current_stream = torch_accelerator_module.current_stream() if self.record_stream else None
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
@@ -162,9 +167,15 @@ class ModuleGroup:
def offload_(self):
r"""Offloads the group of modules to the offload_device."""
torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
if hasattr(torch, "accelerator")
else torch.cuda
)
if self.stream is not None:
if not self.record_stream:
torch.cuda.current_stream().synchronize()
torch_accelerator_module.current_stream().synchronize()
for group_module in self.modules:
for param in group_module.parameters():
param.data = self.cpu_param_dict[param]
@@ -429,8 +440,10 @@ def apply_group_offloading(
if use_stream:
if torch.cuda.is_available():
stream = torch.cuda.Stream()
elif hasattr(torch, "xpu") and torch.xpu.is_available():
stream = torch.Stream()
else:
raise ValueError("Using streams for data transfer requires a CUDA device.")
raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.")
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
@@ -468,7 +481,7 @@ def _apply_group_offloading_block_level(
offload_device: torch.device,
onload_device: torch.device,
non_blocking: bool,
stream: Optional[torch.cuda.Stream] = None,
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
) -> None:
@@ -486,7 +499,7 @@ def _apply_group_offloading_block_level(
non_blocking (`bool`):
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
and data transfer.
stream (`torch.cuda.Stream`, *optional*):
stream (`torch.cuda.Stream`or `torch.Stream`, *optional*):
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
for overlapping computation and data transfer.
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
@@ -498,6 +511,11 @@ def _apply_group_offloading_block_level(
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
the CPU memory is a bottleneck but may counteract the benefits of using streams.
"""
if stream is not None and num_blocks_per_group != 1:
logger.warning(
f"Using streams is only supported for num_blocks_per_group=1. Got {num_blocks_per_group=}. Setting it to 1."
)
num_blocks_per_group = 1
# Create module groups for ModuleList and Sequential blocks
modules_with_group_offloading = set()
@@ -521,7 +539,7 @@ def _apply_group_offloading_block_level(
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=stream is None,
onload_self=True,
)
matched_module_groups.append(group)
for j in range(i, i + len(current_modules)):
@@ -529,12 +547,8 @@ def _apply_group_offloading_block_level(
# Apply group offloading hooks to the module groups
for i, group in enumerate(matched_module_groups):
next_group = (
matched_module_groups[i + 1] if i + 1 < len(matched_module_groups) and stream is not None else None
)
for group_module in group.modules:
_apply_group_offloading_hook(group_module, group, next_group)
_apply_group_offloading_hook(group_module, group, None)
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
# when the forward pass of this module is called. This is because the top-level module is not
@@ -560,8 +574,10 @@ def _apply_group_offloading_block_level(
record_stream=False,
onload_self=True,
)
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
_apply_group_offloading_hook(module, unmatched_group, next_group)
if stream is None:
_apply_group_offloading_hook(module, unmatched_group, None)
else:
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
def _apply_group_offloading_leaf_level(
@@ -569,7 +585,7 @@ def _apply_group_offloading_leaf_level(
offload_device: torch.device,
onload_device: torch.device,
non_blocking: bool,
stream: Optional[torch.cuda.Stream] = None,
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
) -> None:
@@ -589,7 +605,7 @@ def _apply_group_offloading_leaf_level(
non_blocking (`bool`):
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
and data transfer.
stream (`torch.cuda.Stream`, *optional*):
stream (`torch.cuda.Stream` or `torch.Stream`, *optional*):
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
for overlapping computation and data transfer.
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
+6 -1
View File
@@ -116,6 +116,7 @@ class VaeImageProcessor(ConfigMixin):
vae_scale_factor: int = 8,
vae_latent_channels: int = 4,
resample: str = "lanczos",
reducing_gap: int = None,
do_normalize: bool = True,
do_binarize: bool = False,
do_convert_rgb: bool = False,
@@ -498,7 +499,11 @@ class VaeImageProcessor(ConfigMixin):
raise ValueError(f"Only PIL image input is supported for resize_mode {resize_mode}")
if isinstance(image, PIL.Image.Image):
if resize_mode == "default":
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
image = image.resize(
(width, height),
resample=PIL_INTERPOLATION[self.config.resample],
reducing_gap=self.config.reducing_gap,
)
elif resize_mode == "fill":
image = self._resize_and_fill(image, width, height)
elif resize_mode == "crop":
+4
View File
@@ -77,12 +77,14 @@ if is_torch_available():
"SanaLoraLoaderMixin",
"Lumina2LoraLoaderMixin",
"WanLoraLoaderMixin",
"HiDreamImageLoraLoaderMixin",
]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = [
"IPAdapterMixin",
"FluxIPAdapterMixin",
"SD3IPAdapterMixin",
"ModularIPAdapterMixin",
]
_import_structure["peft"] = ["PeftAdapterMixin"]
@@ -101,6 +103,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxIPAdapterMixin,
IPAdapterMixin,
SD3IPAdapterMixin,
ModularIPAdapterMixin,
)
from .lora_pipeline import (
AmusedLoraLoaderMixin,
@@ -108,6 +111,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CogVideoXLoraLoaderMixin,
CogView4LoraLoaderMixin,
FluxLoraLoaderMixin,
HiDreamImageLoraLoaderMixin,
HunyuanVideoLoraLoaderMixin,
LoraLoaderMixin,
LTXVideoLoraLoaderMixin,
+259
View File
@@ -356,6 +356,265 @@ class IPAdapterMixin:
)
self.unet.set_attn_processor(attn_procs)
class ModularIPAdapterMixin:
"""Mixin for handling IP Adapters."""
@validate_hf_hub_args
def load_ip_adapter(
self,
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
subfolder: Union[str, List[str]],
weight_name: Union[str, List[str]],
**kwargs,
):
"""
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
subfolder (`str` or `List[str]`):
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
list is passed, it should have the same length as `weight_name`.
weight_name (`str` or `List[str]`):
The name of the weight file to load. If a list is passed, it should have the same length as
`subfolder`.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
"""
# handle the list inputs for multiple IP Adapters
if not isinstance(weight_name, list):
weight_name = [weight_name]
if not isinstance(pretrained_model_name_or_path_or_dict, list):
pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
if len(pretrained_model_name_or_path_or_dict) == 1:
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
if not isinstance(subfolder, list):
subfolder = [subfolder]
if len(subfolder) == 1:
subfolder = subfolder * len(weight_name)
if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
if len(weight_name) != len(subfolder):
raise ValueError("`weight_name` and `subfolder` must have the same length.")
# Load the main state dict first.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dicts = []
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
pretrained_model_name_or_path_or_dict, weight_name, subfolder
):
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
if weight_name.endswith(".safetensors"):
state_dict = {"image_proj": {}, "ip_adapter": {}}
with safe_open(model_file, framework="pt", device="cpu") as f:
for key in f.keys():
if key.startswith("image_proj."):
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
else:
state_dict = load_state_dict(model_file)
else:
state_dict = pretrained_model_name_or_path_or_dict
keys = list(state_dict.keys())
if "image_proj" not in keys and "ip_adapter" not in keys:
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
state_dicts.append(state_dict)
# create feature extractor if it has not been registered to the pipeline yet
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
# FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224
default_clip_size = 224
clip_image_size = (
self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size
)
feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
unet_name = getattr(self, "unet_name", "unet")
unet = getattr(self, unet_name)
unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
extra_loras = unet._load_ip_adapter_loras(state_dicts)
if extra_loras != {}:
if not USE_PEFT_BACKEND:
logger.warning("PEFT backend is required to load these weights.")
else:
# apply the IP Adapter Face ID LoRA weights
peft_config = getattr(unet, "peft_config", {})
for k, lora in extra_loras.items():
if f"faceid_{k}" not in peft_config:
self.load_lora_weights(lora, adapter_name=f"faceid_{k}")
self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0])
def set_ip_adapter_scale(self, scale):
"""
Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
granular control over each IP-Adapter behavior. A config can be a float or a dictionary.
Example:
```py
# To use original IP-Adapter
scale = 1.0
pipeline.set_ip_adapter_scale(scale)
# To use style block only
scale = {
"up": {"block_0": [0.0, 1.0, 0.0]},
}
pipeline.set_ip_adapter_scale(scale)
# To use style+layout blocks
scale = {
"down": {"block_2": [0.0, 1.0]},
"up": {"block_0": [0.0, 1.0, 0.0]},
}
pipeline.set_ip_adapter_scale(scale)
# To use style and layout from 2 reference images
scales = [{"down": {"block_2": [0.0, 1.0]}}, {"up": {"block_0": [0.0, 1.0, 0.0]}}]
pipeline.set_ip_adapter_scale(scales)
```
"""
unet_name = getattr(self, "unet_name", "unet")
unet = getattr(self, unet_name)
if not isinstance(scale, list):
scale = [scale]
scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)
for attn_name, attn_processor in unet.attn_processors.items():
if isinstance(
attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
):
if len(scale_configs) != len(attn_processor.scale):
raise ValueError(
f"Cannot assign {len(scale_configs)} scale_configs to "
f"{len(attn_processor.scale)} IP-Adapter."
)
elif len(scale_configs) == 1:
scale_configs = scale_configs * len(attn_processor.scale)
for i, scale_config in enumerate(scale_configs):
if isinstance(scale_config, dict):
for k, s in scale_config.items():
if attn_name.startswith(k):
attn_processor.scale[i] = s
else:
attn_processor.scale[i] = scale_config
def unload_ip_adapter(self):
"""
Unloads the IP Adapter weights
Examples:
```python
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
>>> pipeline.unload_ip_adapter()
>>> ...
```
"""
# remove hidden encoder
if self.unet is None:
return
self.unet.encoder_hid_proj = None
self.unet.config.encoder_hid_dim_type = None
# Kolors: restore `encoder_hid_proj` with `text_encoder_hid_proj`
if hasattr(self.unet, "text_encoder_hid_proj") and self.unet.text_encoder_hid_proj is not None:
self.unet.encoder_hid_proj = self.unet.text_encoder_hid_proj
self.unet.text_encoder_hid_proj = None
self.unet.config.encoder_hid_dim_type = "text_proj"
# restore original Unet attention processors layers
attn_procs = {}
for name, value in self.unet.attn_processors.items():
attn_processor_class = (
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
)
attn_procs[name] = (
attn_processor_class
if isinstance(
value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
)
else value.__class__()
)
self.unet.set_attn_processor(attn_procs)
class FluxIPAdapterMixin:
"""Mixin for handling Flux IP Adapters."""
+6 -3
View File
@@ -441,7 +441,7 @@ def _func_optionally_disable_offloading(_pipeline):
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None and _pipeline.hf_device_map is None:
if _pipeline is not None and hasattr(_pipeline, "hf_device_map") and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:
@@ -491,6 +491,7 @@ class LoraBaseMixin:
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
"""
return _func_optionally_disable_offloading(_pipeline=_pipeline)
@classmethod
@@ -713,8 +714,10 @@ class LoraBaseMixin:
# Decompose weights into weights for denoiser and text encoders.
_component_adapter_weights = {}
for component in self._lora_loadable_modules:
model = getattr(self, component)
model = getattr(self, component, None)
if model is None:
logger.warning(f"Model {component} not found in pipeline.")
continue
for adapter_name, weights in zip(adapter_names, adapter_weights):
if isinstance(weights, dict):
component_adapter_weights = weights.pop(component, None)
@@ -433,7 +433,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
if not is_sparse:
# down_weight is copied to each split
ait_sd.update({k: down_weight for k in ait_down_keys})
ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
# up_weight is split to each split
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
@@ -923,7 +923,7 @@ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
# down_weight is copied to each split
ait_sd.update({k: down_weight for k in ait_down_keys})
ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
# up_weight is split to each split
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
+358 -35
View File
@@ -91,18 +91,19 @@ def _maybe_dequantize_weight_for_expanded_lora(model, module):
)
weight_on_cpu = False
if not module.weight.is_cuda:
if module.weight.device.type == "cpu":
weight_on_cpu = True
device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
if is_bnb_4bit_quantized:
module_weight = dequantize_bnb_weight(
module.weight.cuda() if weight_on_cpu else module.weight,
module.weight.to(device) if weight_on_cpu else module.weight,
state=module.weight.quant_state,
dtype=model.dtype,
).data
elif is_gguf_quantized:
module_weight = dequantize_gguf_tensor(
module.weight.cuda() if weight_on_cpu else module.weight,
module.weight.to(device) if weight_on_cpu else module.weight,
)
module_weight = module_weight.to(model.dtype)
else:
@@ -635,7 +636,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config,
unet_config=self.unet.config if hasattr(self, "unet") else None,
**kwargs,
)
@@ -643,37 +644,40 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_unet(
state_dict,
network_alphas=network_alphas,
unet=self.unet,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
self.load_lora_into_text_encoder(
state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder,
prefix=self.text_encoder_name,
lora_scale=self.lora_scale,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
self.load_lora_into_text_encoder(
state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder_2,
prefix=f"{self.text_encoder_name}_2",
lora_scale=self.lora_scale,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
if hasattr(self, "unet"):
self.load_lora_into_unet(
state_dict,
network_alphas=network_alphas,
unet=self.unet,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
if hasattr(self, "text_encoder"):
self.load_lora_into_text_encoder(
state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder,
prefix=self.text_encoder_name,
lora_scale=self.lora_scale,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
if hasattr(self, "text_encoder_2"):
self.load_lora_into_text_encoder(
state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder_2,
prefix=f"{self.text_encoder_name}_2",
lora_scale=self.lora_scale,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
@validate_hf_hub_args
@@ -5360,6 +5364,325 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
super().unfuse_lora(components=components, **kwargs)
class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`HiDreamImageTransformer2DModel`]. Specific to [`HiDreamImagePipeline`].
"""
_lora_loadable_modules = ["transformer"]
transformer_name = TRANSFORMER_NAME
@classmethod
@validate_hf_hub_args
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
Return state dict for lora weights and the network alphas.
<Tip warning={true}>
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
This function is experimental and might change in the future.
</Tip>
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dict = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
)
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
return state_dict
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
dict is loaded into `self.transformer`.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HiDreamImageTransformer2DModel
def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
Parameters:
state_dict (`dict`):
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
encoder lora layers.
transformer (`HiDreamImageTransformer2DModel`):
The Transformer model to load the LoRA layers into.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# Load the layers corresponding to transformer.
logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
):
r"""
Save the LoRA parameters corresponding to the UNet and text encoder.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to save LoRA parameters to. Will be created if it doesn't exist.
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `transformer`.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful during distributed training and you
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
process to avoid race conditions.
save_function (`Callable`):
The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
"""
state_dict = {}
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if transformer_lora_layers:
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
<Tip warning={true}>
This is an experimental API.
</Tip>
Args:
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
lora_scale (`float`, defaults to 1.0):
Controls how much to influence the outputs with the LoRA parameters.
safe_fusing (`bool`, defaults to `False`):
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
adapter_names (`List[str]`, *optional*):
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
Example:
```py
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipeline.fuse_lora(lora_scale=0.7)
```
"""
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
<Tip warning={true}>
This is an experimental API.
</Tip>
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components, **kwargs)
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
def __init__(self, *args, **kwargs):
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
+1
View File
@@ -56,6 +56,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
"Lumina2Transformer2DModel": lambda model_cls, weights: weights,
"WanTransformer3DModel": lambda model_cls, weights: weights,
"CogView4Transformer2DModel": lambda model_cls, weights: weights,
"HiDreamImageTransformer2DModel": lambda model_cls, weights: weights,
}
+1
View File
@@ -408,6 +408,7 @@ class UNet2DConditionLoadersMixin:
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
"""
return _func_optionally_disable_offloading(_pipeline=_pipeline)
def save_attn_procs(
@@ -20,12 +20,12 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...models.attention_processor import AttentionProcessor
from ...models.modeling_utils import ModelMixin
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
from ..attention_processor import AttentionProcessor
from ..controlnets.controlnet import ControlNetConditioningEmbedding, zero_module
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
@@ -4,9 +4,9 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from ...models.controlnets.controlnet import ControlNetModel, ControlNetOutput
from ...models.modeling_utils import ModelMixin
from ...utils import logging
from ..controlnets.controlnet import ControlNetModel, ControlNetOutput
from ..modeling_utils import ModelMixin
logger = logging.get_logger(__name__)
@@ -4,10 +4,10 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from ...models.controlnets.controlnet import ControlNetOutput
from ...models.controlnets.controlnet_union import ControlNetUnionModel
from ...models.modeling_utils import ModelMixin
from ...utils import logging
from ..controlnets.controlnet import ControlNetOutput
from ..controlnets.controlnet_union import ControlNetUnionModel
from ..modeling_utils import ModelMixin
logger = logging.get_logger(__name__)
+1 -1
View File
@@ -286,7 +286,7 @@ class KDownsample2D(nn.Module):
class CogVideoXDownsample3D(nn.Module):
# Todo: Wait for paper relase.
# Todo: Wait for paper release.
r"""
A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI
+1 -47
View File
@@ -18,7 +18,7 @@ import importlib
import inspect
import os
from array import array
from collections import OrderedDict, defaultdict
from collections import OrderedDict
from pathlib import Path
from typing import Dict, List, Optional, Union
from zipfile import is_zipfile
@@ -38,7 +38,6 @@ from ..utils import (
_get_model_file,
deprecate,
is_accelerate_available,
is_accelerator_device,
is_gguf_available,
is_torch_available,
is_torch_version,
@@ -305,51 +304,6 @@ def load_model_dict_into_meta(
return offload_index, state_dict_index
# Taken from
# https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5852C1-L5861C26
def _expand_device_map(device_map, param_names):
new_device_map = {}
for module, device in device_map.items():
new_device_map.update(
{p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
)
return new_device_map
# Adapted from https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5874
# We don't incorporate the `tp_plan` stuff as we don't support it yet.
def _caching_allocator_warmup(model, device_map: Dict, factor=2) -> Dict:
# Remove disk, cpu and meta devices, and cast to proper torch.device
accelerator_device_map = {
param: torch.device(device) for param, device in device_map.items() if is_accelerator_device(device)
}
if not len(accelerator_device_map):
return
total_byte_count = defaultdict(lambda: 0)
for param_name, device in accelerator_device_map.items():
param = model.get_parameter_or_buffer(param_name)
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
param_byte_count = param.numel() * param.element_size()
total_byte_count[device] += param_byte_count
# This will kick off the caching allocator to avoid having to Malloc afterwards
for device, byte_count in total_byte_count.items():
if device.type == "cuda":
index = device.index if device.index is not None else torch.cuda.current_device()
device_memory = torch.cuda.mem_get_info(index)[0]
# Allow up to (max device memory - 1.2 GiB) in resource-constrained hardware configurations. Trying to reserve more
# than that amount might sometimes lead to unecesary cuda OOM, if the last parameter to be loaded on the device is large,
# and the remaining reserved memory portion is smaller than the param size -> torch will then try to fully re-allocate all
# the param size, instead of using the remaining reserved part, and allocating only the difference, which can lead
# to OOM. See https://github.com/huggingface/transformers/issues/37436#issuecomment-2808982161 for more details.
# Note that we use an absolute value instead of device proportion here, as a 8GiB device could still allocate too much
# if using e.g. 90% of device size, while a 140GiB device would allocate too little
byte_count = min(byte_count, max(0, int(device_memory - 1.2 * 1024**3)))
# Allocate memory
_ = torch.empty(byte_count // factor, dtype=torch.float16, device=device, requires_grad=False)
def _load_state_dict_into_model(
model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False
) -> List[str]:
-25
View File
@@ -63,9 +63,7 @@ from ..utils.hub_utils import (
populate_model_card,
)
from .model_loading_utils import (
_caching_allocator_warmup,
_determine_device_map,
_expand_device_map,
_fetch_index_file,
_fetch_index_file_legacy,
_load_state_dict_into_model,
@@ -1376,24 +1374,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
else:
return super().float(*args)
# Taken from `transformers`.
# https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5351C5-L5365C81
def get_parameter_or_buffer(self, target: str):
"""
Return the parameter or buffer given by `target` if it exists, otherwise throw an error. This combines
`get_parameter()` and `get_buffer()` in a single handy function. Note that it only work if `target` is a leaf
of the model.
"""
try:
return self.get_parameter(target)
except AttributeError:
pass
try:
return self.get_buffer(target)
except AttributeError:
pass
raise AttributeError(f"`{target}` is neither a parameter nor a buffer.")
@classmethod
def _load_pretrained_model(
cls,
@@ -1430,11 +1410,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
assign_to_params_buffers = None
error_msgs = []
# Optionally, warmup cuda to load the weights much faster on devices
if device_map is not None:
expanded_device_map = _expand_device_map(device_map, expected_keys)
_caching_allocator_warmup(model, expanded_device_map, factor=2 if hf_quantizer is None else 4)
# Deal with offload
if device_map is not None and "disk" in device_map.values():
if offload_folder is None:
@@ -18,10 +18,9 @@ import torch
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.embeddings import PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
from ..attention import BasicTransformerBlock
from ..cache_utils import CacheMixin
from ..embeddings import PatchEmbed
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle
@@ -21,16 +21,12 @@ import torch.nn as nn
import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.attention import FeedForward
from ...models.attention_processor import (
Attention,
AttentionProcessor,
StableAudioAttnProcessor2_0,
)
from ...models.modeling_utils import ModelMixin
from ...models.transformers.transformer_2d import Transformer2DModelOutput
from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention, AttentionProcessor, StableAudioAttnProcessor2_0
from ..modeling_utils import ModelMixin
from ..transformers.transformer_2d import Transformer2DModelOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -19,18 +19,13 @@ import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.attention import FeedForward
from ...models.attention_processor import (
Attention,
AttentionProcessor,
CogVideoXAttnProcessor2_0,
)
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous
from ...utils import logging
from ..attention import FeedForward
from ..attention_processor import Attention, AttentionProcessor, CogVideoXAttnProcessor2_0
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, CogView3PlusAdaLayerNormZeroTextImage
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -73,8 +73,9 @@ class CogView4AdaLayerNormZero(nn.Module):
def forward(
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
norm_hidden_states = self.norm(hidden_states)
norm_encoder_hidden_states = self.norm_context(encoder_hidden_states)
dtype = hidden_states.dtype
norm_hidden_states = self.norm(hidden_states).to(dtype=dtype)
norm_encoder_hidden_states = self.norm_context(encoder_hidden_states).to(dtype=dtype)
emb = self.linear(temb)
(
@@ -111,8 +112,11 @@ class CogView4AdaLayerNormZero(nn.Module):
class CogView4AttnProcessor:
"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
Processor for implementing scaled dot-product attention for the CogView4 model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
The processor supports passing an attention mask for text tokens. The attention mask should have shape (batch_size,
text_seq_length) where 1 indicates a non-padded token and 0 indicates a padded token.
"""
def __init__(self):
@@ -125,8 +129,10 @@ class CogView4AttnProcessor:
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
dtype = encoder_hidden_states.dtype
batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
batch_size, image_seq_length, embed_dim = hidden_states.shape
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
@@ -142,9 +148,9 @@ class CogView4AttnProcessor:
# 2. QK normalization
if attn.norm_q is not None:
query = attn.norm_q(query)
query = attn.norm_q(query).to(dtype=dtype)
if attn.norm_k is not None:
key = attn.norm_k(key)
key = attn.norm_k(key).to(dtype=dtype)
# 3. Rotational positional embeddings applied to latent stream
if image_rotary_emb is not None:
@@ -159,13 +165,14 @@ class CogView4AttnProcessor:
# 4. Attention
if attention_mask is not None:
text_attention_mask = attention_mask.float().to(query.device)
actual_text_seq_length = text_attention_mask.size(1)
new_attention_mask = torch.zeros((batch_size, text_seq_length + image_seq_length), device=query.device)
new_attention_mask[:, :actual_text_seq_length] = text_attention_mask
new_attention_mask = new_attention_mask.unsqueeze(2)
attention_mask_matrix = new_attention_mask @ new_attention_mask.transpose(1, 2)
attention_mask = (attention_mask_matrix > 0).unsqueeze(1).to(query.dtype)
text_attn_mask = attention_mask
assert text_attn_mask.dim() == 2, "the shape of text_attn_mask should be (batch_size, text_seq_length)"
text_attn_mask = text_attn_mask.float().to(query.device)
mix_attn_mask = torch.ones((batch_size, text_seq_length + image_seq_length), device=query.device)
mix_attn_mask[:, :text_seq_length] = text_attn_mask
mix_attn_mask = mix_attn_mask.unsqueeze(2)
attn_mask_matrix = mix_attn_mask @ mix_attn_mask.transpose(1, 2)
attention_mask = (attn_mask_matrix > 0).unsqueeze(1).to(query.dtype)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
@@ -183,9 +190,276 @@ class CogView4AttnProcessor:
return hidden_states, encoder_hidden_states
class CogView4TrainingAttnProcessor:
"""
Training Processor for implementing scaled dot-product attention for the CogView4 model. It applies a rotary
embedding on query and key vectors, but does not include spatial normalization.
This processor differs from CogView4AttnProcessor in several important ways:
1. It supports attention masking with variable sequence lengths for multi-resolution training
2. It unpacks and repacks sequences for efficient training with variable sequence lengths when batch_flag is
provided
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogView4AttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
latent_attn_mask: Optional[torch.Tensor] = None,
text_attn_mask: Optional[torch.Tensor] = None,
batch_flag: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
attn (`Attention`):
The attention module.
hidden_states (`torch.Tensor`):
The input hidden states.
encoder_hidden_states (`torch.Tensor`):
The encoder hidden states for cross-attention.
latent_attn_mask (`torch.Tensor`, *optional*):
Mask for latent tokens where 0 indicates pad token and 1 indicates non-pad token. If None, full
attention is used for all latent tokens. Note: the shape of latent_attn_mask is (batch_size,
num_latent_tokens).
text_attn_mask (`torch.Tensor`, *optional*):
Mask for text tokens where 0 indicates pad token and 1 indicates non-pad token. If None, full attention
is used for all text tokens.
batch_flag (`torch.Tensor`, *optional*):
Values from 0 to n-1 indicating which samples belong to the same batch. Samples with the same
batch_flag are packed together. Example: [0, 1, 1, 2, 2] means sample 0 forms batch0, samples 1-2 form
batch1, and samples 3-4 form batch2. If None, no packing is used.
image_rotary_emb (`Tuple[torch.Tensor, torch.Tensor]` or `list[Tuple[torch.Tensor, torch.Tensor]]`, *optional*):
The rotary embedding for the image part of the input.
Returns:
`Tuple[torch.Tensor, torch.Tensor]`: The processed hidden states for both image and text streams.
"""
# Get dimensions and device info
batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
batch_size, image_seq_length, embed_dim = hidden_states.shape
dtype = encoder_hidden_states.dtype
device = encoder_hidden_states.device
latent_hidden_states = hidden_states
# Combine text and image streams for joint processing
mixed_hidden_states = torch.cat([encoder_hidden_states, latent_hidden_states], dim=1)
# 1. Construct attention mask and maybe packing input
# Create default masks if not provided
if text_attn_mask is None:
text_attn_mask = torch.ones((batch_size, text_seq_length), dtype=torch.int32, device=device)
if latent_attn_mask is None:
latent_attn_mask = torch.ones((batch_size, image_seq_length), dtype=torch.int32, device=device)
# Validate mask shapes and types
assert text_attn_mask.dim() == 2, "the shape of text_attn_mask should be (batch_size, text_seq_length)"
assert text_attn_mask.dtype == torch.int32, "the dtype of text_attn_mask should be torch.int32"
assert latent_attn_mask.dim() == 2, "the shape of latent_attn_mask should be (batch_size, num_latent_tokens)"
assert latent_attn_mask.dtype == torch.int32, "the dtype of latent_attn_mask should be torch.int32"
# Create combined mask for text and image tokens
mixed_attn_mask = torch.ones(
(batch_size, text_seq_length + image_seq_length), dtype=torch.int32, device=device
)
mixed_attn_mask[:, :text_seq_length] = text_attn_mask
mixed_attn_mask[:, text_seq_length:] = latent_attn_mask
# Convert mask to attention matrix format (where 1 means attend, 0 means don't attend)
mixed_attn_mask_input = mixed_attn_mask.unsqueeze(2).to(dtype=dtype)
attn_mask_matrix = mixed_attn_mask_input @ mixed_attn_mask_input.transpose(1, 2)
# Handle batch packing if enabled
if batch_flag is not None:
assert batch_flag.dim() == 1
# Determine packed batch size based on batch_flag
packing_batch_size = torch.max(batch_flag).item() + 1
# Calculate actual sequence lengths for each sample based on masks
text_seq_length = torch.sum(text_attn_mask, dim=1)
latent_seq_length = torch.sum(latent_attn_mask, dim=1)
mixed_seq_length = text_seq_length + latent_seq_length
# Calculate packed sequence lengths for each packed batch
mixed_seq_length_packed = [
torch.sum(mixed_attn_mask[batch_flag == batch_idx]).item() for batch_idx in range(packing_batch_size)
]
assert len(mixed_seq_length_packed) == packing_batch_size
# Pack sequences by removing padding tokens
mixed_attn_mask_flatten = mixed_attn_mask.flatten(0, 1)
mixed_hidden_states_flatten = mixed_hidden_states.flatten(0, 1)
mixed_hidden_states_unpad = mixed_hidden_states_flatten[mixed_attn_mask_flatten == 1]
assert torch.sum(mixed_seq_length) == mixed_hidden_states_unpad.shape[0]
# Split the unpadded sequence into packed batches
mixed_hidden_states_packed = torch.split(mixed_hidden_states_unpad, mixed_seq_length_packed)
# Re-pad to create packed batches with right-side padding
mixed_hidden_states_packed_padded = torch.nn.utils.rnn.pad_sequence(
mixed_hidden_states_packed,
batch_first=True,
padding_value=0.0,
padding_side="right",
)
# Create attention mask for packed batches
l = mixed_hidden_states_packed_padded.shape[1]
attn_mask_matrix = torch.zeros(
(packing_batch_size, l, l),
dtype=dtype,
device=device,
)
# Fill attention mask with block diagonal matrices
# This ensures that tokens can only attend to other tokens within the same original sample
for idx, mask in enumerate(attn_mask_matrix):
seq_lengths = mixed_seq_length[batch_flag == idx]
offset = 0
for length in seq_lengths:
# Create a block of 1s for each sample in the packed batch
mask[offset : offset + length, offset : offset + length] = 1
offset += length
attn_mask_matrix = attn_mask_matrix.to(dtype=torch.bool)
attn_mask_matrix = attn_mask_matrix.unsqueeze(1) # Add attention head dim
attention_mask = attn_mask_matrix
# Prepare hidden states for attention computation
if batch_flag is None:
# If no packing, just combine text and image tokens
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
else:
# If packing, use the packed sequence
hidden_states = mixed_hidden_states_packed_padded
# 2. QKV projections - convert hidden states to query, key, value
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
# Reshape for multi-head attention: [batch, seq_len, heads*dim] -> [batch, heads, seq_len, dim]
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
# 3. QK normalization - apply layer norm to queries and keys if configured
if attn.norm_q is not None:
query = attn.norm_q(query).to(dtype=dtype)
if attn.norm_k is not None:
key = attn.norm_k(key).to(dtype=dtype)
# 4. Apply rotary positional embeddings to image tokens only
if image_rotary_emb is not None:
from ..embeddings import apply_rotary_emb
if batch_flag is None:
# Apply RoPE only to image tokens (after text tokens)
query[:, :, text_seq_length:, :] = apply_rotary_emb(
query[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
)
key[:, :, text_seq_length:, :] = apply_rotary_emb(
key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
)
else:
# For packed batches, need to carefully apply RoPE to appropriate tokens
assert query.shape[0] == packing_batch_size
assert key.shape[0] == packing_batch_size
assert len(image_rotary_emb) == batch_size
rope_idx = 0
for idx in range(packing_batch_size):
offset = 0
# Get text and image sequence lengths for samples in this packed batch
text_seq_length_bi = text_seq_length[batch_flag == idx]
latent_seq_length_bi = latent_seq_length[batch_flag == idx]
# Apply RoPE to each image segment in the packed sequence
for tlen, llen in zip(text_seq_length_bi, latent_seq_length_bi):
mlen = tlen + llen
# Apply RoPE only to image tokens (after text tokens)
query[idx, :, offset + tlen : offset + mlen, :] = apply_rotary_emb(
query[idx, :, offset + tlen : offset + mlen, :],
image_rotary_emb[rope_idx],
use_real_unbind_dim=-2,
)
key[idx, :, offset + tlen : offset + mlen, :] = apply_rotary_emb(
key[idx, :, offset + tlen : offset + mlen, :],
image_rotary_emb[rope_idx],
use_real_unbind_dim=-2,
)
offset += mlen
rope_idx += 1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
# Reshape back: [batch, heads, seq_len, dim] -> [batch, seq_len, heads*dim]
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.type_as(query)
# 5. Output projection - project attention output to model dimension
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
# Split the output back into text and image streams
if batch_flag is None:
# Simple split for non-packed case
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
else:
# For packed case: need to unpack, split text/image, then restore to original shapes
# First, unpad the sequence based on the packed sequence lengths
hidden_states_unpad = torch.nn.utils.rnn.unpad_sequence(
hidden_states,
lengths=torch.tensor(mixed_seq_length_packed),
batch_first=True,
)
# Concatenate all unpadded sequences
hidden_states_flatten = torch.cat(hidden_states_unpad, dim=0)
# Split by original sample sequence lengths
hidden_states_unpack = torch.split(hidden_states_flatten, mixed_seq_length.tolist())
assert len(hidden_states_unpack) == batch_size
# Further split each sample's sequence into text and image parts
hidden_states_unpack = [
torch.split(h, [tlen, llen])
for h, tlen, llen in zip(hidden_states_unpack, text_seq_length, latent_seq_length)
]
# Separate text and image sequences
encoder_hidden_states_unpad = [h[0] for h in hidden_states_unpack]
hidden_states_unpad = [h[1] for h in hidden_states_unpack]
# Update the original tensors with the processed values, respecting the attention masks
for idx in range(batch_size):
# Place unpacked text tokens back in the encoder_hidden_states tensor
encoder_hidden_states[idx][text_attn_mask[idx] == 1] = encoder_hidden_states_unpad[idx]
# Place unpacked image tokens back in the latent_hidden_states tensor
latent_hidden_states[idx][latent_attn_mask[idx] == 1] = hidden_states_unpad[idx]
# Update the output hidden states
hidden_states = latent_hidden_states
return hidden_states, encoder_hidden_states
class CogView4TransformerBlock(nn.Module):
def __init__(
self, dim: int = 2560, num_attention_heads: int = 64, attention_head_dim: int = 40, time_embed_dim: int = 512
self,
dim: int = 2560,
num_attention_heads: int = 64,
attention_head_dim: int = 40,
time_embed_dim: int = 512,
) -> None:
super().__init__()
@@ -213,9 +487,11 @@ class CogView4TransformerBlock(nn.Module):
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
image_rotary_emb: Optional[
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
] = None,
attention_mask: Optional[Dict[str, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
# 1. Timestep conditioning
(
@@ -232,12 +508,14 @@ class CogView4TransformerBlock(nn.Module):
) = self.norm1(hidden_states, encoder_hidden_states, temb)
# 2. Attention
if attention_kwargs is None:
attention_kwargs = {}
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
**kwargs,
**attention_kwargs,
)
hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1)
encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1)
@@ -402,7 +680,9 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
image_rotary_emb: Optional[
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
] = None,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
@@ -422,7 +702,8 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
batch_size, num_channels, height, width = hidden_states.shape
# 1. RoPE
image_rotary_emb = self.rope(hidden_states)
if image_rotary_emb is None:
image_rotary_emb = self.rope(hidden_states)
# 2. Patch & Timestep embeddings
p = self.config.patch_size
@@ -438,11 +719,22 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
attention_mask,
attention_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
attention_mask,
attention_kwargs,
)
# 4. Output norm & projection
@@ -21,22 +21,22 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...models.attention import FeedForward
from ...models.attention_processor import (
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.import_utils import is_torch_npu_available
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import (
Attention,
AttentionProcessor,
FluxAttnProcessor2_0,
FluxAttnProcessor2_0_NPU,
FusedFluxAttnProcessor2_0,
)
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.import_utils import is_torch_npu_available
from ...utils.torch_utils import maybe_allow_in_graph
from ..cache_utils import CacheMixin
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -275,7 +275,14 @@ class HiDreamAttnProcessor:
# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
class MoEGate(nn.Module):
def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux_loss_alpha=0.01):
def __init__(
self,
embed_dim,
num_routed_experts=4,
num_activated_experts=2,
aux_loss_alpha=0.01,
_force_inference_output=False,
):
super().__init__()
self.top_k = num_activated_experts
self.n_routed_experts = num_routed_experts
@@ -289,9 +296,10 @@ class MoEGate(nn.Module):
self.gating_dim = embed_dim
self.weight = nn.Parameter(torch.randn(self.n_routed_experts, self.gating_dim) / embed_dim**0.5)
self._force_inference_output = _force_inference_output
def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
# print(bsz, seq_len, h)
### compute gating score
hidden_states = hidden_states.view(-1, h)
logits = F.linear(hidden_states, self.weight, None)
@@ -309,7 +317,7 @@ class MoEGate(nn.Module):
topk_weight = topk_weight / denominator
### expert-level computation auxiliary loss
if self.training and self.alpha > 0.0:
if self.training and self.alpha > 0.0 and not self._force_inference_output:
scores_for_aux = scores
aux_topk = self.top_k
# always compute aux loss based on the naive greedy topk method
@@ -341,14 +349,19 @@ class MOEFeedForwardSwiGLU(nn.Module):
hidden_dim: int,
num_routed_experts: int,
num_activated_experts: int,
_force_inference_output: bool = False,
):
super().__init__()
self.shared_experts = HiDreamImageFeedForwardSwiGLU(dim, hidden_dim // 2)
self.experts = nn.ModuleList(
[HiDreamImageFeedForwardSwiGLU(dim, hidden_dim) for i in range(num_routed_experts)]
)
self._force_inference_output = _force_inference_output
self.gate = MoEGate(
embed_dim=dim, num_routed_experts=num_routed_experts, num_activated_experts=num_activated_experts
embed_dim=dim,
num_routed_experts=num_routed_experts,
num_activated_experts=num_activated_experts,
_force_inference_output=_force_inference_output,
)
self.num_activated_experts = num_activated_experts
@@ -359,7 +372,7 @@ class MOEFeedForwardSwiGLU(nn.Module):
topk_idx, topk_weight, aux_loss = self.gate(x)
x = x.view(-1, x.shape[-1])
flat_topk_idx = topk_idx.view(-1)
if self.training:
if self.training and not self._force_inference_output:
x = x.repeat_interleave(self.num_activated_experts, dim=0)
y = torch.empty_like(x, dtype=wtype)
for i, expert in enumerate(self.experts):
@@ -413,6 +426,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
attention_head_dim: int,
num_routed_experts: int = 4,
num_activated_experts: int = 2,
_force_inference_output: bool = False,
):
super().__init__()
self.num_attention_heads = num_attention_heads
@@ -436,6 +450,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
hidden_dim=4 * dim,
num_routed_experts=num_routed_experts,
num_activated_experts=num_activated_experts,
_force_inference_output=_force_inference_output,
)
else:
self.ff_i = HiDreamImageFeedForwardSwiGLU(dim=dim, hidden_dim=4 * dim)
@@ -480,6 +495,7 @@ class HiDreamImageTransformerBlock(nn.Module):
attention_head_dim: int,
num_routed_experts: int = 4,
num_activated_experts: int = 2,
_force_inference_output: bool = False,
):
super().__init__()
self.num_attention_heads = num_attention_heads
@@ -504,6 +520,7 @@ class HiDreamImageTransformerBlock(nn.Module):
hidden_dim=4 * dim,
num_routed_experts=num_routed_experts,
num_activated_experts=num_activated_experts,
_force_inference_output=_force_inference_output,
)
else:
self.ff_i = HiDreamImageFeedForwardSwiGLU(dim=dim, hidden_dim=4 * dim)
@@ -606,6 +623,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
axes_dims_rope: Tuple[int, int] = (32, 32),
max_resolution: Tuple[int, int] = (128, 128),
llama_layers: List[int] = None,
force_inference_output: bool = False,
):
super().__init__()
self.out_channels = out_channels or in_channels
@@ -629,6 +647,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
attention_head_dim=attention_head_dim,
num_routed_experts=num_routed_experts,
num_activated_experts=num_activated_experts,
_force_inference_output=force_inference_output,
)
)
for _ in range(num_layers)
@@ -644,6 +663,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
attention_head_dim=attention_head_dim,
num_routed_experts=num_routed_experts,
num_activated_experts=num_activated_experts,
_force_inference_output=force_inference_output,
)
)
for _ in range(num_single_layers)
@@ -662,7 +682,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
self.gradient_checkpointing = False
def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_training: bool) -> List[torch.Tensor]:
if is_training:
if is_training and not self.config.force_inference_output:
B, S, F = x.shape
C = F // (self.config.patch_size * self.config.patch_size)
x = (
@@ -771,7 +791,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
if encoder_hidden_states is not None:
deprecation_message = "The `encoder_hidden_states` argument is deprecated. Please use `encoder_hidden_states_t5` and `encoder_hidden_states_llama3` instead."
deprecate("encoder_hidden_states", "0.34.0", deprecation_message)
deprecate("encoder_hidden_states", "0.35.0", deprecation_message)
encoder_hidden_states_t5 = encoder_hidden_states[0]
encoder_hidden_states_llama3 = encoder_hidden_states[1]
@@ -779,7 +799,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
deprecation_message = (
"Passing `img_ids` and `img_sizes` with unpachified `hidden_states` is deprecated and will be ignored."
)
deprecate("img_ids", "0.34.0", deprecation_message)
deprecate("img_ids", "0.35.0", deprecation_message)
if hidden_states_masks is not None and (img_ids is None or img_sizes is None):
raise ValueError("if `hidden_states_masks` is passed, `img_ids` and `img_sizes` must also be passed.")
@@ -1068,17 +1068,15 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
latent_sequence_length = hidden_states.shape[1]
condition_sequence_length = encoder_hidden_states.shape[1]
sequence_length = latent_sequence_length + condition_sequence_length
attention_mask = torch.zeros(
attention_mask = torch.ones(
batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool
) # [B, N]
effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,]
effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
for i in range(batch_size):
attention_mask[i, : effective_sequence_length[i]] = True
# [B, 1, 1, N], for broadcasting across attention heads
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
indices = torch.arange(sequence_length, device=hidden_states.device).unsqueeze(0) # [1, N]
mask_indices = indices >= effective_sequence_length.unsqueeze(1) # [B, N]
attention_mask = attention_mask.masked_fill(mask_indices, False)
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, N]
# 4. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing:
@@ -18,19 +18,19 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
from ...models.attention import FeedForward, JointTransformerBlock
from ...models.attention_processor import (
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward, JointTransformerBlock
from ..attention_processor import (
Attention,
AttentionProcessor,
FusedJointAttnProcessor2_0,
JointAttnProcessor2_0,
)
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+1 -1
View File
@@ -358,7 +358,7 @@ class KUpsample2D(nn.Module):
class CogVideoXUpsample3D(nn.Module):
r"""
A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase.
A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper release.
Args:
in_channels (`int`):
+6
View File
@@ -47,6 +47,7 @@ else:
"AutoPipelineForInpainting",
"AutoPipelineForText2Image",
]
_import_structure["modular_pipeline"] = ["ModularPipeline"]
_import_structure["consistency_models"] = ["ConsistencyModelPipeline"]
_import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"]
_import_structure["ddim"] = ["DDIMPipeline"]
@@ -329,6 +330,8 @@ else:
"StableDiffusionXLInpaintPipeline",
"StableDiffusionXLInstructPix2PixPipeline",
"StableDiffusionXLPipeline",
"StableDiffusionXLModularPipeline",
"StableDiffusionXLAutoPipeline",
]
)
_import_structure["stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"]
@@ -478,6 +481,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .deprecated import KarrasVePipeline, LDMPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline
from .dit import DiTPipeline
from .latent_diffusion import LDMSuperResolutionPipeline
from .modular_pipeline import ModularPipeline
from .pipeline_utils import (
AudioPipelineOutput,
DiffusionPipeline,
@@ -702,7 +706,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLInstructPix2PixPipeline,
StableDiffusionXLModularPipeline,
StableDiffusionXLPipeline,
StableDiffusionXLAutoPipeline,
)
from .stable_video_diffusion import StableVideoDiffusionPipeline
from .t2i_adapter import (
@@ -514,7 +514,7 @@ class AllegroPipeline(DiffusionPipeline):
# &amp
caption = re.sub(r"&amp", "", caption)
# ip adresses:
# ip addresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
+8 -7
View File
@@ -246,14 +246,15 @@ def _get_connected_pipeline(pipeline_cls):
return _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, pipeline_cls.__name__, throw_error_if_not_exist=False)
def _get_task_class(mapping, pipeline_class_name, throw_error_if_not_exist: bool = True):
def get_model(pipeline_class_name):
for task_mapping in SUPPORTED_TASKS_MAPPINGS:
for model_name, pipeline in task_mapping.items():
if pipeline.__name__ == pipeline_class_name:
return model_name
def _get_model(pipeline_class_name):
for task_mapping in SUPPORTED_TASKS_MAPPINGS:
for model_name, pipeline in task_mapping.items():
if pipeline.__name__ == pipeline_class_name:
return model_name
model_name = get_model(pipeline_class_name)
def _get_task_class(mapping, pipeline_class_name, throw_error_if_not_exist: bool = True):
model_name = _get_model(pipeline_class_name)
if model_name is not None:
task_class = mapping.get(model_name, None)
@@ -0,0 +1,609 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from itertools import combinations
from typing import List, Optional, Union, Dict, Any
import copy
import torch
import time
from dataclasses import dataclass
from ..utils import (
is_accelerate_available,
logging,
)
from ..models.modeling_utils import ModelMixin
if is_accelerate_available():
from accelerate.hooks import ModelHook, add_hook_to_module, remove_hook_from_module
from accelerate.state import PartialState
from accelerate.utils import send_to_device
from accelerate.utils.memory import clear_device_cache
from accelerate.utils.modeling import convert_file_size_to_int
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# YiYi Notes: copied from modeling_utils.py (decide later where to put this)
def get_memory_footprint(self, return_buffers=True):
r"""
Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. Useful to
benchmark the memory footprint of the current model and design some tests. Solution inspired from the PyTorch
discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2
Arguments:
return_buffers (`bool`, *optional*, defaults to `True`):
Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers are
tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch norm
layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2
"""
mem = sum([param.nelement() * param.element_size() for param in self.parameters()])
if return_buffers:
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()])
mem = mem + mem_bufs
return mem
class CustomOffloadHook(ModelHook):
"""
A hook that offloads a model on the CPU until its forward pass is called. It ensures the model and its inputs are
on the given device. Optionally offloads other models to the CPU before the forward pass is called.
Args:
execution_device(`str`, `int` or `torch.device`, *optional*):
The device on which the model should be executed. Will default to the MPS device if it's available, then
GPU 0 if there is a GPU, and finally to the CPU.
"""
def __init__(
self,
execution_device: Optional[Union[str, int, torch.device]] = None,
other_hooks: Optional[List["UserCustomOffloadHook"]] = None,
offload_strategy: Optional["AutoOffloadStrategy"] = None,
):
self.execution_device = execution_device if execution_device is not None else PartialState().default_device
self.other_hooks = other_hooks
self.offload_strategy = offload_strategy
self.model_id = None
def set_strategy(self, offload_strategy: "AutoOffloadStrategy"):
self.offload_strategy = offload_strategy
def add_other_hook(self, hook: "UserCustomOffloadHook"):
"""
Add a hook to the list of hooks to consider for offloading.
"""
if self.other_hooks is None:
self.other_hooks = []
self.other_hooks.append(hook)
def init_hook(self, module):
return module.to("cpu")
def pre_forward(self, module, *args, **kwargs):
if module.device != self.execution_device:
if self.other_hooks is not None:
hooks_to_offload = [hook for hook in self.other_hooks if hook.model.device == self.execution_device]
# offload all other hooks
start_time = time.perf_counter()
if self.offload_strategy is not None:
hooks_to_offload = self.offload_strategy(
hooks=hooks_to_offload,
model_id=self.model_id,
model=module,
execution_device=self.execution_device,
)
end_time = time.perf_counter()
logger.info(
f" time taken to apply offload strategy for {self.model_id}: {(end_time - start_time):.2f} seconds"
)
for hook in hooks_to_offload:
logger.info(
f"moving {self.model_id} to {self.execution_device}, offloading {hook.model_id} to cpu"
)
hook.offload()
if hooks_to_offload:
clear_device_cache()
module.to(self.execution_device)
return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device)
class UserCustomOffloadHook:
"""
A simple hook grouping a model and a `CustomOffloadHook`, which provides easy APIs for to call the init method of
the hook or remove it entirely.
"""
def __init__(self, model_id, model, hook):
self.model_id = model_id
self.model = model
self.hook = hook
def offload(self):
self.hook.init_hook(self.model)
def attach(self):
add_hook_to_module(self.model, self.hook)
self.hook.model_id = self.model_id
def remove(self):
remove_hook_from_module(self.model)
self.hook.model_id = None
def add_other_hook(self, hook: "UserCustomOffloadHook"):
self.hook.add_other_hook(hook)
def custom_offload_with_hook(
model_id: str,
model: torch.nn.Module,
execution_device: Union[str, int, torch.device] = None,
offload_strategy: Optional["AutoOffloadStrategy"] = None,
):
hook = CustomOffloadHook(execution_device=execution_device, offload_strategy=offload_strategy)
user_hook = UserCustomOffloadHook(model_id=model_id, model=model, hook=hook)
user_hook.attach()
return user_hook
class AutoOffloadStrategy:
"""
Offload strategy that should be used with `CustomOffloadHook` to automatically offload models to the CPU based on
the available memory on the device.
"""
def __init__(self, memory_reserve_margin="3GB"):
self.memory_reserve_margin = convert_file_size_to_int(memory_reserve_margin)
def __call__(self, hooks, model_id, model, execution_device):
if len(hooks) == 0:
return []
current_module_size = get_memory_footprint(model)
mem_on_device = torch.cuda.mem_get_info(execution_device.index)[0]
mem_on_device = mem_on_device - self.memory_reserve_margin
if current_module_size < mem_on_device:
return []
min_memory_offload = current_module_size - mem_on_device
logger.info(f" search for models to offload in order to free up {min_memory_offload / 1024**3:.2f} GB memory")
# exlucde models that's not currently loaded on the device
module_sizes = dict(
sorted(
{hook.model_id: get_memory_footprint(hook.model) for hook in hooks}.items(),
key=lambda x: x[1],
reverse=True,
)
)
def search_best_candidate(module_sizes, min_memory_offload):
"""
search the optimal combination of models to offload to cpu, given a dictionary of module sizes and a
minimum memory offload size. the combination of models should add up to the smallest modulesize that is
larger than `min_memory_offload`
"""
model_ids = list(module_sizes.keys())
best_candidate = None
best_size = float("inf")
for r in range(1, len(model_ids) + 1):
for candidate_model_ids in combinations(model_ids, r):
candidate_size = sum(
module_sizes[candidate_model_id] for candidate_model_id in candidate_model_ids
)
if candidate_size < min_memory_offload:
continue
else:
if best_candidate is None or candidate_size < best_size:
best_candidate = candidate_model_ids
best_size = candidate_size
return best_candidate
best_offload_model_ids = search_best_candidate(module_sizes, min_memory_offload)
if best_offload_model_ids is None:
# if no combination is found, meaning that we cannot meet the memory requirement, offload all models
logger.warning("no combination of models to offload to cpu is found, offloading all models")
hooks_to_offload = hooks
else:
hooks_to_offload = [hook for hook in hooks if hook.model_id in best_offload_model_ids]
return hooks_to_offload
class ComponentsManager:
def __init__(self):
self.components = OrderedDict()
self.added_time = OrderedDict() # Store when components were added
self.model_hooks = None
self._auto_offload_enabled = False
def add(self, name, component):
if name in self.components:
logger.warning(f"Overriding existing component '{name}' in ComponentsManager")
self.components[name] = component
self.added_time[name] = time.time()
if self._auto_offload_enabled:
self.enable_auto_cpu_offload(self._auto_offload_device)
def remove(self, name):
if name not in self.components:
logger.warning(f"Component '{name}' not found in ComponentsManager")
return
self.components.pop(name)
self.added_time.pop(name)
if self._auto_offload_enabled:
self.enable_auto_cpu_offload(self._auto_offload_device)
# YiYi TODO: looking into improving the search pattern
def get(self, names: Union[str, List[str]]):
"""
Get components by name with simple pattern matching.
Args:
names: Component name(s) or pattern(s)
Patterns:
- "unet" : exact match
- "!unet" : everything except exact match "unet"
- "base_*" : everything starting with "base_"
- "!base_*" : everything NOT starting with "base_"
- "*unet*" : anything containing "unet"
- "!*unet*" : anything NOT containing "unet"
- "refiner|vae|unet" : anything containing any of these terms
- "!refiner|vae|unet" : anything NOT containing any of these terms
Returns:
Single component if names is str and matches one component,
dict of components if names matches multiple components or is a list
"""
if isinstance(names, str):
# Check if this is a "not" pattern
is_not_pattern = names.startswith('!')
if is_not_pattern:
names = names[1:] # Remove the ! prefix
# Handle OR patterns (containing |)
if '|' in names:
terms = names.split('|')
matches = {
name: comp for name, comp in self.components.items()
if any((term in name) != is_not_pattern for term in terms) # Flip condition if not pattern
}
if is_not_pattern:
logger.info(f"Getting components NOT containing any of {terms}: {list(matches.keys())}")
else:
logger.info(f"Getting components containing any of {terms}: {list(matches.keys())}")
# Exact match
elif names in self.components:
if is_not_pattern:
matches = {
name: comp for name, comp in self.components.items()
if name != names
}
logger.info(f"Getting all components except '{names}': {list(matches.keys())}")
else:
logger.info(f"Getting component: {names}")
return self.components[names]
# Prefix match (ends with *)
elif names.endswith('*'):
prefix = names[:-1]
matches = {
name: comp for name, comp in self.components.items()
if name.startswith(prefix) != is_not_pattern # Flip condition if not pattern
}
if is_not_pattern:
logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}")
else:
logger.info(f"Getting components starting with '{prefix}': {list(matches.keys())}")
# Contains match (starts with *)
elif names.startswith('*'):
search = names[1:-1] if names.endswith('*') else names[1:]
matches = {
name: comp for name, comp in self.components.items()
if (search in name) != is_not_pattern # Flip condition if not pattern
}
if is_not_pattern:
logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}")
else:
logger.info(f"Getting components containing '{search}': {list(matches.keys())}")
else:
raise ValueError(f"Component '{names}' not found in ComponentsManager")
if not matches:
raise ValueError(f"No components found matching pattern '{names}'")
return matches if len(matches) > 1 else next(iter(matches.values()))
elif isinstance(names, list):
results = {}
for name in names:
result = self.get(name)
if isinstance(result, dict):
results.update(result)
else:
results[name] = result
logger.info(f"Getting multiple components: {list(results.keys())}")
return results
else:
raise ValueError(f"Invalid type for names: {type(names)}")
def enable_auto_cpu_offload(self, device: Union[str, int, torch.device]="cuda", memory_reserve_margin="3GB"):
for name, component in self.components.items():
if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"):
remove_hook_from_module(component, recurse=True)
self.disable_auto_cpu_offload()
offload_strategy = AutoOffloadStrategy(memory_reserve_margin=memory_reserve_margin)
device = torch.device(device)
if device.index is None:
device = torch.device(f"{device.type}:{0}")
all_hooks = []
for name, component in self.components.items():
if isinstance(component, torch.nn.Module):
hook = custom_offload_with_hook(name, component, device, offload_strategy=offload_strategy)
all_hooks.append(hook)
for hook in all_hooks:
other_hooks = [h for h in all_hooks if h is not hook]
for other_hook in other_hooks:
if other_hook.hook.execution_device == hook.hook.execution_device:
hook.add_other_hook(other_hook)
self.model_hooks = all_hooks
self._auto_offload_enabled = True
self._auto_offload_device = device
def disable_auto_cpu_offload(self):
if self.model_hooks is None:
self._auto_offload_enabled = False
return
for hook in self.model_hooks:
hook.offload()
hook.remove()
if self.model_hooks:
clear_device_cache()
self.model_hooks = None
self._auto_offload_enabled = False
def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = None) -> Optional[Dict[str, Any]]:
"""Get comprehensive information about a component.
Args:
name: Name of the component to get info for
fields: Optional field(s) to return. Can be a string for single field or list of fields.
If None, returns all fields.
Returns:
Dictionary containing requested component metadata.
If fields is specified, returns only those fields.
If a single field is requested as string, returns just that field's value.
"""
if name not in self.components:
raise ValueError(f"Component '{name}' not found in ComponentsManager")
component = self.components[name]
# Build complete info dict first
info = {
"model_id": name,
"added_time": self.added_time[name],
}
# Additional info for torch.nn.Module components
if isinstance(component, torch.nn.Module):
info.update({
"class_name": component.__class__.__name__,
"size_gb": get_memory_footprint(component) / (1024**3),
"adapters": None, # Default to None
})
# Get adapters if applicable
if hasattr(component, "peft_config"):
info["adapters"] = list(component.peft_config.keys())
# Check for IP-Adapter scales
if hasattr(component, "_load_ip_adapter_weights") and hasattr(component, "attn_processors"):
processors = copy.deepcopy(component.attn_processors)
# First check if any processor is an IP-Adapter
processor_types = [v.__class__.__name__ for v in processors.values()]
if any("IPAdapter" in ptype for ptype in processor_types):
# Then get scales only from IP-Adapter processors
scales = {
k: v.scale
for k, v in processors.items()
if hasattr(v, "scale") and "IPAdapter" in v.__class__.__name__
}
if scales:
info["ip_adapter"] = summarize_dict_by_value_and_parts(scales)
# If fields specified, filter info
if fields is not None:
if isinstance(fields, str):
# Single field requested, return just that value
return {fields: info.get(fields)}
else:
# List of fields requested, return dict with just those fields
return {k: v for k, v in info.items() if k in fields}
return info
def __repr__(self):
col_widths = {
"id": max(15, max(len(id) for id in self.components.keys())),
"class": max(25, max(len(component.__class__.__name__) for component in self.components.values())),
"device": 10,
"dtype": 15,
"size": 10,
}
# Create the header lines
sep_line = "=" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n"
dash_line = "-" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n"
output = "Components:\n" + sep_line
# Separate components into models and others
models = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)}
others = {k: v for k, v in self.components.items() if not isinstance(v, torch.nn.Module)}
# Models section
if models:
output += "Models:\n" + dash_line
# Column headers
output += f"{'Model ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}} | "
output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | Size (GB)\n"
output += dash_line
# Model entries
for name, component in models.items():
info = self.get_model_info(name)
device = str(getattr(component, "device", "N/A"))
dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A"
output += f"{name:<{col_widths['id']}} | {info['class_name']:<{col_widths['class']}} | "
output += f"{device:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | {info['size_gb']:.2f}\n"
output += dash_line
# Other components section
if others:
if models: # Add extra newline if we had models section
output += "\n"
output += "Other Components:\n" + dash_line
# Column headers for other components
output += f"{'Component ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}}\n"
output += dash_line
# Other component entries
for name, component in others.items():
output += f"{name:<{col_widths['id']}} | {component.__class__.__name__:<{col_widths['class']}}\n"
output += dash_line
# Add additional component info
output += "\nAdditional Component Info:\n" + "=" * 50 + "\n"
for name in self.components:
info = self.get_model_info(name)
if info is not None and (info.get("adapters") is not None or info.get("ip_adapter")):
output += f"\n{name}:\n"
if info.get("adapters") is not None:
output += f" Adapters: {info['adapters']}\n"
if info.get("ip_adapter"):
output += f" IP-Adapter: Enabled\n"
output += f" Added Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(info['added_time']))}\n"
return output
def add_from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs):
"""
Load components from a pretrained model and add them to the manager.
Args:
pretrained_model_name_or_path (str): The path or identifier of the pretrained model
prefix (str, optional): Prefix to add to all component names loaded from this model.
If provided, components will be named as "{prefix}_{component_name}"
**kwargs: Additional arguments to pass to DiffusionPipeline.from_pretrained()
"""
from ..pipelines.pipeline_utils import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs)
for name, component in pipe.components.items():
if component is None:
continue
# Add prefix if specified
component_name = f"{prefix}_{name}" if prefix else name
if component_name not in self.components:
self.add(component_name, component)
else:
logger.warning(
f"Component '{component_name}' already exists in ComponentsManager and will not be added. To add it, either:\n"
f"1. remove the existing component with remove('{component_name}')\n"
f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')"
)
def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
"""Summarizes a dictionary by finding common prefixes that share the same value.
For a dictionary with dot-separated keys like:
{
'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': [0.6],
'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor': [0.6],
'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': [0.3],
}
Returns a dictionary where keys are the shortest common prefixes and values are their shared values:
{
'down_blocks': [0.6],
'up_blocks': [0.3]
}
"""
# First group by values - convert lists to tuples to make them hashable
value_to_keys = {}
for key, value in d.items():
value_tuple = tuple(value) if isinstance(value, list) else value
if value_tuple not in value_to_keys:
value_to_keys[value_tuple] = []
value_to_keys[value_tuple].append(key)
def find_common_prefix(keys: List[str]) -> str:
"""Find the shortest common prefix among a list of dot-separated keys."""
if not keys:
return ""
if len(keys) == 1:
return keys[0]
# Split all keys into parts
key_parts = [k.split('.') for k in keys]
# Find how many initial parts are common
common_length = 0
for parts in zip(*key_parts):
if len(set(parts)) == 1: # All parts at this position are the same
common_length += 1
else:
break
if common_length == 0:
return ""
# Return the common prefix
return '.'.join(key_parts[0][:common_length])
# Create summary by finding common prefixes for each value group
summary = {}
for value_tuple, keys in value_to_keys.items():
prefix = find_common_prefix(keys)
if prefix: # Only add if we found a common prefix
# Convert tuple back to list if it was originally a list
value = list(value_tuple) if isinstance(d[keys[0]], list) else value_tuple
summary[prefix] = value
else:
summary[""] = value # Use empty string if no common prefix
return summary
@@ -912,12 +912,6 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
)
latents_mean = latents_std = None
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
# Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu")
@@ -931,6 +925,11 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
init_latents = image
else:
latents_mean = latents_std = None
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
# make sure the VAE is in float32 mode, as it overflows in float16
if self.vae.config.force_upcast:
image = image.float()
@@ -867,12 +867,6 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
)
latents_mean = latents_std = None
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
# Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu")
@@ -886,6 +880,11 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
init_latents = image
else:
latents_mean = latents_std = None
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
# make sure the VAE is in float32 mode, as it overflows in float16
if self.vae.config.force_upcast:
image = image.float()
@@ -484,7 +484,7 @@ class IFPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin):
# &amp
caption = re.sub(r"&amp", "", caption)
# ip adresses:
# ip addresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
@@ -528,7 +528,7 @@ class IFImg2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin):
# &amp
caption = re.sub(r"&amp", "", caption)
# ip adresses:
# ip addresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
@@ -281,7 +281,7 @@ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, StableDiffusionLoraLoa
# &amp
caption = re.sub(r"&amp", "", caption)
# ip adresses:
# ip addresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
@@ -568,7 +568,7 @@ class IFInpaintingPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin):
# &amp
caption = re.sub(r"&amp", "", caption)
# ip adresses:
# ip addresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
@@ -283,7 +283,7 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, StableDiffusionLora
# &amp
caption = re.sub(r"&amp", "", caption)
# ip adresses:
# ip addresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
@@ -239,7 +239,7 @@ class IFSuperResolutionPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixi
# &amp
caption = re.sub(r"&amp", "", caption)
# ip adresses:
# ip addresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
@@ -574,7 +574,7 @@ class StableDiffusionModelEditingPipeline(
idxs_replace.append(76)
idxs_replaces.append(idxs_replace)
# prepare batch: for each pair of setences, old context and new values
# prepare batch: for each pair of sentences, old context and new values
contexts, valuess = [], []
for old_emb, new_emb, idxs_replace in zip(old_embs, new_embs, idxs_replaces):
context = old_emb.detach()
+2 -10
View File
@@ -490,14 +490,6 @@ class FluxPipeline(
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
@@ -821,7 +813,7 @@ class FluxPipeline(
(
negative_prompt_embeds,
negative_pooled_prompt_embeds,
_,
negative_text_ids,
) = self.encode_prompt(
prompt=negative_prompt,
prompt_2=negative_prompt_2,
@@ -938,7 +930,7 @@ class FluxPipeline(
guidance=guidance,
pooled_projections=negative_pooled_prompt_embeds,
encoder_hidden_states=negative_prompt_embeds,
txt_ids=text_ids,
txt_ids=negative_text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
@@ -800,17 +800,20 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
)
height, width = control_image.shape[-2:]
control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
# xlab controlnet has a input_hint_block and instantx controlnet does not
controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True
if self.controlnet.input_hint_block is None:
control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
height_control_image, width_control_image = control_image.shape[2:]
control_image = self._pack_latents(
control_image,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)
height_control_image, width_control_image = control_image.shape[2:]
control_image = self._pack_latents(
control_image,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)
if control_mode is not None:
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
@@ -819,7 +822,9 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
elif isinstance(self.controlnet, FluxMultiControlNetModel):
control_images = []
for control_image_ in control_image:
# xlab controlnet has a input_hint_block and instantx controlnet does not
controlnet_blocks_repeat = False if self.controlnet.nets[0].input_hint_block is None else True
for i, control_image_ in enumerate(control_image):
control_image_ = self.prepare_image(
image=control_image_,
width=width,
@@ -831,17 +836,18 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
)
height, width = control_image_.shape[-2:]
control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
if self.controlnet.nets[0].input_hint_block is None:
control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
height_control_image, width_control_image = control_image_.shape[2:]
control_image_ = self._pack_latents(
control_image_,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)
height_control_image, width_control_image = control_image_.shape[2:]
control_image_ = self._pack_latents(
control_image_,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)
control_images.append(control_image_)
@@ -955,6 +961,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
controlnet_blocks_repeat=controlnet_blocks_repeat,
)[0]
latents_dtype = latents.dtype
@@ -13,6 +13,7 @@ from transformers import (
)
from ...image_processor import VaeImageProcessor
from ...loaders import HiDreamImageLoraLoaderMixin
from ...models import AutoencoderKL, HiDreamImageTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler
from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
@@ -142,7 +143,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
class HiDreamImagePipeline(DiffusionPipeline):
class HiDreamImagePipeline(DiffusionPipeline, HiDreamImageLoraLoaderMixin):
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->text_encoder_4->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds_t5", "prompt_embeds_llama3", "pooled_prompt_embeds"]
@@ -822,13 +823,13 @@ class HiDreamImagePipeline(DiffusionPipeline):
if prompt_embeds is not None:
deprecation_message = "The `prompt_embeds` argument is deprecated. Please use `prompt_embeds_t5` and `prompt_embeds_llama3` instead."
deprecate("prompt_embeds", "0.34.0", deprecation_message)
deprecate("prompt_embeds", "0.35.0", deprecation_message)
prompt_embeds_t5 = prompt_embeds[0]
prompt_embeds_llama3 = prompt_embeds[1]
if negative_prompt_embeds is not None:
deprecation_message = "The `negative_prompt_embeds` argument is deprecated. Please use `negative_prompt_embeds_t5` and `negative_prompt_embeds_llama3` instead."
deprecate("negative_prompt_embeds", "0.34.0", deprecation_message)
deprecate("negative_prompt_embeds", "0.35.0", deprecation_message)
negative_prompt_embeds_t5 = negative_prompt_embeds[0]
negative_prompt_embeds_llama3 = negative_prompt_embeds[1]
@@ -14,14 +14,13 @@
from typing import Callable, List, Optional, Union
import numpy as np
import PIL.Image
import torch
from PIL import Image
from transformers import (
XLMRobertaTokenizer,
)
from ...image_processor import VaeImageProcessor
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDIMScheduler
from ...utils import (
@@ -95,15 +94,6 @@ def get_new_h_w(h, w, scale_factor=8):
return new_h * scale_factor, new_w * scale_factor
def prepare_image(pil_image, w=512, h=512):
pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
arr = np.array(pil_image.convert("RGB"))
arr = arr.astype(np.float32) / 127.5 - 1
arr = np.transpose(arr, [2, 0, 1])
image = torch.from_numpy(arr).unsqueeze(0)
return image
class KandinskyImg2ImgPipeline(DiffusionPipeline):
"""
Pipeline for image-to-image generation using Kandinsky
@@ -143,7 +133,16 @@ class KandinskyImg2ImgPipeline(DiffusionPipeline):
scheduler=scheduler,
movq=movq,
)
self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
self.movq_scale_factor = (
2 ** (len(self.movq.config.block_out_channels) - 1) if getattr(self, "movq", None) else 8
)
movq_latent_channels = self.movq.config.latent_channels if getattr(self, "movq", None) else 4
self.image_processor = VaeImageProcessor(
vae_scale_factor=self.movq_scale_factor,
vae_latent_channels=movq_latent_channels,
resample="bicubic",
reducing_gap=1,
)
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
@@ -417,7 +416,7 @@ class KandinskyImg2ImgPipeline(DiffusionPipeline):
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
)
image = torch.cat([prepare_image(i, width, height) for i in image], dim=0)
image = torch.cat([self.image_processor.preprocess(i, width, height) for i in image], dim=0)
image = image.to(dtype=prompt_embeds.dtype, device=device)
latents = self.movq.encode(image)["latents"]
@@ -498,13 +497,7 @@ class KandinskyImg2ImgPipeline(DiffusionPipeline):
if output_type not in ["pt", "np", "pil"]:
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
if output_type in ["np", "pil"]:
image = image * 0.5 + 0.5
image = image.clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
image = self.image_processor.postprocess(image, output_type)
if not return_dict:
return (image,)
@@ -14,11 +14,10 @@
from typing import Callable, List, Optional, Union
import numpy as np
import PIL.Image
import torch
from PIL import Image
from ...image_processor import VaeImageProcessor
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDPMScheduler
from ...utils import (
@@ -105,27 +104,6 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width
def downscale_height_and_width(height, width, scale_factor=8):
new_height = height // scale_factor**2
if height % scale_factor**2 != 0:
new_height += 1
new_width = width // scale_factor**2
if width % scale_factor**2 != 0:
new_width += 1
return new_height * scale_factor, new_width * scale_factor
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.prepare_image
def prepare_image(pil_image, w=512, h=512):
pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
arr = np.array(pil_image.convert("RGB"))
arr = arr.astype(np.float32) / 127.5 - 1
arr = np.transpose(arr, [2, 0, 1])
image = torch.from_numpy(arr).unsqueeze(0)
return image
class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline):
"""
Pipeline for image-to-image generation using Kandinsky
@@ -157,7 +135,14 @@ class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline):
scheduler=scheduler,
movq=movq,
)
self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) if getattr(self, "movq", None) else 8
movq_latent_channels = self.movq.config.latent_channels if getattr(self, "movq", None) else 4
self.image_processor = VaeImageProcessor(
vae_scale_factor=movq_scale_factor,
vae_latent_channels=movq_latent_channels,
resample="bicubic",
reducing_gap=1,
)
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.KandinskyImg2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device):
@@ -316,7 +301,7 @@ class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline):
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
)
image = torch.cat([prepare_image(i, width, height) for i in image], dim=0)
image = torch.cat([self.image_processor.preprocess(i, width, height) for i in image], dim=0)
image = image.to(dtype=image_embeds.dtype, device=device)
latents = self.movq.encode(image)["latents"]
@@ -324,7 +309,6 @@ class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline):
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
height, width = downscale_height_and_width(height, width, self.movq_scale_factor)
latents = self.prepare_latents(
latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator
)
@@ -379,13 +363,7 @@ class KandinskyV22ControlnetImg2ImgPipeline(DiffusionPipeline):
if output_type not in ["pt", "np", "pil"]:
raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}")
if output_type in ["np", "pil"]:
image = image * 0.5 + 0.5
image = image.clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
image = self.image_processor.postprocess(image, output_type)
if not return_dict:
return (image,)
@@ -14,11 +14,10 @@
from typing import Callable, Dict, List, Optional, Union
import numpy as np
import PIL.Image
import torch
from PIL import Image
from ...image_processor import VaeImageProcessor
from ...models import UNet2DConditionModel, VQModel
from ...schedulers import DDPMScheduler
from ...utils import deprecate, is_torch_xla_available, logging
@@ -76,27 +75,6 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.kandinsky2_2.pipeline_kandinsky2_2.downscale_height_and_width
def downscale_height_and_width(height, width, scale_factor=8):
new_height = height // scale_factor**2
if height % scale_factor**2 != 0:
new_height += 1
new_width = width // scale_factor**2
if width % scale_factor**2 != 0:
new_width += 1
return new_height * scale_factor, new_width * scale_factor
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.prepare_image
def prepare_image(pil_image, w=512, h=512):
pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
arr = np.array(pil_image.convert("RGB"))
arr = arr.astype(np.float32) / 127.5 - 1
arr = np.transpose(arr, [2, 0, 1])
image = torch.from_numpy(arr).unsqueeze(0)
return image
class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
"""
Pipeline for image-to-image generation using Kandinsky
@@ -129,7 +107,14 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
scheduler=scheduler,
movq=movq,
)
self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1)
movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) if getattr(self, "movq", None) else 8
movq_latent_channels = self.movq.config.latent_channels if getattr(self, "movq", None) else 4
self.image_processor = VaeImageProcessor(
vae_scale_factor=movq_scale_factor,
vae_latent_channels=movq_latent_channels,
resample="bicubic",
reducing_gap=1,
)
# Copied from diffusers.pipelines.kandinsky.pipeline_kandinsky_img2img.KandinskyImg2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device):
@@ -319,7 +304,7 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
)
image = torch.cat([prepare_image(i, width, height) for i in image], dim=0)
image = torch.cat([self.image_processor.preprocess(i, width, height) for i in image], dim=0)
image = image.to(dtype=image_embeds.dtype, device=device)
latents = self.movq.encode(image)["latents"]
@@ -327,7 +312,6 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
height, width = downscale_height_and_width(height, width, self.movq_scale_factor)
latents = self.prepare_latents(
latents, latent_timestep, batch_size, num_images_per_prompt, image_embeds.dtype, device, generator
)
@@ -383,21 +367,9 @@ class KandinskyV22Img2ImgPipeline(DiffusionPipeline):
if XLA_AVAILABLE:
xm.mark_step()
if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError(
f"Only the output types `pt`, `pil` ,`np` and `latent` are supported not output_type={output_type}"
)
if not output_type == "latent":
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
if output_type in ["np", "pil"]:
image = image * 0.5 + 0.5
image = image.clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
image = self.image_processor.postprocess(image, output_type)
else:
image = latents
@@ -1,12 +1,12 @@
import inspect
from typing import Callable, Dict, List, Optional, Union
import numpy as np
import PIL
import PIL.Image
import torch
from transformers import T5EncoderModel, T5Tokenizer
from ...image_processor import VaeImageProcessor
from ...loaders import StableDiffusionLoraLoaderMixin
from ...models import Kandinsky3UNet, VQModel
from ...schedulers import DDPMScheduler
@@ -53,24 +53,6 @@ EXAMPLE_DOC_STRING = """
"""
def downscale_height_and_width(height, width, scale_factor=8):
new_height = height // scale_factor**2
if height % scale_factor**2 != 0:
new_height += 1
new_width = width // scale_factor**2
if width % scale_factor**2 != 0:
new_width += 1
return new_height * scale_factor, new_width * scale_factor
def prepare_image(pil_image):
arr = np.array(pil_image.convert("RGB"))
arr = arr.astype(np.float32) / 127.5 - 1
arr = np.transpose(arr, [2, 0, 1])
image = torch.from_numpy(arr).unsqueeze(0)
return image
class Kandinsky3Img2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin):
model_cpu_offload_seq = "text_encoder->movq->unet->movq"
_callback_tensor_inputs = [
@@ -94,6 +76,14 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixi
self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, movq=movq
)
movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) if getattr(self, "movq", None) else 8
movq_latent_channels = self.movq.config.latent_channels if getattr(self, "movq", None) else 4
self.image_processor = VaeImageProcessor(
vae_scale_factor=movq_scale_factor,
vae_latent_channels=movq_latent_channels,
resample="bicubic",
reducing_gap=1,
)
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
@@ -566,7 +556,7 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixi
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
)
image = torch.cat([prepare_image(i) for i in image], dim=0)
image = torch.cat([self.image_processor.preprocess(i) for i in image], dim=0)
image = image.to(dtype=prompt_embeds.dtype, device=device)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -630,20 +620,9 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixi
xm.mark_step()
# post-processing
if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError(
f"Only the output types `pt`, `pil`, `np` and `latent` are supported not output_type={output_type}"
)
if not output_type == "latent":
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
if output_type in ["np", "pil"]:
image = image * 0.5 + 0.5
image = image.clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
image = self.image_processor.postprocess(image, output_type)
else:
image = latents
@@ -609,12 +609,6 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
)
latents_mean = latents_std = None
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
# Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu")
@@ -628,6 +622,11 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu
init_latents = image
else:
latents_mean = latents_std = None
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
# make sure the VAE is in float32 mode, as it overflows in float16
if self.vae.config.force_upcast:
image = image.float()
@@ -501,7 +501,7 @@ class LattePipeline(DiffusionPipeline):
# &amp
caption = re.sub(r"&amp", "", caption)
# ip adresses:
# ip addresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
@@ -534,7 +534,7 @@ class LuminaPipeline(DiffusionPipeline):
# &amp
caption = re.sub(r"&amp", "", caption)
# ip adresses:
# ip addresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
File diff suppressed because it is too large Load Diff
+15 -2
View File
@@ -75,6 +75,11 @@ class OnnxRuntimeModel:
logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
provider = "CPUExecutionProvider"
if provider_options is None:
provider_options = []
elif not isinstance(provider_options, list):
provider_options = [provider_options]
return ort.InferenceSession(
path, providers=[provider], sess_options=sess_options, provider_options=provider_options
)
@@ -174,7 +179,10 @@ class OnnxRuntimeModel:
# load model from local directory
if os.path.isdir(model_id):
model = OnnxRuntimeModel.load_model(
Path(model_id, model_file_name).as_posix(), provider=provider, sess_options=sess_options
Path(model_id, model_file_name).as_posix(),
provider=provider,
sess_options=sess_options,
provider_options=kwargs.pop("provider_options"),
)
kwargs["model_save_dir"] = Path(model_id)
# load model from hub
@@ -190,7 +198,12 @@ class OnnxRuntimeModel:
)
kwargs["model_save_dir"] = Path(model_cache_path).parent
kwargs["latest_model_name"] = Path(model_cache_path).name
model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider, sess_options=sess_options)
model = OnnxRuntimeModel.load_model(
model_cache_path,
provider=provider,
sess_options=sess_options,
provider_options=kwargs.pop("provider_options"),
)
return cls(model=model, **kwargs)
@classmethod
@@ -917,12 +917,6 @@ class StableDiffusionXLControlNetPAGImg2ImgPipeline(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
)
latents_mean = latents_std = None
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
# Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu")
@@ -936,6 +930,11 @@ class StableDiffusionXLControlNetPAGImg2ImgPipeline(
init_latents = image
else:
latents_mean = latents_std = None
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
# make sure the VAE is in float32 mode, as it overflows in float16
if self.vae.config.force_upcast:
image = image.float()
@@ -488,7 +488,7 @@ class PixArtSigmaPAGPipeline(DiffusionPipeline, PAGMixin):
# &amp
caption = re.sub(r"&amp", "", caption)
# ip adresses:
# ip addresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
@@ -524,7 +524,7 @@ class SanaPAGPipeline(DiffusionPipeline, PAGMixin):
# &amp
caption = re.sub(r"&amp", "", caption)
# ip adresses:
# ip addresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
@@ -707,12 +707,6 @@ class StableDiffusionXLPAGImg2ImgPipeline(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
)
latents_mean = latents_std = None
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
# Offload text encoder if `enable_model_cpu_offload` was enabled
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.text_encoder_2.to("cpu")
@@ -726,6 +720,11 @@ class StableDiffusionXLPAGImg2ImgPipeline(
init_latents = image
else:
latents_mean = latents_std = None
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
# make sure the VAE is in float32 mode, as it overflows in float16
if self.vae.config.force_upcast:
image = image.float()
@@ -469,7 +469,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
class_obj = import_flax_or_no_model(pipeline_module, class_name)
importable_classes = ALL_IMPORTABLE_CLASSES
class_candidates = {c: class_obj for c in importable_classes.keys()}
class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
else:
# else we just import it from the library.
library = importlib.import_module(library_name)
@@ -341,13 +341,13 @@ def get_class_obj_and_candidates(
pipeline_module = getattr(pipelines, library_name)
class_obj = getattr(pipeline_module, class_name)
class_candidates = {c: class_obj for c in importable_classes.keys()}
class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
elif os.path.isfile(os.path.join(component_folder, library_name + ".py")):
# load custom component
class_obj = get_class_from_dynamic_module(
component_folder, module_file=library_name + ".py", class_name=class_name
)
class_candidates = {c: class_obj for c in importable_classes.keys()}
class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
else:
# else we just import it from the library.
library = importlib.import_module(library_name)
@@ -412,7 +412,7 @@ def _get_pipeline_class(
revision=revision,
)
if class_obj.__name__ != "DiffusionPipeline":
if class_obj.__name__ != "DiffusionPipeline" and class_obj.__name__ != "ModularPipeline":
return class_obj
diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0])
+21 -3
View File
@@ -58,6 +58,7 @@ from ..utils import (
_is_valid_type,
is_accelerate_available,
is_accelerate_version,
is_hpu_available,
is_torch_npu_available,
is_torch_version,
is_transformers_version,
@@ -426,7 +427,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
module_is_sequentially_offloaded(module) for _, module in self.components.items()
)
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
is_pipeline_device_mapped = hasattr(self, "hf_device_map") and self.hf_device_map is not None and len(self.hf_device_map) > 1
if is_pipeline_device_mapped:
raise ValueError(
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline."
@@ -443,6 +444,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
"You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation."
)
# Display a warning in this case (the operation succeeds but the benefits are lost)
pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
if pipeline_is_offloaded and device_type in ["cuda", "xpu"]:
@@ -450,6 +452,20 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
)
# Enable generic support for Intel Gaudi accelerator using GPU/HPU migration
if device_type == "hpu" and kwargs.pop("hpu_migration", True) and is_hpu_available():
os.environ["PT_HPU_GPU_MIGRATION"] = "1"
logger.debug("Environment variable set: PT_HPU_GPU_MIGRATION=1")
import habana_frameworks.torch # noqa: F401
# HPU hardware check
if not (hasattr(torch, "hpu") and torch.hpu.is_available()):
raise ValueError("You are trying to call `.to('hpu')` but HPU device is unavailable.")
os.environ["PT_HPU_MAX_COMPOUND_OP_SIZE"] = "1"
logger.debug("Environment variable set: PT_HPU_MAX_COMPOUND_OP_SIZE=1")
module_names, _ = self._get_signature_keys(self)
modules = [getattr(self, n, None) for n in module_names]
modules = [m for m in modules if isinstance(m, torch.nn.Module)]
@@ -1104,9 +1120,11 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
automatically detect the available accelerator and use.
"""
self._maybe_raise_error_if_group_offload_active(raise_error=True)
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
is_pipeline_device_mapped = hasattr(self, "hf_device_map") and self.hf_device_map is not None and len(self.hf_device_map) > 1
if is_pipeline_device_mapped:
raise ValueError(
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_model_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_model_cpu_offload()`."
@@ -1230,7 +1248,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
self.remove_all_hooks()
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
is_pipeline_device_mapped = hasattr(self, "hf_device_map") and self.hf_device_map is not None and len(self.hf_device_map) > 1
if is_pipeline_device_mapped:
raise ValueError(
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`."
@@ -598,7 +598,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
# &amp
caption = re.sub(r"&amp", "", caption)
# ip adresses:
# ip addresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
@@ -525,7 +525,7 @@ class PixArtSigmaPipeline(DiffusionPipeline):
# &amp
caption = re.sub(r"&amp", "", caption)
# ip adresses:
# ip addresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
@@ -600,7 +600,7 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
# &amp
caption = re.sub(r"&amp", "", caption)
# ip adresses:
# ip addresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
@@ -615,7 +615,7 @@ class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
# &amp
caption = re.sub(r"&amp", "", caption)
# ip adresses:
# ip addresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
@@ -491,7 +491,7 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
# &amp
caption = re.sub(r"&amp", "", caption)
# ip adresses:
# ip addresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
@@ -175,7 +175,7 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline, StableDiffusionM
tokenizer ([`~transformers.CLIPTokenizer`]):
A `CLIPTokenizer` to tokenize text.
processor ([`~transformers.CLIPProcessor`]):
A `CLIPProcessor` to procces reference image.
A `CLIPProcessor` to process reference image.
image_encoder ([`~transformers.CLIPVisionModelWithProjection`]):
Frozen image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
image_project ([`CLIPImageProjection`]):

Some files were not shown because too many files have changed in this diff Show More