Compare commits

..

61 Commits

Author SHA1 Message Date
Sayak Paul 51a855c8c6 Merge branch 'main' into layerwise-upcasting 2024-08-19 14:22:12 +05:30
Dhruv Nair 940b8e0358 [CI] Multiple Slow Test fixes. (#9198)
* update

* update

* update

* update
2024-08-19 13:31:09 +05:30
Dhruv Nair b2add10d13 Update is_safetensors_compatible check (#8991)
* update

* update

* update

* update

* update
2024-08-19 11:35:22 +05:30
Wenlong Wu 815d882217 Add loading text inversion (#9130) 2024-08-19 09:26:27 +05:30
Sayak Paul c64fa22c08 Merge branch 'main' into layerwise-upcasting 2024-08-19 09:13:18 +05:30
M Saqlain ba4348d9a7 [Tests] Improve transformers model test suite coverage - Lumina (#8987)
* Added test suite for lumina

* Fixed failing tests

* Improved code quality

* Added function docstrings

* Improved formatting
2024-08-19 08:29:03 +05:30
townwish4git d25eb5d385 fix(sd3): fix deletion of text_encoders etc (#8951)
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2024-08-18 15:37:40 -10:00
Tolga Cangöz 7ef8a46523 [Docs] Fix CPU offloading usage (#9207)
* chore: Fix cpu offloading usage

* Trim trailing white space

* docs: update Kolors model link in kolors.md
2024-08-18 13:12:12 -10:00
Sayak Paul f848febacd feat: allow sharding for auraflow. (#8853) 2024-08-18 08:47:26 +05:30
Beinsezii b38255006a Add Lumina T2I Auto Pipe Mapping (#8962) 2024-08-16 23:14:17 -10:00
Jianqi Pan cba548d8a3 fix(pipeline): k sampler sigmas device (#9189)
If Karras is not enabled, a device inconsistency error will occur. This is due to the fact that sigmas were not moved to the specified device.
2024-08-16 22:43:42 -10:00
Álvaro Somoza db829a4be4 [IP Adapter] Fix object has no attribute with image encoder (#9194)
* fix

* apply suggestion
2024-08-17 02:00:04 -04:00
Sayak Paul 0d1a1f875a Merge branch 'main' into layerwise-upcasting 2024-08-16 14:15:15 +05:30
Sayak Paul e780c05cc3 [Chore] add set_default_attn_processor to pixart. (#9196)
add set_default_attn_processor to pixart.
2024-08-16 13:07:06 +05:30
C e649678bf5 [Flux] Optimize guidance creation in flux pipeline by moving it outside the loop (#9153)
* optimize guidance creation in flux pipeline by moving it outside the loop

* use torch.full instead of torch.tensor to create a tensor with a single value

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-08-16 10:14:05 +05:30
Sayak Paul 39b87b14b5 feat: allow flux transformer to be sharded during inference (#9159)
* feat: support sharding for flux.

* tests
2024-08-16 10:00:51 +05:30
Sayak Paul f1fa1235e4 Merge branch 'main' into layerwise-upcasting 2024-08-16 09:48:53 +05:30
Dhruv Nair 3e46043223 Small improvements for video loading (#9183)
* update

* update
2024-08-16 09:36:58 +05:30
Simo Ryu 1a92bc05a7 Add Learned PE selection for Auraflow (#9182)
* add pe

* Update src/diffusers/models/transformers/auraflow_transformer_2d.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/models/transformers/auraflow_transformer_2d.py

* beauty

* retrigger ci.

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-08-15 16:30:24 +05:30
Sayak Paul 9b411e5ff3 Merge branch 'main' into layerwise-upcasting 2024-08-15 10:34:40 +05:30
Dhruv Nair b366b22191 update 2024-08-14 14:50:18 +02:00
Dhruv Nair 1fdae85f49 update 2024-08-14 14:19:20 +02:00
dependabot[bot] 0c1e63bd11 Bump jinja2 from 3.1.3 to 3.1.4 in /examples/research_projects/realfill (#7873)
Bumps [jinja2](https://github.com/pallets/jinja) from 3.1.3 to 3.1.4.
- [Release notes](https://github.com/pallets/jinja/releases)
- [Changelog](https://github.com/pallets/jinja/blob/main/CHANGES.rst)
- [Commits](https://github.com/pallets/jinja/compare/3.1.3...3.1.4)

---
updated-dependencies:
- dependency-name: jinja2
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-08-14 16:36:59 +05:30
dependabot[bot] e7e45bd127 Bump torch from 2.0.1 to 2.2.0 in /examples/research_projects/realfill (#8971)
Bumps [torch](https://github.com/pytorch/pytorch) from 2.0.1 to 2.2.0.
- [Release notes](https://github.com/pytorch/pytorch/releases)
- [Changelog](https://github.com/pytorch/pytorch/blob/main/RELEASE.md)
- [Commits](https://github.com/pytorch/pytorch/compare/v2.0.1...v2.2.0)

---
updated-dependencies:
- dependency-name: torch
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-08-14 16:36:46 +05:30
Álvaro Somoza 82058a5413 post release 0.30.0 (#9173)
* post release

* fix quality
2024-08-14 12:55:55 +05:30
Dhruv Nair 6b9fd0905e update 2024-08-14 08:22:43 +02:00
Aryan a85b34e7fd [refactor] CogVideoX followups + tiled decoding support (#9150)
* refactor context parallel cache; update torch compile time benchmark

* add tiling support

* make style

* remove num_frames % 8 == 0 requirement

* update default num_frames to original value

* add explanations + refactor

* update torch compile example

* update docs

* update

* clean up if-statements

* address review comments

* add test for vae tiling

* update docs

* update docs

* update docstrings

* add modeling test for cogvideox transformer

* make style
2024-08-14 03:53:21 +05:30
王奇勋 5ffbe14c32 [FLUX] Support ControlNet (#9126)
* cnt model

* cnt model

* cnt model

* fix Loader "Copied"

* format

* txt_ids for  multiple images

* add test and format

* typo

* Update pipeline_flux_controlnet.py

* remove

* make quality

* fix copy

* Update src/diffusers/pipelines/flux/pipeline_flux_controlnet.py

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* Update src/diffusers/pipelines/flux/pipeline_flux_controlnet.py

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* Update src/diffusers/pipelines/flux/pipeline_flux_controlnet.py

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* Update src/diffusers/pipelines/flux/pipeline_flux_controlnet.py

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* Update src/diffusers/models/controlnet_flux.py

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* fix

* make copies

* test

* bs

---------

Co-authored-by: haofanwang <haofanwang.ai@gmail.com>
Co-authored-by: haofanwang <haofan@HaofandeMBP.lan>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2024-08-13 18:17:40 +05:30
Dhruv Nair be55fa631f update 2024-08-13 14:11:47 +02:00
林金鹏 cc0513091a Support SD3 controlnet inpainting (#9099)
* add controlnet inpainting pipeline

* [SD3] add controlnet inpaint example

* update example and fix code style

* fix code style with ruff

* Update controlnet_sd3.md : add control inpaint pipeline

* Update docs/source/en/api/pipelines/controlnet_sd3.md

Co-authored-by: Aryan <contact.aryanvs@gmail.com>

* Update docs/source/en/api/pipelines/controlnet_sd3.md

Co-authored-by: Aryan <contact.aryanvs@gmail.com>

* Update docs/source/en/api/pipelines/controlnet_sd3.md

Co-authored-by: Aryan <contact.aryanvs@gmail.com>

* Update src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py

Co-authored-by: Aryan <contact.aryanvs@gmail.com>

* Update __init__.py : add sd3 control pipelines

* Update pipeline : add new param doc & check input reference.

* fix typo

* make style & make quality

* add unittest for sd3 controlnet inpaint

---------

Co-authored-by: 鹏徙 <linjinpeng.ljp@alibaba-inc.com>
Co-authored-by: Aryan <contact.aryanvs@gmail.com>
2024-08-13 17:30:46 +05:30
Sayak Paul 15eb77bc4c Update distributed_inference.md to include a fuller example on distributed inference (#9152)
* Update distributed_inference.md

* Update docs/source/en/training/distributed_inference.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2024-08-12 09:56:03 -07:00
Linoy Tsaban 413ca29b71 [Flux Dreambooth LoRA] - te bug fixes & updates (#9139)
* add requirements + fix link to bghira's guide

* text ecnoder training fixes

* text encoder training fixes

* text encoder training fixes

* text encoder training fixes

* style

* add tests

* fix encode_prompt call

* style

* unpack_latents test

* fix lora saving

* remove default val for max_sequenece_length in encode_prompt

* remove default val for max_sequenece_length in encode_prompt

* style

* testing

* style

* testing

* testing

* style

* fix sizing issue

* style

* revert scaling

* style

* style

* scaling test

* style

* scaling test

* remove model pred operation left from pre-conditioning

* remove model pred operation left from pre-conditioning

* fix trainable params

* remove te2 from casting

* transformer to accelerator

* remove prints

* empty commit
2024-08-12 11:58:03 +05:30
Dhruv Nair 10dc06c8d9 Update Video Loading/Export to use imageio (#9094)
* update

* update

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-08-12 10:19:53 +05:30
Dibbla! 3ece143308 Errata - fix typo (#9100) 2024-08-12 07:30:19 +05:30
Steven Liu 98930ee131 [docs] Resolve internal links to PEFT (#9144)
* resolve peft links

* fuse_lora
2024-08-10 06:37:46 +05:30
Daniel Socek c1079f0887 Fix textual inversion SDXL and add support for 2nd text encoder (#9010)
* Fix textual inversion SDXL and add support for 2nd text encoder

Signed-off-by: Daniel Socek <daniel.socek@intel.com>

* Fix style/quality of text inv for sdxl

Signed-off-by: Daniel Socek <daniel.socek@intel.com>

---------

Signed-off-by: Daniel Socek <daniel.socek@intel.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-08-09 20:23:06 +05:30
Linoy Tsaban 65e30907b5 [Flux] Dreambooth LoRA training scripts (#9086)
* initial commit - dreambooth for flux

* update transformer to be FluxTransformer2DModel

* update training loop and validation inference

* fix sd3->flux docs

* add guidance handling, not sure if it makes sense(?)

* inital dreambooth lora commit

* fix text_ids in compute_text_embeddings

* fix imports of static methods

* fix pipeline loading in readme, remove auto1111 docs for now

* fix pipeline loading in readme, remove auto1111 docs for now, remove some irrelevant text_encoder_3 refs

* Update examples/dreambooth/train_dreambooth_flux.py

Co-authored-by: Bagheera <59658056+bghira@users.noreply.github.com>

* fix te2 loading and remove te2 refs from text encoder training

* fix tokenizer_2 initialization

* remove text_encoder training refs from lora script (for now)

* try with vae in bfloat16, fix model hook save

* fix tokenization

* fix static imports

* fix CLIP import

* remove text_encoder training refs (for now) from lora script

* fix minor bug in encode_prompt, add guidance def in lora script, ...

* fix unpack_latents args

* fix license in readme

* add "none" to weighting_scheme options for uniform sampling

* style

* adapt model saving - remove text encoder refs

* adapt model loading - remove text encoder refs

* initial commit for readme

* Update examples/dreambooth/train_dreambooth_lora_flux.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update examples/dreambooth/train_dreambooth_lora_flux.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* fix vae casting

* remove precondition_outputs

* readme

* readme

* style

* readme

* readme

* update weighting scheme default & docs

* style

* add text_encoder training to lora script, change vae_scale_factor value in both

* style

* text encoder training fixes

* style

* update readme

* minor fixes

* fix te params

* fix te params

---------

Co-authored-by: Bagheera <59658056+bghira@users.noreply.github.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-08-09 07:31:04 +05:30
Sayak Paul cee7c1b0fb Update README.md to include InstantID (#8770)
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2024-08-08 10:14:12 -07:00
Monjoy Narayan Choudhury 1fcb811a8e Add Differential Diffusion to HunyuanDiT. (#9040)
* Add Differential Pipeline.

* Fix Styling Issue using ruff -fix

* Add details to Contributing.md

* Revert "Fix Styling Issue using ruff -fix"

This reverts commit d347de162d.

* Revert "Revert "Fix Styling Issue using ruff -fix""

This reverts commit ce7c3ff216.

* Revert README changes

* Restore README.md

* Update README.md

* Resolved Comments:

* Fix Readme based on review

* Fix formatting after make style

---------

Co-authored-by: Aryan <aryan@huggingface.co>
2024-08-08 18:53:39 +05:30
David Steinberg ae026db7aa Fix a dead link (#9116)
Co-authored-by: Aryan <aryan@huggingface.co>
2024-08-08 18:46:50 +05:30
sayantan sadhu 8e3affc669 fix for lr scheduler in distributed training (#9103)
* fix for lr scheduler in distributed training

* Fixed the recalculation of the total training step section

* Fixed lint error

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-08-08 08:45:48 +05:30
Steven Liu ba7e48455a [docs] Organize model toctree (#9118)
* toctree

* fix
2024-08-08 08:31:58 +05:30
zR 2dad462d9b Add CogVideoX text-to-video generation model (#9082)
* add CogVideoX

---------

Co-authored-by: Aryan <aryan@huggingface.co>
Co-authored-by: sayakpaul <spsayakpaul@gmail.com>
Co-authored-by: Aryan <contact.aryanvs@gmail.com>
Co-authored-by: yiyixuxu <yixu310@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2024-08-06 21:23:57 -10:00
Dhruv Nair e3568d14ba Freenoise change vae_batch_size to decode_chunk_size (#9110)
* update

* update
2024-08-07 12:47:18 +05:30
Aryan f6df22447c [feat] allow sparsectrl to be loaded from single file (#9073)
* allow sparsectrl to be loaded with single file

* update

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2024-08-07 11:12:30 +05:30
latentCall145 9b5180cb5f Flux fp16 inference fix (#9097)
* clipping for fp16

* fix typo

* added fp16 inference to docs

* fix docs typo

* include link for fp16 investigation

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-08-07 10:54:20 +05:30
Aryan 16a93f1a25 [core] FreeNoise (#8948)
* initial work draft for freenoise; needs massive cleanup

* fix freeinit bug

* add animatediff controlnet implementation

* revert attention changes

* add freenoise

* remove old helper functions

* add decode batch size param to all pipelines

* make style

* fix copied from comments

* make fix-copies

* make style

* copy animatediff controlnet implementation from #8972

* add experimental support for num_frames not perfectly fitting context length, ocntext stride

* make unet motion model lora work again based on #8995

* copy load video utils from #8972

* copied from AnimateDiff::prepare_latents

* address the case where last batch of frames does not match length of indices in prepare latents

* decode_batch_size->vae_batch_size; batch vae encode support in animatediff vid2vid

* revert sparsectrl and sdxl freenoise changes

* revert pia

* add freenoise tests

* make fix-copies

* improve docstrings

* add freenoise tests to animatediff controlnet

* update tests

* Update src/diffusers/models/unets/unet_motion_model.py

* add freenoise to animatediff pag

* address review comments

* make style

* update tests

* make fix-copies

* fix error message

* remove copied from comment

* fix imports in tests

* update

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2024-08-07 10:35:18 +05:30
Sayak Paul 2d753b6fb5 fix train_dreambooth_lora_sd3.py loading hook (#9107) 2024-08-07 10:09:47 +05:30
Álvaro Somoza 39e1f7eaa4 [Kolors] Add PAG (#8934)
* txt2img pag added

* autopipe added, fixed case

* style

* apply suggestions

* added fast tests, added todo tests

* revert dummy objects for kolors

* fix pag dummies

* fix test imports

* update pag tests

* add kolor pag to docs

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-08-07 09:29:52 +05:30
Dhruv Nair e1b603dc2e [Single File] Add single file support for Flux Transformer (#9083)
* update

* update

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-08-07 08:49:57 +05:30
Marc Sun e4325606db Fix loading sharded checkpoints when we have variants (#9061)
* Fix loading sharded checkpoint when we have variant

* add test

* remote print

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-08-06 13:38:44 -10:00
Ahn Donghoon (안동훈 / suno) 926daa30f9 add PAG support for Stable Diffusion 3 (#8861)
add pag sd3


---------

Co-authored-by: HyoungwonCho <jhw9811@korea.ac.kr>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: crepejung00 <jaewoojung00@naver.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Aryan <contact.aryanvs@gmail.com>
Co-authored-by: Aryan <aryan@huggingface.co>
2024-08-06 09:11:35 -10:00
Dhruv Nair 325a5de3a9 [Docs] Add community projects section to docs (#9013)
* update

* update

* update
2024-08-06 08:59:39 -07:00
Dhruv Nair 4c6152c2fb update 2024-08-06 12:00:14 +00:00
Vinh H. Pham 87e50a2f1d [Tests] Improve transformers model test suite coverage - Hunyuan DiT (#8916)
* add hunyuan model test

* apply suggestions

* reduce dims further

* reduce dims further

* run make style

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-08-06 12:59:30 +05:30
Aryan a57a7af45c [bug] remove unreachable norm_type=ada_norm_continuous from norm3 initialization conditions (#9006)
remove ada_norm_continuous from norm3 list

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-08-06 07:23:48 +05:30
Sayak Paul 52f1378e64 [Core] add QKV fusion to AuraFlow and PixArt Sigma (#8952)
* add fusion support to pixart

* add to auraflow.

* add tests

* apply review feedback.

* add back args and kwargs

* style
2024-08-05 14:09:37 -10:00
Tolga Cangöz 3dc97bd148 Update CLIPFeatureExtractor to CLIPImageProcessor and DPTFeatureExtractor to DPTImageProcessor (#9002)
* fix: update `CLIPFeatureExtractor` to `CLIPImageProcessor` in codebase

* `make style && make quality`

* Update `DPTFeatureExtractor` to `DPTImageProcessor` in codebase

* `make style`

---------

Co-authored-by: Aryan <aryan@huggingface.co>
2024-08-05 09:20:29 -10:00
omahs 6d32b29239 Fix typos (#9077)
* fix typo
2024-08-05 09:00:08 -10:00
YiYi Xu bc3c73ad0b add sentencepiece as a soft dependency (#9065)
* add sentencepiece as  soft dependency for kolors

* up

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-08-05 08:04:51 -10:00
Sayak Paul 5934873b8f [Docs] add stable cascade unet doc. (#9066)
* add stable cascade unet doc.

* fix path
2024-08-05 21:28:48 +05:30
192 changed files with 15112 additions and 1027 deletions
+1
View File
@@ -202,6 +202,7 @@ Also, say 👋 in our public Discord channel <a href="https://discord.gg/G7tWnz9
- https://github.com/microsoft/TaskMatrix
- https://github.com/invoke-ai/InvokeAI
- https://github.com/InstantID/InstantID
- https://github.com/apple/ml-stable-diffusion
- https://github.com/Sanster/lama-cleaner
- https://github.com/IDEA-Research/Grounded-Segment-Anything
+76 -56
View File
@@ -190,6 +190,10 @@
- local: conceptual/evaluation
title: Evaluating Diffusion Models
title: Conceptual Guides
- sections:
- local: community_projects
title: Projects built with Diffusers
title: Community Projects
- sections:
- isExpanded: false
sections:
@@ -219,62 +223,76 @@
sections:
- local: api/models/overview
title: Overview
- local: api/models/unet
title: UNet1DModel
- local: api/models/unet2d
title: UNet2DModel
- local: api/models/unet2d-cond
title: UNet2DConditionModel
- local: api/models/unet3d-cond
title: UNet3DConditionModel
- local: api/models/unet-motion
title: UNetMotionModel
- local: api/models/uvit2d
title: UViT2DModel
- local: api/models/vq
title: VQModel
- local: api/models/autoencoderkl
title: AutoencoderKL
- local: api/models/asymmetricautoencoderkl
title: AsymmetricAutoencoderKL
- local: api/models/autoencoder_tiny
title: Tiny AutoEncoder
- local: api/models/autoencoder_oobleck
title: Oobleck AutoEncoder
- local: api/models/consistency_decoder_vae
title: ConsistencyDecoderVAE
- local: api/models/transformer2d
title: Transformer2DModel
- local: api/models/pixart_transformer2d
title: PixArtTransformer2DModel
- local: api/models/dit_transformer2d
title: DiTTransformer2DModel
- local: api/models/hunyuan_transformer2d
title: HunyuanDiT2DModel
- local: api/models/aura_flow_transformer2d
title: AuraFlowTransformer2DModel
- local: api/models/flux_transformer
title: FluxTransformer2DModel
- local: api/models/latte_transformer3d
title: LatteTransformer3DModel
- local: api/models/lumina_nextdit2d
title: LuminaNextDiT2DModel
- local: api/models/transformer_temporal
title: TransformerTemporalModel
- local: api/models/sd3_transformer2d
title: SD3Transformer2DModel
- local: api/models/stable_audio_transformer
title: StableAudioDiTModel
- local: api/models/prior_transformer
title: PriorTransformer
- local: api/models/controlnet
title: ControlNetModel
- local: api/models/controlnet_hunyuandit
title: HunyuanDiT2DControlNetModel
- local: api/models/controlnet_sd3
title: SD3ControlNetModel
- local: api/models/controlnet_sparsectrl
title: SparseControlNetModel
- sections:
- local: api/models/controlnet
title: ControlNetModel
- local: api/models/controlnet_hunyuandit
title: HunyuanDiT2DControlNetModel
- local: api/models/controlnet_sd3
title: SD3ControlNetModel
- local: api/models/controlnet_sparsectrl
title: SparseControlNetModel
title: ControlNets
- sections:
- local: api/models/aura_flow_transformer2d
title: AuraFlowTransformer2DModel
- local: api/models/cogvideox_transformer3d
title: CogVideoXTransformer3DModel
- local: api/models/dit_transformer2d
title: DiTTransformer2DModel
- local: api/models/flux_transformer
title: FluxTransformer2DModel
- local: api/models/hunyuan_transformer2d
title: HunyuanDiT2DModel
- local: api/models/latte_transformer3d
title: LatteTransformer3DModel
- local: api/models/lumina_nextdit2d
title: LuminaNextDiT2DModel
- local: api/models/pixart_transformer2d
title: PixArtTransformer2DModel
- local: api/models/prior_transformer
title: PriorTransformer
- local: api/models/sd3_transformer2d
title: SD3Transformer2DModel
- local: api/models/stable_audio_transformer
title: StableAudioDiTModel
- local: api/models/transformer2d
title: Transformer2DModel
- local: api/models/transformer_temporal
title: TransformerTemporalModel
title: Transformers
- sections:
- local: api/models/stable_cascade_unet
title: StableCascadeUNet
- local: api/models/unet
title: UNet1DModel
- local: api/models/unet2d
title: UNet2DModel
- local: api/models/unet2d-cond
title: UNet2DConditionModel
- local: api/models/unet3d-cond
title: UNet3DConditionModel
- local: api/models/unet-motion
title: UNetMotionModel
- local: api/models/uvit2d
title: UViT2DModel
title: UNets
- sections:
- local: api/models/autoencoderkl
title: AutoencoderKL
- local: api/models/autoencoderkl_cogvideox
title: AutoencoderKLCogVideoX
- local: api/models/asymmetricautoencoderkl
title: AsymmetricAutoencoderKL
- local: api/models/consistency_decoder_vae
title: ConsistencyDecoderVAE
- local: api/models/autoencoder_oobleck
title: Oobleck AutoEncoder
- local: api/models/autoencoder_tiny
title: Tiny AutoEncoder
- local: api/models/vq
title: VQModel
title: VAEs
title: Models
- isExpanded: false
sections:
@@ -296,6 +314,8 @@
title: AutoPipeline
- local: api/pipelines/blip_diffusion
title: BLIP-Diffusion
- local: api/pipelines/cogvideox
title: CogVideoX
- local: api/pipelines/consistency_models
title: Consistency Models
- local: api/pipelines/controlnet
@@ -53,6 +53,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load:
- [`AutoencoderKLCogVideoX`]
- [`ControlNetModel`]
- [`SD3Transformer2DModel`]
- [`FluxTransformer2DModel`]
## FromSingleFileMixin
@@ -11,18 +11,14 @@ specific language governing permissions and limitations under the License. -->
# AutoencoderKLCogVideoX
The 3D variational autoencoder (VAE) model with KL loss using CogVideoX.
The 3D variational autoencoder (VAE) model with KL loss used in [CogVideoX](https://github.com/THUDM/CogVideo) was introduced in [CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://github.com/THUDM/CogVideo/blob/main/resources/CogVideoX.pdf) by Tsinghua University & ZhipuAI.
## Loading from the original format
The model can be loaded with the following code snippet.
By default, the [`AutoencoderKLCogVideoX`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded from the original format using [`FromOriginalModelMixin.from_single_file`] as follows:
```py
```python
from diffusers import AutoencoderKLCogVideoX
url = "THUDM/CogVideoX-2b" # can also be a local file
model = AutoencoderKLCogVideoX.from_single_file(url)
vae = AutoencoderKLCogVideoX.from_pretrained("THUDM/CogVideoX-2b", subfolder="vae", torch_dtype=torch.float16).to("cuda")
```
## AutoencoderKLCogVideoX
@@ -32,38 +28,10 @@ model = AutoencoderKLCogVideoX.from_single_file(url)
- encode
- all
## CogVideoXSafeConv3d
## AutoencoderKLOutput
[[autodoc]] CogVideoXSafeConv3d
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
## CogVideoXCausalConv3d
## DecoderOutput
[[autodoc]] CogVideoXCausalConv3d
## CogVideoXSpatialNorm3D
[[autodoc]] CogVideoXSpatialNorm3D
## CogVideoXResnetBlock3D
[[autodoc]] CogVideoXResnetBlock3D
## CogVideoXDownBlock3D
[[autodoc]] CogVideoXDownBlock3D
## CogVideoXMidBlock3D
[[autodoc]] CogVideoXMidBlock3D
## CogVideoXUpBlock3D
[[autodoc]] CogVideoXUpBlock3D
## CogVideoXEncoder3D
[[autodoc]] CogVideoXEncoder3D
## CogVideoXDecoder3D
[[autodoc]] CogVideoXDecoder3D
[[autodoc]] models.autoencoders.vae.DecoderOutput
@@ -9,10 +9,22 @@ Unless required by applicable law or agreed to in writing, software distributed
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
## CogVideoXTransformer3DModel
# CogVideoXTransformer3DModel
A Diffusion Transformer model for 3D data from [CogVideoX](https://github.com/THUDM/CogVideoX).
A Diffusion Transformer model for 3D data from [CogVideoX](https://github.com/THUDM/CogVideo) was introduced in [CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://github.com/THUDM/CogVideo/blob/main/resources/CogVideoX.pdf) by Tsinghua University & ZhipuAI.
The model can be loaded with the following code snippet.
```python
from diffusers import CogVideoXTransformer3DModel
vae = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-2b", subfolder="transformer", torch_dtype=torch.float16).to("cuda")
```
## CogVideoXTransformer3DModel
[[autodoc]] CogVideoXTransformer3DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
@@ -0,0 +1,19 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# StableCascadeUNet
A UNet model from the [Stable Cascade pipeline](../pipelines/stable_cascade.md).
## StableCascadeUNet
[[autodoc]] models.unets.unet_stable_cascade.StableCascadeUNet
+30 -21
View File
@@ -10,18 +10,16 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
## TODO: The paper is still being written.
# limitations under the License.
-->
# CogVideoX
[TODO]() from Tsinghua University & ZhipuAI.
[CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://arxiv.org/abs/2408.06072) from Tsinghua University & ZhipuAI, by Zhuoyi Yang, Jiayan Teng, Wendi Zheng, Ming Ding, Shiyu Huang, Jiazheng Xu, Yuanming Yang, Wenyi Hong, Xiaohan Zhang, Guanyu Feng, Da Yin, Xiaotao Gu, Yuxuan Zhang, Weihan Wang, Yean Cheng, Ting Liu, Bin Xu, Yuxiao Dong, Jie Tang.
The abstract from the paper is:
The paper is still being written.
*We introduce CogVideoX, a large-scale diffusion transformer model designed for generating videos based on text prompts. To efficently model video data, we propose to levearge a 3D Variational Autoencoder (VAE) to compresses videos along both spatial and temporal dimensions. To improve the text-video alignment, we propose an expert transformer with the expert adaptive LayerNorm to facilitate the deep fusion between the two modalities. By employing a progressive training technique, CogVideoX is adept at producing coherent, long-duration videos characterized by significant motion. In addition, we develop an effectively text-video data processing pipeline that includes various data preprocessing strategies and a video captioning method. It significantly helps enhance the performance of CogVideoX, improving both generation quality and semantic alignment. Results show that CogVideoX demonstrates state-of-the-art performance across both multiple machine metrics and human evaluations. The model weight of CogVideoX-2B is publicly available at https://github.com/THUDM/CogVideo.*
<Tip>
@@ -29,7 +27,9 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
</Tip>
### Inference
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).
## Inference
Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency.
@@ -37,38 +37,46 @@ First, load the pipeline:
```python
import torch
from diffusers import LattePipeline
from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video
pipeline = LattePipeline.from_pretrained(
"THUDM/CogVideoX-2b", torch_dtype=torch.float16
).to("cuda")
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b").to("cuda")
```
Then change the memory layout of the pipelines `transformer` and `vae` components to `torch.channels-last`:
Then change the memory layout of the pipelines `transformer` component to `torch.channels_last`:
```python
pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.vae.to(memory_format=torch.channels_last)
pipe.transformer.to(memory_format=torch.channels_last)
```
Finally, compile the components and run inference:
```python
pipeline.transformer = torch.compile(pipeline.transformer)
pipeline.vae.decode = torch.compile(pipeline.vae.decode)
pipe.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
# CogVideoX works very well with long and well-described prompts
# CogVideoX works well with long and well-described prompts
prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance."
video = pipeline(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
```
The [benchmark](TODO: link) results on an 80GB A100 machine are:
The [benchmark](https://gist.github.com/a-r-r-o-w/5183d75e452a368fd17448fcc810bd3f) results on an 80GB A100 machine are:
```
Without torch.compile(): Average inference time: TODO seconds.
With torch.compile(): Average inference time: TODO seconds.
Without torch.compile(): Average inference time: 96.89 seconds.
With torch.compile(): Average inference time: 76.27 seconds.
```
### Memory optimization
CogVideoX requires about 19 GB of GPU memory to decode 49 frames (6 seconds of video at 8 FPS) with output resolution 720x480 (W x H), which makes it not possible to run on consumer GPUs or free-tier T4 Colab. The following memory optimizations could be used to reduce the memory footprint. For replication, you can refer to [this](https://gist.github.com/a-r-r-o-w/3959a03f15be5c9bd1fe545b09dfcc93) script.
- `pipe.enable_model_cpu_offload()`:
- Without enabling cpu offloading, memory usage is `33 GB`
- With enabling cpu offloading, memory usage is `19 GB`
- `pipe.vae.enable_tiling()`:
- With enabling cpu offloading and tiling, memory usage is `11 GB`
- `pipe.vae.enable_slicing()`
## CogVideoXPipeline
[[autodoc]] CogVideoXPipeline
@@ -76,4 +84,5 @@ With torch.compile(): Average inference time: TODO seconds.
- __call__
## CogVideoXPipelineOutput
[[autodoc]] pipelines.pipline_cogvideo.pipeline_output.CogVideoXPipelineOutput
[[autodoc]] pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput
+16 -2
View File
@@ -1,4 +1,4 @@
<!--Copyright 2023 The HuggingFace Team and The InstantX Team. All rights reserved.
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
@@ -22,7 +22,16 @@ The abstract from the paper is:
*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*
This code is implemented by [The InstantX Team](https://huggingface.co/InstantX). You can find pre-trained checkpoints for SD3-ControlNet on [The InstantX Team](https://huggingface.co/InstantX) Hub profile.
This controlnet code is mainly implemented by [The InstantX Team](https://huggingface.co/InstantX). The inpainting-related code was developed by [The Alimama Creative Team](https://huggingface.co/alimama-creative). You can find pre-trained checkpoints for SD3-ControlNet in the table below:
| ControlNet type | Developer | Link |
| -------- | ---------- | ---- |
| Canny | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/InstantX/SD3-Controlnet-Canny) |
| Pose | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/InstantX/SD3-Controlnet-Pose) |
| Tile | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/InstantX/SD3-Controlnet-Tile) |
| Inpainting | [The AlimamaCreative Team](https://huggingface.co/alimama-creative) | [link](https://huggingface.co/alimama-creative/SD3-Controlnet-Inpainting) |
<Tip>
@@ -35,5 +44,10 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
- all
- __call__
## StableDiffusion3ControlNetInpaintingPipeline
[[autodoc]] pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet_inpainting.StableDiffusion3ControlNetInpaintingPipeline
- all
- __call__
## StableDiffusion3PipelineOutput
[[autodoc]] pipelines.stable_diffusion_3.pipeline_output.StableDiffusion3PipelineOutput
+84 -3
View File
@@ -37,7 +37,7 @@ Both checkpoints have slightly difference usage which we detail below.
```python
import torch
from diffusers import FluxPipeline
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
@@ -61,7 +61,7 @@ out.save("image.png")
```python
import torch
from diffusers import FluxPipeline
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
@@ -77,8 +77,89 @@ out = pipe(
out.save("image.png")
```
## Running FP16 inference
Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details.
FP16 inference code:
```python
import torch
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) # can replace schnell with dev
# to run on low vram GPUs (i.e. between 4 and 32 GB VRAM)
pipe.enable_sequential_cpu_offload()
pipe.vae.enable_slicing()
pipe.vae.enable_tiling()
pipe.to(torch.float16) # casting here instead of in the pipeline constructor because doing so in the constructor loads all models into CPU memory at once
prompt = "A cat holding a sign that says hello world"
out = pipe(
prompt=prompt,
guidance_scale=0.,
height=768,
width=1360,
num_inference_steps=4,
max_sequence_length=256,
).images[0]
out.save("image.png")
```
## Single File Loading for the `FluxTransformer2DModel`
The `FluxTransformer2DModel` supports loading checkpoints in the original format shipped by Black Forest Labs. This is also useful when trying to load finetunes or quantized versions of the models that have been published by the community.
<Tip>
`FP8` inference can be brittle depending on the GPU type, CUDA version, and `torch` version that you are using. It is recommended that you use the `optimum-quanto` library in order to run FP8 inference on your machine.
</Tip>
The following example demonstrates how to run Flux with less than 16GB of VRAM.
First install `optimum-quanto`
```shell
pip install optimum-quanto
```
Then run the following example
```python
import torch
from diffusers import FluxTransformer2DModel, FluxPipeline
from transformers import T5EncoderModel, CLIPTextModel
from optimum.quanto import freeze, qfloat8, quantize
bfl_repo = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16
transformer = FluxTransformer2DModel.from_single_file("https://huggingface.co/Kijai/flux-fp8/blob/main/flux1-dev-fp8.safetensors", torch_dtype=dtype)
quantize(transformer, weights=qfloat8)
freeze(transformer)
text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)
pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype)
pipe.transformer = transformer
pipe.text_encoder_2 = text_encoder_2
pipe.enable_model_cpu_offload()
prompt = "A cat holding a sign that says hello world"
image = pipe(
prompt,
guidance_scale=3.5,
output_type="pil",
num_inference_steps=20,
generator=torch.Generator("cpu").manual_seed(0)
).images[0]
image.save("flux-fp8-dev.png")
```
## FluxPipeline
[[autodoc]] FluxPipeline
- all
- __call__
- __call__
+2 -2
View File
@@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/kolors/kolors_header_collage.png)
Kolors is a large-scale text-to-image generation model based on latent diffusion, developed by [the Kuaishou Kolors team](kwai-kolors@kuaishou.com). Trained on billions of text-image pairs, Kolors exhibits significant advantages over both open-source and closed-source models in visual quality, complex semantic accuracy, and text rendering for both Chinese and English characters. Furthermore, Kolors supports both Chinese and English inputs, demonstrating strong performance in understanding and generating Chinese-specific content. For more details, please refer to this [technical report](https://github.com/Kwai-Kolors/Kolors/blob/master/imgs/Kolors_paper.pdf).
Kolors is a large-scale text-to-image generation model based on latent diffusion, developed by [the Kuaishou Kolors team](https://github.com/Kwai-Kolors/Kolors). Trained on billions of text-image pairs, Kolors exhibits significant advantages over both open-source and closed-source models in visual quality, complex semantic accuracy, and text rendering for both Chinese and English characters. Furthermore, Kolors supports both Chinese and English inputs, demonstrating strong performance in understanding and generating Chinese-specific content. For more details, please refer to this [technical report](https://github.com/Kwai-Kolors/Kolors/blob/master/imgs/Kolors_paper.pdf).
The abstract from the technical report is:
@@ -74,7 +74,7 @@ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
pipe = KolorsPipeline.from_pretrained(
"Kwai-Kolors/Kolors-diffusers", image_encoder=image_encoder, torch_dtype=torch.float16, variant="fp16"
).to("cuda")
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)
pipe.load_ip_adapter(
+12 -1
View File
@@ -20,7 +20,7 @@ The abstract from the paper is:
*Recent studies have demonstrated that diffusion models are capable of generating high-quality samples, but their quality heavily depends on sampling guidance techniques, such as classifier guidance (CG) and classifier-free guidance (CFG). These techniques are often not applicable in unconditional generation or in various downstream tasks such as image restoration. In this paper, we propose a novel sampling guidance, called Perturbed-Attention Guidance (PAG), which improves diffusion sample quality across both unconditional and conditional settings, achieving this without requiring additional training or the integration of external modules. PAG is designed to progressively enhance the structure of samples throughout the denoising process. It involves generating intermediate samples with degraded structure by substituting selected self-attention maps in diffusion U-Net with an identity matrix, by considering the self-attention mechanisms' ability to capture structural information, and guiding the denoising process away from these degraded samples. In both ADM and Stable Diffusion, PAG surprisingly improves sample quality in conditional and even unconditional scenarios. Moreover, PAG significantly improves the baseline performance in various downstream tasks where existing guidances such as CG or CFG cannot be fully utilized, including ControlNet with empty prompts and image restoration such as inpainting and deblurring.*
PAG can be used by specifying the `pag_applied_layers` as a parameter when instantiating a PAG pipeline. It can be a single string or a list of strings. Each string can be a unique layer identifier or a regular expression to identify one or more layers.
PAG can be used by specifying the `pag_applied_layers` as a parameter when instantiating a PAG pipeline. It can be a single string or a list of strings. Each string can be a unique layer identifier or a regular expression to identify one or more layers.
- Full identifier as a normal string: `down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor`
- Full identifier as a RegEx: `down_blocks.2.(attentions|motion_modules).0.transformer_blocks.0.attn1.processor`
@@ -43,6 +43,11 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial
- all
- __call__
## KolorsPAGPipeline
[[autodoc]] KolorsPAGPipeline
- all
- __call__
## StableDiffusionPAGPipeline
[[autodoc]] StableDiffusionPAGPipeline
- all
@@ -74,6 +79,12 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial
- __call__
## StableDiffusion3PAGPipeline
[[autodoc]] StableDiffusion3PAGPipeline
- all
- __call__
## PixArtSigmaPAGPipeline
[[autodoc]] PixArtSigmaPAGPipeline
- all
+1 -1
View File
@@ -21,7 +21,7 @@ Stable Audio is trained on a corpus of around 48k audio recordings, where around
The abstract of the paper is the following:
*Open generative models are vitally important for the community, allowing for fine-tunes and serving as baselines when presenting new models. However, most current text-to-audio models are private and not accessible for artists and researchers to build upon. Here we describe the architecture and training process of a new open-weights text-to-audio model trained with Creative Commons data. Our evaluation shows that the model's performance is competitive with the state-of-the-art across various metrics. Notably, the reported FDopenl3 results (measuring the realism of the generations) showcase its potential for high-quality stereo sound synthesis at 44.1kHz.*
This pipeline was contributed by [Yoach Lacombe](https://huggingface.co/ylacombe). The original codebase can be found at [Stability-AI/stable-audio-tool](https://github.com/Stability-AI/stable-audio-tool).
This pipeline was contributed by [Yoach Lacombe](https://huggingface.co/ylacombe). The original codebase can be found at [Stability-AI/stable-audio-tools](https://github.com/Stability-AI/stable-audio-tools).
## Tips
+78
View File
@@ -0,0 +1,78 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Community Projects
Welcome to Community Projects. This space is dedicated to showcasing the incredible work and innovative applications created by our vibrant community using the `diffusers` library.
This section aims to:
- Highlight diverse and inspiring projects built with `diffusers`
- Foster knowledge sharing within our community
- Provide real-world examples of how `diffusers` can be leveraged
Happy exploring, and thank you for being part of the Diffusers community!
<table>
<tr>
<th>Project Name</th>
<th>Description</th>
</tr>
<tr style="border-top: 2px solid black">
<td><a href="https://github.com/carson-katri/dream-textures"> dream-textures </a></td>
<td>Stable Diffusion built-in to Blender</td>
</tr>
<tr style="border-top: 2px solid black">
<td><a href="https://github.com/megvii-research/HiDiffusion"> HiDiffusion </a></td>
<td>Increases the resolution and speed of your diffusion model by only adding a single line of code</td>
</tr>
<tr style="border-top: 2px solid black">
<td><a href="https://github.com/lllyasviel/IC-Light"> IC-Light </a></td>
<td>IC-Light is a project to manipulate the illumination of images</td>
</tr>
<tr style="border-top: 2px solid black">
<td><a href="https://github.com/InstantID/InstantID"> InstantID </a></td>
<td>InstantID : Zero-shot Identity-Preserving Generation in Seconds</td>
</tr>
<tr style="border-top: 2px solid black">
<td><a href="https://github.com/Sanster/IOPaint"> IOPaint </a></td>
<td>Image inpainting tool powered by SOTA AI Model. Remove any unwanted object, defect, people from your pictures or erase and replace(powered by stable diffusion) any thing on your pictures.</td>
</tr>
<tr style="border-top: 2px solid black">
<td><a href="https://github.com/bmaltais/kohya_ss"> Kohya </a></td>
<td>Gradio GUI for Kohya's Stable Diffusion trainers</td>
</tr>
<tr style="border-top: 2px solid black">
<td><a href="https://github.com/magic-research/magic-animate"> MagicAnimate </a></td>
<td>MagicAnimate: Temporally Consistent Human Image Animation using Diffusion Model</td>
</tr>
<tr style="border-top: 2px solid black">
<td><a href="https://github.com/levihsu/OOTDiffusion"> OOTDiffusion </a></td>
<td>Outfitting Fusion based Latent Diffusion for Controllable Virtual Try-on</td>
</tr>
<tr style="border-top: 2px solid black">
<td><a href="https://github.com/vladmandic/automatic"> SD.Next </a></td>
<td>SD.Next: Advanced Implementation of Stable Diffusion and other Diffusion-based generative image models</td>
</tr>
<tr style="border-top: 2px solid black">
<td><a href="https://github.com/ashawkey/stable-dreamfusion"> stable-dreamfusion </a></td>
<td>Text-to-3D & Image-to-3D & Mesh Exportation with NeRF + Diffusion</td>
</tr>
<tr style="border-top: 2px solid black">
<td><a href="https://github.com/HVision-NKU/StoryDiffusion"> StoryDiffusion </a></td>
<td>StoryDiffusion can create a magic story by generating consistent images and videos.</td>
</tr>
<tr style="border-top: 2px solid black">
<td><a href="https://github.com/cumulo-autumn/StreamDiffusion"> StreamDiffusion </a></td>
<td>A Pipeline-Level Solution for Real-Time Interactive Generation</td>
</tr>
</table>
+2
View File
@@ -125,3 +125,5 @@ image
<figcaption class="mt-2 text-center text-sm text-gray-500">distilled Stable Diffusion + Tiny AutoEncoder</figcaption>
</div>
</div>
More tiny autoencoder models for other Stable Diffusion models, like Stable Diffusion 3, are available from [madebyollin](https://huggingface.co/madebyollin).
@@ -48,7 +48,7 @@ accelerate launch run_distributed.py --num_processes=2
<Tip>
To learn more, take a look at the [Distributed Inference with 🤗 Accelerate](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate) guide.
Refer to this minimal example [script](https://gist.github.com/sayakpaul/cfaebd221820d7b43fae638b4dfa01ba) for running inference across multiple GPUs. To learn more, take a look at the [Distributed Inference with 🤗 Accelerate](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate) guide.
</Tip>
@@ -108,4 +108,4 @@ torchrun run_distributed.py --nproc_per_node=2
```
> [!TIP]
> You can use `device_map` within a [`DiffusionPipeline`] to distribute its model-level components on multiple devices. Refer to the [Device placement](../tutorials/inference_with_big_models#device-placement) guide to learn more.
> You can use `device_map` within a [`DiffusionPipeline`] to distribute its model-level components on multiple devices. Refer to the [Device placement](../tutorials/inference_with_big_models#device-placement) guide to learn more.
+3 -3
View File
@@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
[InstructPix2Pix](https://hf.co/papers/2211.09800) is a Stable Diffusion model trained to edit images from human-provided instructions. For example, your prompt can be "turn the clouds rainy" and the model will edit the input image accordingly. This model is conditioned on the text prompt (or editing instruction) and the input image.
This guide will explore the [train_instruct_pix2pix.py](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py) training script to help you become familiar with it, and how you can adapt it for your own use-case.
This guide will explore the [train_instruct_pix2pix.py](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py) training script to help you become familiar with it, and how you can adapt it for your own use case.
Before running the script, make sure you install the library from source:
@@ -117,7 +117,7 @@ optimizer = optimizer_cls(
)
```
Next, the edited images and and edit instructions are [preprocessed](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L624) and [tokenized](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L610C24-L610C24). It is important the same image transformations are applied to the original and edited images.
Next, the edited images and edit instructions are [preprocessed](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L624) and [tokenized](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L610C24-L610C24). It is important the same image transformations are applied to the original and edited images.
```py
def preprocess_train(examples):
@@ -249,4 +249,4 @@ The SDXL training script is discussed in more detail in the [SDXL training](sdxl
Congratulations on training your own InstructPix2Pix model! 🥳 To learn more about the model, it may be helpful to:
- Read the [Instruction-tuning Stable Diffusion with InstructPix2Pix](https://huggingface.co/blog/instruction-tuning-sd) blog post to learn more about some experiments we've done with InstructPix2Pix, dataset preparation, and results for different instructions.
- Read the [Instruction-tuning Stable Diffusion with InstructPix2Pix](https://huggingface.co/blog/instruction-tuning-sd) blog post to learn more about some experiments we've done with InstructPix2Pix, dataset preparation, and results for different instructions.
@@ -34,7 +34,7 @@ pipe_id = "stabilityai/stable-diffusion-xl-base-1.0"
pipe = DiffusionPipeline.from_pretrained(pipe_id, torch_dtype=torch.float16).to("cuda")
```
Next, load a [CiroN2022/toy-face](https://huggingface.co/CiroN2022/toy-face) adapter with the [`~diffusers.loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] method. With the 🤗 PEFT integration, you can assign a specific `adapter_name` to the checkpoint, which let's you easily switch between different LoRA checkpoints. Let's call this adapter `"toy"`.
Next, load a [CiroN2022/toy-face](https://huggingface.co/CiroN2022/toy-face) adapter with the [`~diffusers.loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] method. With the 🤗 PEFT integration, you can assign a specific `adapter_name` to the checkpoint, which lets you easily switch between different LoRA checkpoints. Let's call this adapter `"toy"`.
```python
pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
+2 -2
View File
@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
# Pipeline callbacks
The denoising loop of a pipeline can be modified with custom defined functions using the `callback_on_step_end` parameter. The callback function is executed at the end of each step, and modifies the pipeline attributes and variables for the next step. This is really useful for *dynamically* adjusting certain pipeline attributes or modifying tensor variables. This versatility allows for interesting use-cases such as changing the prompt embeddings at each timestep, assigning different weights to the prompt embeddings, and editing the guidance scale. With callbacks, you can implement new features without modifying the underlying code!
The denoising loop of a pipeline can be modified with custom defined functions using the `callback_on_step_end` parameter. The callback function is executed at the end of each step, and modifies the pipeline attributes and variables for the next step. This is really useful for *dynamically* adjusting certain pipeline attributes or modifying tensor variables. This versatility allows for interesting use cases such as changing the prompt embeddings at each timestep, assigning different weights to the prompt embeddings, and editing the guidance scale. With callbacks, you can implement new features without modifying the underlying code!
> [!TIP]
> 🤗 Diffusers currently only supports `callback_on_step_end`, but feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you have a cool use-case and require a callback function with a different execution point!
@@ -75,7 +75,7 @@ out.images[0].save("official_callback.png")
<figcaption class="mt-2 text-center text-sm text-gray-500">without SDXLCFGCutoffCallback</figcaption>
</div>
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/with_cfg_callback.png" alt="generated image of a a sports car at the road with cfg callback" />
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/with_cfg_callback.png" alt="generated image of a sports car at the road with cfg callback" />
<figcaption class="mt-2 text-center text-sm text-gray-500">with SDXLCFGCutoffCallback</figcaption>
</div>
</div>
@@ -289,9 +289,9 @@ scheduler = DPMSolverMultistepScheduler.from_pretrained(pipe_id, subfolder="sche
3. Load an image processor:
```python
from transformers import CLIPFeatureExtractor
from transformers import CLIPImageProcessor
feature_extractor = CLIPFeatureExtractor.from_pretrained(pipe_id, subfolder="feature_extractor")
feature_extractor = CLIPImageProcessor.from_pretrained(pipe_id, subfolder="feature_extractor")
```
<Tip warning={true}>
@@ -212,14 +212,14 @@ TCD-LoRA is very versatile, and it can be combined with other adapter types like
import torch
import numpy as np
from PIL import Image
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
from transformers import DPTImageProcessor, DPTForDepthEstimation
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
from diffusers.utils import load_image, make_image_grid
from scheduling_tcd import TCDScheduler
device = "cuda"
depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to(device)
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
def get_depth_map(image):
image = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)
+15 -15
View File
@@ -14,9 +14,9 @@ specific language governing permissions and limitations under the License.
It can be fun and creative to use multiple [LoRAs]((https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora)) together to generate something entirely new and unique. This works by merging multiple LoRA weights together to produce images that are a blend of different styles. Diffusers provides a few methods to merge LoRAs depending on *how* you want to merge their weights, which can affect image quality.
This guide will show you how to merge LoRAs using the [`~loaders.UNet2DConditionLoadersMixin.set_adapters`] and [`~peft.LoraModel.add_weighted_adapter`] methods. To improve inference speed and reduce memory-usage of merged LoRAs, you'll also see how to use the [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] method to fuse the LoRA weights with the original weights of the underlying model.
This guide will show you how to merge LoRAs using the [`~loaders.PeftAdapterMixin.set_adapters`] and [add_weighted_adapter](https://huggingface.co/docs/peft/package_reference/lora#peft.LoraModel.add_weighted_adapter) methods. To improve inference speed and reduce memory-usage of merged LoRAs, you'll also see how to use the [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] method to fuse the LoRA weights with the original weights of the underlying model.
For this guide, load a Stable Diffusion XL (SDXL) checkpoint and the [KappaNeuro/studio-ghibli-style]() and [Norod78/sdxl-chalkboarddrawing-lora]() LoRAs with the [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] method. You'll need to assign each LoRA an `adapter_name` to combine them later.
For this guide, load a Stable Diffusion XL (SDXL) checkpoint and the [KappaNeuro/studio-ghibli-style](https://huggingface.co/KappaNeuro/studio-ghibli-style) and [Norod78/sdxl-chalkboarddrawing-lora](https://huggingface.co/Norod78/sdxl-chalkboarddrawing-lora) LoRAs with the [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] method. You'll need to assign each LoRA an `adapter_name` to combine them later.
```py
from diffusers import DiffusionPipeline
@@ -29,7 +29,7 @@ pipeline.load_lora_weights("lordjia/by-feng-zikai", weight_name="fengzikai_v1.0_
## set_adapters
The [`~loaders.UNet2DConditionLoadersMixin.set_adapters`] method merges LoRA adapters by concatenating their weighted matrices. Use the adapter name to specify which LoRAs to merge, and the `adapter_weights` parameter to control the scaling for each LoRA. For example, if `adapter_weights=[0.5, 0.5]`, then the merged LoRA output is an average of both LoRAs. Try adjusting the adapter weights to see how it affects the generated image!
The [`~loaders.PeftAdapterMixin.set_adapters`] method merges LoRA adapters by concatenating their weighted matrices. Use the adapter name to specify which LoRAs to merge, and the `adapter_weights` parameter to control the scaling for each LoRA. For example, if `adapter_weights=[0.5, 0.5]`, then the merged LoRA output is an average of both LoRAs. Try adjusting the adapter weights to see how it affects the generated image!
```py
pipeline.set_adapters(["ikea", "feng"], adapter_weights=[0.7, 0.8])
@@ -47,19 +47,19 @@ image
## add_weighted_adapter
> [!WARNING]
> This is an experimental method that adds PEFTs [`~peft.LoraModel.add_weighted_adapter`] method to Diffusers to enable more efficient merging methods. Check out this [issue](https://github.com/huggingface/diffusers/issues/6892) if you're interested in learning more about the motivation and design behind this integration.
> This is an experimental method that adds PEFTs [add_weighted_adapter](https://huggingface.co/docs/peft/package_reference/lora#peft.LoraModel.add_weighted_adapter) method to Diffusers to enable more efficient merging methods. Check out this [issue](https://github.com/huggingface/diffusers/issues/6892) if you're interested in learning more about the motivation and design behind this integration.
The [`~peft.LoraModel.add_weighted_adapter`] method provides access to more efficient merging method such as [TIES and DARE](https://huggingface.co/docs/peft/developer_guides/model_merging). To use these merging methods, make sure you have the latest stable version of Diffusers and PEFT installed.
The [add_weighted_adapter](https://huggingface.co/docs/peft/package_reference/lora#peft.LoraModel.add_weighted_adapter) method provides access to more efficient merging method such as [TIES and DARE](https://huggingface.co/docs/peft/developer_guides/model_merging). To use these merging methods, make sure you have the latest stable version of Diffusers and PEFT installed.
```bash
pip install -U diffusers peft
```
There are three steps to merge LoRAs with the [`~peft.LoraModel.add_weighted_adapter`] method:
There are three steps to merge LoRAs with the [add_weighted_adapter](https://huggingface.co/docs/peft/package_reference/lora#peft.LoraModel.add_weighted_adapter) method:
1. Create a [`~peft.PeftModel`] from the underlying model and LoRA checkpoint.
1. Create a [PeftModel](https://huggingface.co/docs/peft/package_reference/peft_model#peft.PeftModel) from the underlying model and LoRA checkpoint.
2. Load a base UNet model and the LoRA adapters.
3. Merge the adapters using the [`~peft.LoraModel.add_weighted_adapter`] method and the merging method of your choice.
3. Merge the adapters using the [add_weighted_adapter](https://huggingface.co/docs/peft/package_reference/lora#peft.LoraModel.add_weighted_adapter) method and the merging method of your choice.
Let's dive deeper into what these steps entail.
@@ -92,7 +92,7 @@ pipeline = DiffusionPipeline.from_pretrained(
pipeline.load_lora_weights("ostris/ikea-instructions-lora-sdxl", weight_name="ikea_instructions_xl_v1_5.safetensors", adapter_name="ikea")
```
Now you'll create a [`~peft.PeftModel`] from the loaded LoRA checkpoint by combining the SDXL UNet and the LoRA UNet from the pipeline.
Now you'll create a [PeftModel](https://huggingface.co/docs/peft/package_reference/peft_model#peft.PeftModel) from the loaded LoRA checkpoint by combining the SDXL UNet and the LoRA UNet from the pipeline.
```python
from peft import get_peft_model, LoraConfig
@@ -112,7 +112,7 @@ ikea_peft_model.load_state_dict(original_state_dict, strict=True)
> [!TIP]
> You can optionally push the ikea_peft_model to the Hub by calling `ikea_peft_model.push_to_hub("ikea_peft_model", token=TOKEN)`.
Repeat this process to create a [`~peft.PeftModel`] from the [lordjia/by-feng-zikai](https://huggingface.co/lordjia/by-feng-zikai) LoRA.
Repeat this process to create a [PeftModel](https://huggingface.co/docs/peft/package_reference/peft_model#peft.PeftModel) from the [lordjia/by-feng-zikai](https://huggingface.co/lordjia/by-feng-zikai) LoRA.
```python
pipeline.delete_adapters("ikea")
@@ -148,7 +148,7 @@ model = PeftModel.from_pretrained(base_unet, "stevhliu/ikea_peft_model", use_saf
model.load_adapter("stevhliu/feng_peft_model", use_safetensors=True, subfolder="feng", adapter_name="feng")
```
3. Merge the adapters using the [`~peft.LoraModel.add_weighted_adapter`] method and the merging method of your choice (learn more about other merging methods in this [blog post](https://huggingface.co/blog/peft_merging)). For this example, let's use the `"dare_linear"` method to merge the LoRAs.
3. Merge the adapters using the [add_weighted_adapter](https://huggingface.co/docs/peft/package_reference/lora#peft.LoraModel.add_weighted_adapter) method and the merging method of your choice (learn more about other merging methods in this [blog post](https://huggingface.co/blog/peft_merging)). For this example, let's use the `"dare_linear"` method to merge the LoRAs.
> [!WARNING]
> Keep in mind the LoRAs need to have the same rank to be merged!
@@ -182,9 +182,9 @@ image
## fuse_lora
Both the [`~loaders.UNet2DConditionLoadersMixin.set_adapters`] and [`~peft.LoraModel.add_weighted_adapter`] methods require loading the base model and the LoRA adapters separately which incurs some overhead. The [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] method allows you to fuse the LoRA weights directly with the original weights of the underlying model. This way, you're only loading the model once which can increase inference and lower memory-usage.
Both the [`~loaders.PeftAdapterMixin.set_adapters`] and [add_weighted_adapter](https://huggingface.co/docs/peft/package_reference/lora#peft.LoraModel.add_weighted_adapter) methods require loading the base model and the LoRA adapters separately which incurs some overhead. The [`~loaders.lora_base.LoraBaseMixin.fuse_lora`] method allows you to fuse the LoRA weights directly with the original weights of the underlying model. This way, you're only loading the model once which can increase inference and lower memory-usage.
You can use PEFT to easily fuse/unfuse multiple adapters directly into the model weights (both UNet and text encoder) using the [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] method, which can lead to a speed-up in inference and lower VRAM usage.
You can use PEFT to easily fuse/unfuse multiple adapters directly into the model weights (both UNet and text encoder) using the [`~loaders.lora_base.LoraBaseMixin.fuse_lora`] method, which can lead to a speed-up in inference and lower VRAM usage.
For example, if you have a base model and adapters loaded and set as active with the following adapter weights:
@@ -199,7 +199,7 @@ pipeline.load_lora_weights("lordjia/by-feng-zikai", weight_name="fengzikai_v1.0_
pipeline.set_adapters(["ikea", "feng"], adapter_weights=[0.7, 0.8])
```
Fuse these LoRAs into the UNet with the [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] method. The `lora_scale` parameter controls how much to scale the output by with the LoRA weights. It is important to make the `lora_scale` adjustments in the [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] method because it wont work if you try to pass `scale` to the `cross_attention_kwargs` in the pipeline.
Fuse these LoRAs into the UNet with the [`~loaders.lora_base.LoraBaseMixin.fuse_lora`] method. The `lora_scale` parameter controls how much to scale the output by with the LoRA weights. It is important to make the `lora_scale` adjustments in the [`~loaders.lora_base.LoraBaseMixin.fuse_lora`] method because it wont work if you try to pass `scale` to the `cross_attention_kwargs` in the pipeline.
```py
pipeline.fuse_lora(adapter_names=["ikea", "feng"], lora_scale=1.0)
@@ -226,7 +226,7 @@ image = pipeline("A bowl of ramen shaped like a cute kawaii bear, by Feng Zikai"
image
```
You can call [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] to restore the original model's weights (for example, if you want to use a different `lora_scale` value). However, this only works if you've only fused one LoRA adapter to the original model. If you've fused multiple LoRAs, you'll need to reload the model.
You can call [`~~loaders.lora_base.LoraBaseMixin.unfuse_lora`] to restore the original model's weights (for example, if you want to use a different `lora_scale` value). However, this only works if you've only fused one LoRA adapter to the original model. If you've fused multiple LoRAs, you'll need to reload the model.
```py
pipeline.unfuse_lora()
+1 -1
View File
@@ -307,7 +307,7 @@ print(pipeline)
위의 코드 출력 결과를 확인해보면, `pipeline`은 [`StableDiffusionPipeline`]의 인스턴스이며, 다음과 같이 총 7개의 컴포넌트로 구성된다는 것을 알 수 있습니다.
- `"feature_extractor"`: [`~transformers.CLIPFeatureExtractor`]의 인스턴스
- `"feature_extractor"`: [`~transformers.CLIPImageProcessor`]의 인스턴스
- `"safety_checker"`: 유해한 컨텐츠를 스크리닝하기 위한 [컴포넌트](https://github.com/huggingface/diffusers/blob/e55687e1e15407f60f32242027b7bb8170e58266/src/diffusers/pipelines/stable_diffusion/safety_checker.py#L32)
- `"scheduler"`: [`PNDMScheduler`]의 인스턴스
- `"text_encoder"`: [`~transformers.CLIPTextModel`]의 인스턴스
@@ -24,7 +24,7 @@ import PIL
from PIL import Image
from diffusers import StableDiffusionPipeline
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
def image_grid(imgs, rows, cols):
@@ -71,7 +71,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
@@ -79,7 +79,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
+52 -7
View File
@@ -71,6 +71,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
| Stable Diffusion BoxDiff Pipeline | Training-free controlled generation with bounding boxes using [BoxDiff](https://github.com/showlab/BoxDiff) | [Stable Diffusion BoxDiff Pipeline](#stable-diffusion-boxdiff) | - | [Jingyang Zhang](https://github.com/zjysteven/) |
| FRESCO V2V Pipeline | Implementation of [[CVPR 2024] FRESCO: Spatial-Temporal Correspondence for Zero-Shot Video Translation](https://arxiv.org/abs/2403.12962) | [FRESCO V2V Pipeline](#fresco) | - | [Yifan Zhou](https://github.com/SingleZombie) |
| AnimateDiff IPEX Pipeline | Accelerate AnimateDiff inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [AnimateDiff on IPEX](#animatediff-on-ipex) | - | [Dan Li](https://github.com/ustcuna/) |
| HunyuanDiT Differential Diffusion Pipeline | Applies [Differential Diffsuion](https://github.com/exx8/differential-diffusion) to [HunyuanDiT](https://github.com/huggingface/diffusers/pull/8240). | [HunyuanDiT with Differential Diffusion](#hunyuandit-with-differential-diffusion) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing) | [Monjoy Choudhury](https://github.com/MnCSSJ4x) |
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
@@ -1435,9 +1436,9 @@ import requests
import torch
from diffusers import DiffusionPipeline
from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPModel
from transformers import CLIPImageProcessor, CLIPModel
feature_extractor = CLIPFeatureExtractor.from_pretrained(
feature_extractor = CLIPImageProcessor.from_pretrained(
"laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
)
clip_model = CLIPModel.from_pretrained(
@@ -1646,7 +1647,6 @@ from diffusers import DiffusionPipeline
scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1",
subfolder="scheduler")
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1",
custom_pipeline="stable_diffusion_tensorrt_img2img",
variant='fp16',
@@ -1661,7 +1661,6 @@ pipe = pipe.to("cuda")
url = "https://pajoca.com/wp-content/uploads/2022/09/tekito-yamakawa-1.png"
response = requests.get(url)
input_image = Image.open(BytesIO(response.content)).convert("RGB")
prompt = "photorealistic new zealand hills"
image = pipe(prompt, image=input_image, strength=0.75,).images[0]
image.save('tensorrt_img2img_new_zealand_hills.png')
@@ -2122,7 +2121,7 @@ import torch
import open_clip
from open_clip import SimpleTokenizer
from diffusers import DiffusionPipeline
from transformers import CLIPFeatureExtractor, CLIPModel
from transformers import CLIPImageProcessor, CLIPModel
def download_image(url):
@@ -2130,7 +2129,7 @@ def download_image(url):
return PIL.Image.open(BytesIO(response.content)).convert("RGB")
# Loading additional models
feature_extractor = CLIPFeatureExtractor.from_pretrained(
feature_extractor = CLIPImageProcessor.from_pretrained(
"laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
)
clip_model = CLIPModel.from_pretrained(
@@ -4209,6 +4208,52 @@ print("Latency of AnimateDiffPipelineIpex--fp32", latency, "s for total", step,
latency = elapsed_time(pipe4, num_inference_steps=step)
print("Latency of AnimateDiffPipeline--fp32",latency, "s for total", step, "steps")
```
### HunyuanDiT with Differential Diffusion
#### Usage
```python
import torch
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import load_image
from PIL import Image
from torchvision import transforms
from pipeline_hunyuandit_differential_img2img import (
HunyuanDiTDifferentialImg2ImgPipeline,
)
pipe = HunyuanDiTDifferentialImg2ImgPipeline.from_pretrained(
"Tencent-Hunyuan/HunyuanDiT-Diffusers", torch_dtype=torch.float16
).to("cuda")
source_image = load_image(
"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png"
)
map = load_image(
"https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask_2.png"
)
prompt = "a green pear"
negative_prompt = "blurry"
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=source_image,
num_inference_steps=28,
guidance_scale=4.5,
strength=1.0,
map=map,
).images[0]
```
| ![Gradient](https://github.com/user-attachments/assets/e38ce4d5-1ae6-4df0-ab43-adc1b45716b5) | ![Input](https://github.com/user-attachments/assets/9c95679c-e9d7-4f5a-90d6-560203acd6b3) | ![Output](https://github.com/user-attachments/assets/5313ff64-a0c4-418b-8b55-a38f1a5e7532) |
| ------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------- |
| Gradient | Input | Output |
A colab notebook demonstrating all results can be found [here](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing). Depth Maps have also been added in the same colab.
# Perturbed-Attention Guidance
@@ -4285,4 +4330,4 @@ grid_image.save(grid_dir + "sample.png")
`pag_scale` : guidance scale of PAG (ex: 5.0)
`pag_applied_layers_index` : index of the layer to apply perturbation (ex: ['m0'])
`pag_applied_layers_index` : index of the layer to apply perturbation (ex: ['m0'])
@@ -7,7 +7,7 @@ import PIL.Image
import torch
from torch.nn import functional as F
from torchvision import transforms
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPModel, CLIPTextModel, CLIPTokenizer
from diffusers import (
AutoencoderKL,
@@ -86,7 +86,7 @@ class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline, StableDiffusionMi
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler],
feature_extractor: CLIPFeatureExtractor,
feature_extractor: CLIPImageProcessor,
coca_model=None,
coca_tokenizer=None,
coca_transform=None,
@@ -7,7 +7,7 @@ import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPModel, CLIPTextModel, CLIPTokenizer
from diffusers import (
AutoencoderKL,
@@ -32,9 +32,9 @@ EXAMPLE_DOC_STRING = """
import torch
from diffusers import DiffusionPipeline
from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPModel
from transformers import CLIPImageProcessor, CLIPModel
feature_extractor = CLIPFeatureExtractor.from_pretrained(
feature_extractor = CLIPImageProcessor.from_pretrained(
"laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
)
clip_model = CLIPModel.from_pretrained(
@@ -139,7 +139,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline, StableDiffusionMixin):
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler],
feature_extractor: CLIPFeatureExtractor,
feature_extractor: CLIPImageProcessor,
):
super().__init__()
self.register_modules(
@@ -43,7 +43,7 @@ from diffusers.utils import BaseOutput, check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
check_min_version("0.31.0.dev0")
class MarigoldDepthOutput(BaseOutput):
+2 -2
View File
@@ -9,7 +9,7 @@ import torch
from numpy import exp, pi, sqrt
from torchvision.transforms.functional import resize
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
@@ -275,7 +275,7 @@ class StableDiffusionCanvasPipeline(DiffusionPipeline, StableDiffusionMixin):
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
feature_extractor: CLIPImageProcessor,
):
super().__init__()
self.register_modules(
+2 -2
View File
@@ -15,7 +15,7 @@ from diffusers.utils import logging
try:
from ligo.segments import segment
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
except ImportError:
raise ImportError("Please install transformers and ligo-segments to use the mixture pipeline")
@@ -144,7 +144,7 @@ class StableDiffusionTilingPipeline(DiffusionPipeline, StableDiffusionExtrasMixi
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler],
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
feature_extractor: CLIPImageProcessor,
):
super().__init__()
self.register_modules(
File diff suppressed because it is too large Load Diff
@@ -189,7 +189,7 @@ class StableDiffusionXLControlNetAdapterPipeline(
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
@@ -332,7 +332,7 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`):
Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config
+3 -3
View File
@@ -9,7 +9,7 @@ import numpy as np
import PIL.Image
import torch
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
# from ...configuration_utils import FrozenDict
# from ...models import AutoencoderKL, UNet2DConditionModel
@@ -87,7 +87,7 @@ class Zero1to3StableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
cc_projection ([`CCProjection`]):
Projection layer to project the concated CLIP features and pose embeddings to the original CLIP feature size.
@@ -102,7 +102,7 @@ class Zero1to3StableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
feature_extractor: CLIPImageProcessor,
cc_projection: CCProjection,
requires_safety_checker: bool = True,
):
@@ -3,7 +3,7 @@ from typing import Dict, Optional
import torch
import torchvision.transforms.functional as FF
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from diffusers import StableDiffusionPipeline
from diffusers.models import AutoencoderKL, UNet2DConditionModel
@@ -69,7 +69,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
):
super().__init__(
+3 -3
View File
@@ -18,7 +18,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
import intel_extension_for_pytorch as ipex
import torch
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from diffusers.configuration_utils import FrozenDict
from diffusers.loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
@@ -86,7 +86,7 @@ class StableDiffusionIPEXPipeline(
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
@@ -100,7 +100,7 @@ class StableDiffusionIPEXPipeline(
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -42,7 +42,7 @@ from polygraphy.backend.trt import (
network_from_onnx_path,
save_engine,
)
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import FrozenDict, deprecate
@@ -679,7 +679,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline):
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
@@ -693,7 +693,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline):
unet: UNet2DConditionModel,
scheduler: DDIMScheduler,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
feature_extractor: CLIPImageProcessor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
stages=["clip", "unet", "vae", "vae_encoder"],
@@ -42,7 +42,7 @@ from polygraphy.backend.trt import (
network_from_onnx_path,
save_engine,
)
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import FrozenDict, deprecate
@@ -683,7 +683,7 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
@@ -697,7 +697,7 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
unet: UNet2DConditionModel,
scheduler: DDIMScheduler,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
feature_extractor: CLIPImageProcessor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
stages=["clip", "unet", "vae", "vae_encoder"],
@@ -42,7 +42,7 @@ from polygraphy.backend.trt import (
network_from_onnx_path,
save_engine,
)
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import FrozenDict, deprecate
@@ -595,7 +595,7 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
@@ -609,7 +609,7 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
unet: UNet2DConditionModel,
scheduler: DDIMScheduler,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
feature_extractor: CLIPImageProcessor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
stages=["clip", "unet", "vae"],
@@ -73,7 +73,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
@@ -66,7 +66,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
@@ -79,7 +79,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
@@ -72,7 +72,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
@@ -78,7 +78,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
+1 -1
View File
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
+1 -1
View File
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
check_min_version("0.31.0.dev0")
logger = logging.getLogger(__name__)
+1 -1
View File
@@ -61,7 +61,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
@@ -63,7 +63,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
+195
View File
@@ -0,0 +1,195 @@
# DreamBooth training example for FLUX.1 [dev]
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject.
The `train_dreambooth_flux.py` script shows how to implement the training procedure and adapt it for [FLUX.1 [dev]](https://blackforestlabs.ai/announcing-black-forest-labs/). We also provide a LoRA implementation in the `train_dreambooth_lora_flux.py` script.
> [!NOTE]
> **Memory consumption**
>
> Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements -
> a LoRA with a rank of 16 (w/ all components trained) can exceed 40GB of VRAM for training.
> For more tips & guidance on training on a resource-constrained device please visit [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX.md)
> [!NOTE]
> **Gated model**
>
> As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows youve accepted the gate. Use the command below to log in:
```bash
huggingface-cli login
```
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
## Running locally with PyTorch
### Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies:
**Important**
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
```
Then cd in the `examples/dreambooth` folder and run
```bash
pip install -r requirements_flux.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
```
Or for a default accelerate configuration without answering questions about your environment
```bash
accelerate config default
```
Or if your environment doesn't support an interactive shell (e.g., a notebook)
```python
from accelerate.utils import write_basic_config
write_basic_config()
```
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
### Dog toy example
Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
Let's first download it locally:
```python
from huggingface_hub import snapshot_download
local_dir = "./dog"
snapshot_download(
"diffusers/dog-example",
local_dir=local_dir, repo_type="dataset",
ignore_patterns=".gitattributes",
)
```
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
Now, we can launch training using:
```bash
export MODEL_NAME="black-forest-labs/FLUX.1-dev"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-flux"
accelerate launch train_dreambooth_flux.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="bf16" \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--learning_rate=1e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```
To better track our training experiments, we're using the following flags in the command above:
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
> [!NOTE]
> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases.
> [!TIP]
> You can pass `--use_8bit_adam` to reduce the memory requirements of training. Make sure to install `bitsandbytes` if you want to do so.
## LoRA + DreamBooth
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
To perform DreamBooth with LoRA, run:
```bash
export MODEL_NAME="black-forest-labs/FLUX.1-dev"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-flux-lora"
accelerate launch train_dreambooth_lora_flux.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="bf16" \
--instance_prompt="a photo of sks dog" \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--learning_rate=1e-5 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```
### Text Encoder Training
Alongside the transformer, fine-tuning of the CLIP text encoder is also supported.
To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind:
> [!NOTE]
> FLUX.1 has 2 text encoders (CLIP L/14 and T5-v1.1-XXL).
By enabling `--train_text_encoder`, fine-tuning of the **CLIP encoder** is performed.
> At the moment, T5 fine-tuning is not supported and weights remain frozen when text encoder training is enabled.
To perform DreamBooth LoRA with text-encoder training, run:
```bash
export MODEL_NAME="black-forest-labs/FLUX.1-dev"
export OUTPUT_DIR="trained-flux-dev-dreambooth-lora"
accelerate launch train_dreambooth_lora_flux.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--mixed_precision="bf16" \
--train_text_encoder\
--instance_prompt="a photo of sks dog" \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--learning_rate=1e-5 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--seed="0" \
--push_to_hub
```
## Other notes
Thanks to `bghira` for their help with reviewing & insight sharing ♥️
@@ -0,0 +1,8 @@
accelerate>=0.31.0
torchvision
transformers>=4.41.2
ftfy
tensorboard
Jinja2
peft>=0.11.1
sentencepiece
+203
View File
@@ -0,0 +1,203 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import shutil
import sys
import tempfile
from diffusers import DiffusionPipeline, FluxTransformer2DModel
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class DreamBoothFlux(ExamplesTestsAccelerate):
instance_data_dir = "docs/source/en/imgs"
instance_prompt = "photo"
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
script_path = "examples/dreambooth/train_dreambooth_flux.py"
def test_dreambooth(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "transformer", "diffusion_pytorch_model.safetensors")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
def test_dreambooth_checkpointing(self):
with tempfile.TemporaryDirectory() as tmpdir:
# Run training script with checkpointing
# max_train_steps == 4, checkpointing_steps == 2
# Should create checkpoints at steps 2, 4
initial_run_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 4
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--seed=0
""".split()
run_command(self._launch_args + initial_run_args)
# check can run the original fully trained output pipeline
pipe = DiffusionPipeline.from_pretrained(tmpdir)
pipe(self.instance_prompt, num_inference_steps=1)
# check checkpoint directories exist
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
# check can run an intermediate checkpoint
transformer = FluxTransformer2DModel.from_pretrained(tmpdir, subfolder="checkpoint-2/transformer")
pipe = DiffusionPipeline.from_pretrained(self.pretrained_model_name_or_path, transformer=transformer)
pipe(self.instance_prompt, num_inference_steps=1)
# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
# Run training script for 7 total steps resuming from checkpoint 4
resume_run_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 6
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-4
--seed=0
""".split()
run_command(self._launch_args + resume_run_args)
# check can run new fully trained pipeline
pipe = DiffusionPipeline.from_pretrained(tmpdir)
pipe(self.instance_prompt, num_inference_steps=1)
# check old checkpoints do not exist
self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
# check new checkpoints exist
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6")))
def test_dreambooth_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=6
--checkpoints_total_limit=2
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)
def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=4
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-2", "checkpoint-4"},
)
resume_run_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=8
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
""".split()
run_command(self._launch_args + resume_run_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
@@ -0,0 +1,165 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
import tempfile
import safetensors
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class DreamBoothLoRAFlux(ExamplesTestsAccelerate):
instance_data_dir = "docs/source/en/imgs"
instance_prompt = "photo"
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
script_path = "examples/dreambooth/train_dreambooth_lora_flux.py"
def test_dreambooth_lora_flux(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_text_encoder_flux(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--train_text_encoder
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
starts_with_expected_prefix = all(
(key.startswith("transformer") or key.startswith("text_encoder")) for key in lora_state_dict.keys()
)
self.assertTrue(starts_with_expected_prefix)
def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=6
--checkpoints_total_limit=2
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)
def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=4
--checkpointing_steps=2
""".split()
run_command(self._launch_args + test_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
resume_run_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=8
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
""".split()
run_command(self._launch_args + resume_run_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
+1 -1
View File
@@ -63,7 +63,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
+1 -1
View File
@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
check_min_version("0.31.0.dev0")
# Cache compiled models across invocations of this script.
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
File diff suppressed because it is too large Load Diff
+1 -1
View File
@@ -70,7 +70,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
File diff suppressed because it is too large Load Diff
@@ -72,7 +72,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
@@ -1271,7 +1271,7 @@ def main(args):
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)
transformer_state_dict = {
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
f'{k.replace("transformer.", "")}': v for k, v in lora_state_dict.items() if k.startswith("transformer.")
}
transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
@@ -1454,7 +1454,7 @@ def main(args):
)
# Clear the memory here
if not args.train_text_encoder and train_dataset.custom_instance_prompts:
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
del tokenizers, text_encoders
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
del text_encoder_one, text_encoder_two, text_encoder_three
@@ -78,7 +78,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
+1 -1
View File
@@ -64,7 +64,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
@@ -57,7 +57,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -52,7 +52,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -51,7 +51,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -43,7 +43,7 @@ from PIL import Image
from torch.utils.data import default_collate
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import AutoTokenizer, DPTFeatureExtractor, DPTForDepthEstimation, PretrainedConfig
from transformers import AutoTokenizer, DPTForDepthEstimation, DPTImageProcessor, PretrainedConfig
from webdataset.tariterators import (
base_plus_ext,
tar_file_expander,
@@ -205,7 +205,7 @@ class Text2ImageDataset:
pin_memory: bool = False,
persistent_workers: bool = False,
control_type: str = "canny",
feature_extractor: Optional[DPTFeatureExtractor] = None,
feature_extractor: Optional[DPTImageProcessor] = None,
):
if not isinstance(train_shards_path_or_url, str):
train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url]
@@ -1011,7 +1011,7 @@ def main(args):
controlnet = pre_controlnet
if args.control_type == "depth":
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas")
depth_model.requires_grad_(False)
else:
+1 -1
View File
@@ -45,7 +45,7 @@
" UniPCMultistepScheduler,\n",
" EulerDiscreteScheduler,\n",
")\n",
"from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer\n",
"from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n",
"# pretrained_model_name_or_path = 'masterful/gligen-1-4-generation-text-box'\n",
"\n",
"pretrained_model_name_or_path = '/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83'\n",
@@ -4,7 +4,7 @@ from typing import Callable, List, Optional, Union
import torch
from PIL import Image
from retriever import Retriever, normalize_images, preprocess_images
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPModel, CLIPTokenizer
from diffusers import (
AutoencoderKL,
@@ -47,7 +47,7 @@ class RDMPipeline(DiffusionPipeline, StableDiffusionMixin):
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
feature_extractor ([`CLIPFeatureExtractor`]):
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
@@ -65,7 +65,7 @@ class RDMPipeline(DiffusionPipeline, StableDiffusionMixin):
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
],
feature_extractor: CLIPFeatureExtractor,
feature_extractor: CLIPImageProcessor,
retriever: Optional[Retriever] = None,
):
super().__init__()
+7 -9
View File
@@ -6,7 +6,7 @@ import numpy as np
import torch
from datasets import Dataset, load_dataset
from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPModel, PretrainedConfig
from transformers import CLIPImageProcessor, CLIPModel, PretrainedConfig
from diffusers import logging
@@ -20,7 +20,7 @@ def normalize_images(images: List[Image.Image]):
return images
def preprocess_images(images: List[np.array], feature_extractor: CLIPFeatureExtractor) -> torch.Tensor:
def preprocess_images(images: List[np.array], feature_extractor: CLIPImageProcessor) -> torch.Tensor:
"""
Preprocesses a list of images into a batch of tensors.
@@ -95,14 +95,12 @@ class Index:
def build_index(
self,
model=None,
feature_extractor: CLIPFeatureExtractor = None,
feature_extractor: CLIPImageProcessor = None,
torch_dtype=torch.float32,
):
if not self.index_initialized:
model = model or CLIPModel.from_pretrained(self.config.clip_name_or_path).to(dtype=torch_dtype)
feature_extractor = feature_extractor or CLIPFeatureExtractor.from_pretrained(
self.config.clip_name_or_path
)
feature_extractor = feature_extractor or CLIPImageProcessor.from_pretrained(self.config.clip_name_or_path)
self.dataset = get_dataset_with_emb_from_clip_model(
self.dataset,
model,
@@ -136,7 +134,7 @@ class Retriever:
index: Index = None,
dataset: Dataset = None,
model=None,
feature_extractor: CLIPFeatureExtractor = None,
feature_extractor: CLIPImageProcessor = None,
):
self.config = config
self.index = index or self._build_index(config, dataset, model=model, feature_extractor=feature_extractor)
@@ -148,7 +146,7 @@ class Retriever:
index: Index = None,
dataset: Dataset = None,
model=None,
feature_extractor: CLIPFeatureExtractor = None,
feature_extractor: CLIPImageProcessor = None,
**kwargs,
):
config = kwargs.pop("config", None) or IndexConfig.from_pretrained(retriever_name_or_path, **kwargs)
@@ -156,7 +154,7 @@ class Retriever:
@staticmethod
def _build_index(
config: IndexConfig, dataset: Dataset = None, model=None, feature_extractor: CLIPFeatureExtractor = None
config: IndexConfig, dataset: Dataset = None, model=None, feature_extractor: CLIPImageProcessor = None
):
dataset = dataset or load_dataset(config.dataset_name)
dataset = dataset[config.dataset_set]
@@ -2,8 +2,8 @@ diffusers==0.20.1
accelerate==0.23.0
transformers==4.38.0
peft==0.5.0
torch==2.0.1
torch==2.2.0
torchvision>=0.16
ftfy==6.1.1
tensorboard==2.14.0
Jinja2==3.1.3
Jinja2==3.1.4
@@ -18,7 +18,7 @@ cc.initialize_cache("/tmp/sdxl_cache")
NUM_DEVICES = jax.device_count()
# 1. Let's start by downloading the model and loading it into our pipeline class
# Adhering to JAX's functional approach, the model's parameters are returned seperatetely and
# Adhering to JAX's functional approach, the model's parameters are returned separately and
# will have to be passed to the pipeline during inference
pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", revision="refs/pr/95", split_head_dim=True
@@ -69,7 +69,7 @@ def replicate_all(prompt_ids, neg_prompt_ids, seed):
# to the function and tell JAX which are static arguments, that is, arguments that
# are known at compile time and won't change. In our case, it is num_inference_steps,
# height, width and return_latents.
# Once the function is compiled, these parameters are ommited from future calls and
# Once the function is compiled, these parameters are omitted from future calls and
# cannot be changed without modifying the code and recompiling.
def aot_compile(
prompt=default_prompt,
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
+19 -8
View File
@@ -57,7 +57,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -826,17 +826,22 @@ def main():
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
num_training_steps_for_scheduler = (
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
)
else:
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_warmup_steps=num_warmup_steps_for_scheduler,
num_training_steps=num_training_steps_for_scheduler,
)
# Prepare everything with our `accelerator`.
@@ -866,8 +871,14 @@ def main():
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
logger.warning(
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
f"This inconsistency may result in the learning rate scheduler not functioning properly."
)
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
@@ -49,7 +49,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
check_min_version("0.31.0.dev0")
logger = logging.getLogger(__name__)
@@ -56,7 +56,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -68,7 +68,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
@@ -478,7 +478,7 @@ def parse_args(input_args=None):
parser.add_argument(
"--debug_loss",
action="store_true",
help="debug loss for each image, if filenames are awailable in the dataset",
help="debug loss for each image, if filenames are available in the dataset",
)
if input_args is not None:
@@ -55,7 +55,7 @@ from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
+3
View File
@@ -109,6 +109,9 @@ import torch
model_id = "path-to-your-trained-model"
pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to("cuda")
repo_id_embeds = "path-to-your-learned-embeds"
pipe.load_textual_inversion(repo_id_embeds)
prompt = "A <cat-toy> backpack"
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
+22 -1
View File
@@ -23,4 +23,25 @@ accelerate launch textual_inversion_sdxl.py \
--output_dir="./textual_inversion_cat_sdxl"
```
For now, only training of the first text encoder is supported.
Training of both text encoders is supported.
### Inference Example
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionXLPipeline`.
Make sure to include the `placeholder_token` in your prompt.
```python
from diffusers import StableDiffusionXLPipeline
import torch
model_id = "./textual_inversion_cat_sdxl"
pipe = StableDiffusionXLPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to("cuda")
prompt = "A <cat-toy> backpack"
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
image.save("cat-backpack.png")
image = pipe(prompt="", prompt_2=prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
image.save("cat-backpack-prompt_2.png")
```
@@ -81,7 +81,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
@@ -56,7 +56,7 @@ else:
# ------------------------------------------------------------------------------
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
check_min_version("0.31.0.dev0")
logger = logging.getLogger(__name__)
@@ -76,7 +76,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
@@ -135,7 +135,7 @@ def log_validation(
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
text_encoder=accelerator.unwrap_model(text_encoder_1),
text_encoder_2=text_encoder_2,
text_encoder_2=accelerator.unwrap_model(text_encoder_2),
tokenizer=tokenizer_1,
tokenizer_2=tokenizer_2,
unet=unet,
@@ -678,36 +678,54 @@ def main():
f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
" `placeholder_token` that is not already in the tokenizer."
)
num_added_tokens = tokenizer_2.add_tokens(placeholder_tokens)
if num_added_tokens != args.num_vectors:
raise ValueError(
f"The 2nd tokenizer already contains the token {args.placeholder_token}. Please pass a different"
" `placeholder_token` that is not already in the tokenizer."
)
# Convert the initializer_token, placeholder_token to ids
token_ids = tokenizer_1.encode(args.initializer_token, add_special_tokens=False)
token_ids_2 = tokenizer_2.encode(args.initializer_token, add_special_tokens=False)
# Check if initializer_token is a single token or a sequence of tokens
if len(token_ids) > 1:
if len(token_ids) > 1 or len(token_ids_2) > 1:
raise ValueError("The initializer token must be a single token.")
initializer_token_id = token_ids[0]
placeholder_token_ids = tokenizer_1.convert_tokens_to_ids(placeholder_tokens)
initializer_token_id_2 = token_ids_2[0]
placeholder_token_ids_2 = tokenizer_2.convert_tokens_to_ids(placeholder_tokens)
# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder_1.resize_token_embeddings(len(tokenizer_1))
text_encoder_2.resize_token_embeddings(len(tokenizer_2))
# Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = text_encoder_1.get_input_embeddings().weight.data
token_embeds_2 = text_encoder_2.get_input_embeddings().weight.data
with torch.no_grad():
for token_id in placeholder_token_ids:
token_embeds[token_id] = token_embeds[initializer_token_id].clone()
for token_id in placeholder_token_ids_2:
token_embeds_2[token_id] = token_embeds_2[initializer_token_id_2].clone()
# Freeze vae and unet
vae.requires_grad_(False)
unet.requires_grad_(False)
text_encoder_2.requires_grad_(False)
# Freeze all parameters except for the token embeddings in text encoder
text_encoder_1.text_model.encoder.requires_grad_(False)
text_encoder_1.text_model.final_layer_norm.requires_grad_(False)
text_encoder_1.text_model.embeddings.position_embedding.requires_grad_(False)
text_encoder_2.text_model.encoder.requires_grad_(False)
text_encoder_2.text_model.final_layer_norm.requires_grad_(False)
text_encoder_2.text_model.embeddings.position_embedding.requires_grad_(False)
if args.gradient_checkpointing:
text_encoder_1.gradient_checkpointing_enable()
text_encoder_2.gradient_checkpointing_enable()
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
@@ -746,7 +764,11 @@ def main():
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(
text_encoder_1.get_input_embeddings().parameters(), # only optimize the embeddings
# only optimize the embeddings
[
text_encoder_1.text_model.embeddings.token_embedding.weight,
text_encoder_2.text_model.embeddings.token_embedding.weight,
],
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
@@ -786,9 +808,10 @@ def main():
)
text_encoder_1.train()
text_encoder_2.train()
# Prepare everything with our `accelerator`.
text_encoder_1, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder_1, optimizer, train_dataloader, lr_scheduler
text_encoder_1, text_encoder_2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder_1, text_encoder_2, optimizer, train_dataloader, lr_scheduler
)
# For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
@@ -866,11 +889,13 @@ def main():
# keep original embeddings as reference
orig_embeds_params = accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight.data.clone()
orig_embeds_params_2 = accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight.data.clone()
for epoch in range(first_epoch, args.num_train_epochs):
text_encoder_1.train()
text_encoder_2.train()
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(text_encoder_1):
with accelerator.accumulate([text_encoder_1, text_encoder_2]):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
latents = latents * vae.config.scaling_factor
@@ -892,9 +917,7 @@ def main():
.hidden_states[-2]
.to(dtype=weight_dtype)
)
encoder_output_2 = text_encoder_2(
batch["input_ids_2"].reshape(batch["input_ids_1"].shape[0], -1), output_hidden_states=True
)
encoder_output_2 = text_encoder_2(batch["input_ids_2"], output_hidden_states=True)
encoder_hidden_states_2 = encoder_output_2.hidden_states[-2].to(dtype=weight_dtype)
original_size = [
(batch["original_size"][0][i].item(), batch["original_size"][1][i].item())
@@ -938,11 +961,16 @@ def main():
# Let's make sure we don't update any embedding weights besides the newly added token
index_no_updates = torch.ones((len(tokenizer_1),), dtype=torch.bool)
index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
index_no_updates_2 = torch.ones((len(tokenizer_2),), dtype=torch.bool)
index_no_updates_2[min(placeholder_token_ids_2) : max(placeholder_token_ids_2) + 1] = False
with torch.no_grad():
accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[
index_no_updates
] = orig_embeds_params[index_no_updates]
accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight[
index_no_updates_2
] = orig_embeds_params_2[index_no_updates_2]
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
@@ -960,6 +988,16 @@ def main():
save_path,
safe_serialization=True,
)
weight_name = f"learned_embeds_2-steps-{global_step}.safetensors"
save_path = os.path.join(args.output_dir, weight_name)
save_progress(
text_encoder_2,
placeholder_token_ids_2,
accelerator,
args,
save_path,
safe_serialization=True,
)
if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
@@ -1034,7 +1072,7 @@ def main():
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
text_encoder=accelerator.unwrap_model(text_encoder_1),
text_encoder_2=text_encoder_2,
text_encoder_2=accelerator.unwrap_model(text_encoder_2),
vae=vae,
unet=unet,
tokenizer=tokenizer_1,
@@ -1052,6 +1090,16 @@ def main():
save_path,
safe_serialization=True,
)
weight_name = "learned_embeds_2.safetensors"
save_path = os.path.join(args.output_dir, weight_name)
save_progress(
text_encoder_2,
placeholder_token_ids_2,
accelerator,
args,
save_path,
safe_serialization=True,
)
if args.push_to_hub:
save_model_card(
@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__, log_level="INFO")
+1 -1
View File
@@ -50,7 +50,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -50,7 +50,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -51,7 +51,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.30.0.dev0")
check_min_version("0.31.0.dev0")
logger = get_logger(__name__, log_level="INFO")
+1 -1
View File
@@ -254,7 +254,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
setup(
name="diffusers",
version="0.30.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="0.31.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="State-of-the-art diffusion in PyTorch and JAX.",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
+29 -9
View File
@@ -1,4 +1,4 @@
__version__ = "0.30.0.dev0"
__version__ = "0.31.0.dev0"
from typing import TYPE_CHECKING
@@ -12,6 +12,7 @@ from .utils import (
is_note_seq_available,
is_onnx_available,
is_scipy_available,
is_sentencepiece_available,
is_torch_available,
is_torchsde_available,
is_transformers_available,
@@ -87,6 +88,7 @@ else:
"ControlNetModel",
"ControlNetXSAdapter",
"DiTTransformer2DModel",
"FluxControlNetModel",
"FluxTransformer2DModel",
"HunyuanDiT2DControlNetModel",
"HunyuanDiT2DModel",
@@ -250,11 +252,10 @@ else:
"AuraFlowPipeline",
"BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline",
"ChatGLMModel",
"ChatGLMTokenizer",
"CLIPImageProjection",
"CogVideoXPipeline",
"CycleDiffusionPipeline",
"FluxControlNetPipeline",
"FluxPipeline",
"HunyuanDiTControlNetPipeline",
"HunyuanDiTPAGPipeline",
@@ -286,8 +287,6 @@ else:
"KandinskyV22Pipeline",
"KandinskyV22PriorEmb2EmbPipeline",
"KandinskyV22PriorPipeline",
"KolorsImg2ImgPipeline",
"KolorsPipeline",
"LatentConsistencyModelImg2ImgPipeline",
"LatentConsistencyModelPipeline",
"LattePipeline",
@@ -311,9 +310,11 @@ else:
"StableCascadeCombinedPipeline",
"StableCascadeDecoderPipeline",
"StableCascadePriorPipeline",
"StableDiffusion3ControlNetInpaintingPipeline",
"StableDiffusion3ControlNetPipeline",
"StableDiffusion3Img2ImgPipeline",
"StableDiffusion3InpaintPipeline",
"StableDiffusion3PAGPipeline",
"StableDiffusion3Pipeline",
"StableDiffusionAdapterPipeline",
"StableDiffusionAttendAndExcitePipeline",
@@ -391,6 +392,19 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["pipelines"].extend(["StableDiffusionKDiffusionPipeline", "StableDiffusionXLKDiffusionPipeline"])
try:
if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_torch_and_transformers_and_sentencepiece_objects # noqa F403
_import_structure["utils.dummy_torch_and_transformers_and_sentencepiece_objects"] = [
name for name in dir(dummy_torch_and_transformers_and_sentencepiece_objects) if not name.startswith("_")
]
else:
_import_structure["pipelines"].extend(["KolorsImg2ImgPipeline", "KolorsPAGPipeline", "KolorsPipeline"])
try:
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
raise OptionalDependencyNotAvailable()
@@ -538,6 +552,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ControlNetModel,
ControlNetXSAdapter,
DiTTransformer2DModel,
FluxControlNetModel,
FluxTransformer2DModel,
HunyuanDiT2DControlNetModel,
HunyuanDiT2DModel,
@@ -679,11 +694,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AudioLDM2UNet2DConditionModel,
AudioLDMPipeline,
AuraFlowPipeline,
ChatGLMModel,
ChatGLMTokenizer,
CLIPImageProjection,
CogVideoXPipeline,
CycleDiffusionPipeline,
FluxControlNetPipeline,
FluxPipeline,
HunyuanDiTControlNetPipeline,
HunyuanDiTPAGPipeline,
@@ -715,8 +729,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
KandinskyV22Pipeline,
KandinskyV22PriorEmb2EmbPipeline,
KandinskyV22PriorPipeline,
KolorsImg2ImgPipeline,
KolorsPipeline,
LatentConsistencyModelImg2ImgPipeline,
LatentConsistencyModelPipeline,
LattePipeline,
@@ -743,6 +755,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusion3ControlNetPipeline,
StableDiffusion3Img2ImgPipeline,
StableDiffusion3InpaintPipeline,
StableDiffusion3PAGPipeline,
StableDiffusion3Pipeline,
StableDiffusionAdapterPipeline,
StableDiffusionAttendAndExcitePipeline,
@@ -814,6 +827,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .pipelines import StableDiffusionKDiffusionPipeline, StableDiffusionXLKDiffusionPipeline
try:
if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_transformers_and_sentencepiece_objects import * # noqa F403
else:
from .pipelines import KolorsImg2ImgPipeline, KolorsPAGPipeline, KolorsPipeline
try:
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
raise OptionalDependencyNotAvailable()
+5 -1
View File
@@ -222,7 +222,11 @@ class IPAdapterMixin:
# create feature extractor if it has not been registered to the pipeline yet
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
clip_image_size = self.image_encoder.config.image_size
# FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224
default_clip_size = 224
clip_image_size = (
self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size
)
feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
self.register_modules(feature_extractor=feature_extractor)
@@ -24,6 +24,7 @@ from .single_file_utils import (
SingleFileComponentError,
convert_animatediff_checkpoint_to_diffusers,
convert_controlnet_checkpoint,
convert_flux_transformer_checkpoint_to_diffusers,
convert_ldm_unet_checkpoint,
convert_ldm_vae_checkpoint,
convert_sd3_transformer_checkpoint_to_diffusers,
@@ -74,6 +75,13 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"MotionAdapter": {
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
},
"SparseControlNetModel": {
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
},
"FluxTransformer2DModel": {
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
}
+212 -2
View File
@@ -74,9 +74,12 @@ CHECKPOINT_KEY_NAMES = {
"stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
"stable_cascade_stage_c": "clip_txt_mapper.weight",
"sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
"animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe",
"animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe",
"animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
"animatediff_scribble": "controlnet_cond_embedding.conv_in.weight",
"animatediff_rgb": "controlnet_cond_embedding.weight",
"flux": "double_blocks.0.img_attn.norm.key_norm.scale",
}
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -110,6 +113,10 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"},
"animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"},
"animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"},
"animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"},
"animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
"flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
}
# Use to configure model sample size when original config is provided
@@ -491,7 +498,13 @@ def infer_diffusers_model_type(checkpoint):
model_type = "sd3"
elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint:
if CHECKPOINT_KEY_NAMES["animatediff_v2"] in checkpoint:
if CHECKPOINT_KEY_NAMES["animatediff_scribble"] in checkpoint:
model_type = "animatediff_scribble"
elif CHECKPOINT_KEY_NAMES["animatediff_rgb"] in checkpoint:
model_type = "animatediff_rgb"
elif CHECKPOINT_KEY_NAMES["animatediff_v2"] in checkpoint:
model_type = "animatediff_v2"
elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff_sdxl_beta"]].shape[-1] == 320:
@@ -503,6 +516,11 @@ def infer_diffusers_model_type(checkpoint):
else:
model_type = "animatediff_v3"
elif CHECKPOINT_KEY_NAMES["flux"] in checkpoint:
if "guidance_in.in_layer.bias" in checkpoint:
model_type = "flux-dev"
else:
model_type = "flux-schnell"
else:
model_type = "v1"
@@ -1859,3 +1877,195 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
] = v
return converted_state_dict
def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
mlp_ratio = 4.0
inner_dim = 3072
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
## time_text_embed.timestep_embedder <- time_in
converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop(
"time_in.in_layer.weight"
)
converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("time_in.in_layer.bias")
converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop(
"time_in.out_layer.weight"
)
converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("time_in.out_layer.bias")
## time_text_embed.text_embedder <- vector_in
converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("vector_in.in_layer.weight")
converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("vector_in.in_layer.bias")
converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop(
"vector_in.out_layer.weight"
)
converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("vector_in.out_layer.bias")
# guidance
has_guidance = any("guidance" in k for k in checkpoint)
if has_guidance:
converted_state_dict["time_text_embed.guidance_embedder.linear_1.weight"] = checkpoint.pop(
"guidance_in.in_layer.weight"
)
converted_state_dict["time_text_embed.guidance_embedder.linear_1.bias"] = checkpoint.pop(
"guidance_in.in_layer.bias"
)
converted_state_dict["time_text_embed.guidance_embedder.linear_2.weight"] = checkpoint.pop(
"guidance_in.out_layer.weight"
)
converted_state_dict["time_text_embed.guidance_embedder.linear_2.bias"] = checkpoint.pop(
"guidance_in.out_layer.bias"
)
# context_embedder
converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight")
converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias")
# x_embedder
converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight")
converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias")
# double transformer blocks
for i in range(num_layers):
block_prefix = f"transformer_blocks.{i}."
# norms.
## norm1
converted_state_dict[f"{block_prefix}norm1.linear.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_mod.lin.weight"
)
converted_state_dict[f"{block_prefix}norm1.linear.bias"] = checkpoint.pop(
f"double_blocks.{i}.img_mod.lin.bias"
)
## norm1_context
converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_mod.lin.weight"
)
converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = checkpoint.pop(
f"double_blocks.{i}.txt_mod.lin.bias"
)
# Q, K, V
sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0)
context_q, context_k, context_v = torch.chunk(
checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
)
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
)
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
)
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
# qk_norm
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_attn.norm.query_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_attn.norm.key_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
)
# ff img_mlp
converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_mlp.0.weight"
)
converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias")
converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight")
converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias")
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_mlp.0.weight"
)
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop(
f"double_blocks.{i}.txt_mlp.0.bias"
)
converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_mlp.2.weight"
)
converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop(
f"double_blocks.{i}.txt_mlp.2.bias"
)
# output projections.
converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_attn.proj.weight"
)
converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop(
f"double_blocks.{i}.img_attn.proj.bias"
)
converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_attn.proj.weight"
)
converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop(
f"double_blocks.{i}.txt_attn.proj.bias"
)
# single transfomer blocks
for i in range(num_single_layers):
block_prefix = f"single_transformer_blocks.{i}."
# norm.linear <- single_blocks.0.modulation.lin
converted_state_dict[f"{block_prefix}norm.linear.weight"] = checkpoint.pop(
f"single_blocks.{i}.modulation.lin.weight"
)
converted_state_dict[f"{block_prefix}norm.linear.bias"] = checkpoint.pop(
f"single_blocks.{i}.modulation.lin.bias"
)
# Q, K, V, mlp
mlp_hidden_dim = int(inner_dim * mlp_ratio)
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
q_bias, k_bias, v_bias, mlp_bias = torch.split(
checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
)
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
# qk norm
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
f"single_blocks.{i}.norm.query_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
f"single_blocks.{i}.norm.key_norm.scale"
)
# output projections.
converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight")
converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias")
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
checkpoint.pop("final_layer.adaLN_modulation.1.weight")
)
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
checkpoint.pop("final_layer.adaLN_modulation.1.bias")
)
return converted_state_dict
+2
View File
@@ -35,6 +35,7 @@ if is_torch_available():
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["autoencoders.vq_model"] = ["VQModel"]
_import_structure["controlnet"] = ["ControlNetModel"]
_import_structure["controlnet_flux"] = ["FluxControlNetModel"]
_import_structure["controlnet_hunyuan"] = ["HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel"]
_import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
_import_structure["controlnet_sparsectrl"] = ["SparseControlNetModel"]
@@ -87,6 +88,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
VQModel,
)
from .controlnet import ControlNetModel
from .controlnet_flux import FluxControlNetModel
from .controlnet_hunyuan import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel
from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
from .controlnet_sparsectrl import SparseControlNetModel
+327 -3
View File
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
@@ -272,6 +272,17 @@ class BasicTransformerBlock(nn.Module):
attention_out_bias: bool = True,
):
super().__init__()
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.dropout = dropout
self.cross_attention_dim = cross_attention_dim
self.activation_fn = activation_fn
self.attention_bias = attention_bias
self.double_self_attention = double_self_attention
self.norm_elementwise_affine = norm_elementwise_affine
self.positional_embeddings = positional_embeddings
self.num_positional_embeddings = num_positional_embeddings
self.only_cross_attention = only_cross_attention
# We keep these boolean flags for backward-compatibility.
@@ -376,7 +387,7 @@ class BasicTransformerBlock(nn.Module):
"layer_norm",
)
elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm", "ada_norm_continuous"]:
elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
elif norm_type == "layer_norm_i2vgen":
self.norm3 = None
@@ -438,7 +449,7 @@ class BasicTransformerBlock(nn.Module):
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
elif self.norm_type == "ada_norm_single":
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
self.scale_shift_table[None].to(timestep.dtype) + timestep.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
@@ -782,6 +793,319 @@ class SkipFFTransformerBlock(nn.Module):
return hidden_states
@maybe_allow_in_graph
class FreeNoiseTransformerBlock(nn.Module):
r"""
A FreeNoise Transformer block.
Parameters:
dim (`int`):
The number of channels in the input and output.
num_attention_heads (`int`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`):
The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0):
The dropout probability to use.
cross_attention_dim (`int`, *optional*):
The size of the encoder_hidden_states vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`):
Activation function to be used in feed-forward.
num_embeds_ada_norm (`int`, *optional*):
The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (`bool`, defaults to `False`):
Configure if the attentions should contain a bias parameter.
only_cross_attention (`bool`, defaults to `False`):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, defaults to `False`):
Whether to use two self-attention layers. In this case no cross attention layers are used.
upcast_attention (`bool`, defaults to `False`):
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
norm_elementwise_affine (`bool`, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_type (`str`, defaults to `"layer_norm"`):
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
final_dropout (`bool` defaults to `False`):
Whether to apply a final dropout after the last feed-forward layer.
attention_type (`str`, defaults to `"default"`):
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
positional_embeddings (`str`, *optional*):
The type of positional embeddings to apply to.
num_positional_embeddings (`int`, *optional*, defaults to `None`):
The maximum number of positional embeddings to apply.
ff_inner_dim (`int`, *optional*):
Hidden dimension of feed-forward MLP.
ff_bias (`bool`, defaults to `True`):
Whether or not to use bias in feed-forward MLP.
attention_out_bias (`bool`, defaults to `True`):
Whether or not to use bias in attention output project layer.
context_length (`int`, defaults to `16`):
The maximum number of frames that the FreeNoise block processes at once.
context_stride (`int`, defaults to `4`):
The number of frames to be skipped before starting to process a new batch of `context_length` frames.
weighting_scheme (`str`, defaults to `"pyramid"`):
The weighting scheme to use for weighting averaging of processed latent frames. As described in the
Equation 9. of the [FreeNoise](https://arxiv.org/abs/2310.15169) paper, "pyramid" is the default setting
used.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout: float = 0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm",
norm_eps: float = 1e-5,
final_dropout: bool = False,
positional_embeddings: Optional[str] = None,
num_positional_embeddings: Optional[int] = None,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
context_length: int = 16,
context_stride: int = 4,
weighting_scheme: str = "pyramid",
):
super().__init__()
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.dropout = dropout
self.cross_attention_dim = cross_attention_dim
self.activation_fn = activation_fn
self.attention_bias = attention_bias
self.double_self_attention = double_self_attention
self.norm_elementwise_affine = norm_elementwise_affine
self.positional_embeddings = positional_embeddings
self.num_positional_embeddings = num_positional_embeddings
self.only_cross_attention = only_cross_attention
self.set_free_noise_properties(context_length, context_stride, weighting_scheme)
# We keep these boolean flags for backward-compatibility.
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
self.use_layer_norm = norm_type == "layer_norm"
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)
self.norm_type = norm_type
self.num_embeds_ada_norm = num_embeds_ada_norm
if positional_embeddings and (num_positional_embeddings is None):
raise ValueError(
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
)
if positional_embeddings == "sinusoidal":
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
else:
self.pos_embed = None
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
)
# 2. Cross-Attn
if cross_attention_dim is not None or double_self_attention:
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
out_bias=attention_out_bias,
) # is self-attn if encoder_hidden_states is none
# 3. Feed-forward
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]:
frame_indices = []
for i in range(0, num_frames - self.context_length + 1, self.context_stride):
window_start = i
window_end = min(num_frames, i + self.context_length)
frame_indices.append((window_start, window_end))
return frame_indices
def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
if weighting_scheme == "pyramid":
if num_frames % 2 == 0:
# num_frames = 4 => [1, 2, 2, 1]
weights = list(range(1, num_frames // 2 + 1))
weights = weights + weights[::-1]
else:
# num_frames = 5 => [1, 2, 3, 2, 1]
weights = list(range(1, num_frames // 2 + 1))
weights = weights + [num_frames // 2 + 1] + weights[::-1]
else:
raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")
return weights
def set_free_noise_properties(
self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid"
) -> None:
self.context_length = context_length
self.context_stride = context_stride
self.weighting_scheme = weighting_scheme
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0) -> None:
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
*args,
**kwargs,
) -> torch.Tensor:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
# hidden_states: [B x H x W, F, C]
device = hidden_states.device
dtype = hidden_states.dtype
num_frames = hidden_states.size(1)
frame_indices = self._get_frame_indices(num_frames)
frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme)
frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1)
is_last_frame_batch_complete = frame_indices[-1][1] == num_frames
# Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length
# For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges:
# [(0, 16), (4, 20), (8, 24), (10, 26)]
if not is_last_frame_batch_complete:
if num_frames < self.context_length:
raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}")
last_frame_batch_length = num_frames - frame_indices[-1][1]
frame_indices.append((num_frames - self.context_length, num_frames))
num_times_accumulated = torch.zeros((1, num_frames, 1), device=device)
accumulated_values = torch.zeros_like(hidden_states)
for i, (frame_start, frame_end) in enumerate(frame_indices):
# The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle
# cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or
# essentially a non-multiple of `context_length`.
weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end])
weights *= frame_weights
hidden_states_chunk = hidden_states[:, frame_start:frame_end]
# Notice that normalization is always applied before the real computation in the following blocks.
# 1. Self-Attention
norm_hidden_states = self.norm1(hidden_states_chunk)
if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
hidden_states_chunk = attn_output + hidden_states_chunk
if hidden_states_chunk.ndim == 4:
hidden_states_chunk = hidden_states_chunk.squeeze(1)
# 2. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = self.norm2(hidden_states_chunk)
if self.pos_embed is not None and self.norm_type != "ada_norm_single":
norm_hidden_states = self.pos_embed(norm_hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states_chunk = attn_output + hidden_states_chunk
if i == len(frame_indices) - 1 and not is_last_frame_batch_complete:
accumulated_values[:, -last_frame_batch_length:] += (
hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:]
)
num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length]
else:
accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
num_times_accumulated[:, frame_start:frame_end] += weights
hidden_states = torch.where(
num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
).to(dtype)
# 3. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self._chunk_size is not None:
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
else:
ff_output = self.ff(norm_hidden_states)
hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states
class FeedForward(nn.Module):
r"""
A feed-forward layer.
+426 -5
View File
@@ -227,6 +227,7 @@ class Attention(nn.Module):
self.to_k = None
self.to_v = None
self.added_proj_bias = added_proj_bias
if self.added_kv_proj_dim is not None:
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
@@ -698,12 +699,15 @@ class Attention(nn.Module):
in_features = concatenated_weights.shape[1]
out_features = concatenated_weights.shape[0]
self.to_added_qkv = nn.Linear(in_features, out_features, bias=True, device=device, dtype=dtype)
self.to_added_qkv.weight.copy_(concatenated_weights)
concatenated_bias = torch.cat(
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
self.to_added_qkv = nn.Linear(
in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
)
self.to_added_qkv.bias.copy_(concatenated_bias)
self.to_added_qkv.weight.copy_(concatenated_weights)
if self.added_proj_bias:
concatenated_bias = torch.cat(
[self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
)
self.to_added_qkv.bias.copy_(concatenated_bias)
self.fused_projections = fuse
@@ -1102,6 +1106,326 @@ class JointAttnProcessor2_0:
return hidden_states, encoder_hidden_states
class PAGJointAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"PAGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
) -> torch.FloatTensor:
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
# store the length of image patch sequences to create a mask that prevents interaction between patches
# similar to making the self-attention map an identity matrix
identity_block_size = hidden_states.shape[1]
# chunk
hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
encoder_hidden_states_org, encoder_hidden_states_ptb = encoder_hidden_states.chunk(2)
################## original path ##################
batch_size = encoder_hidden_states_org.shape[0]
# `sample` projections.
query_org = attn.to_q(hidden_states_org)
key_org = attn.to_k(hidden_states_org)
value_org = attn.to_v(hidden_states_org)
# `context` projections.
encoder_hidden_states_org_query_proj = attn.add_q_proj(encoder_hidden_states_org)
encoder_hidden_states_org_key_proj = attn.add_k_proj(encoder_hidden_states_org)
encoder_hidden_states_org_value_proj = attn.add_v_proj(encoder_hidden_states_org)
# attention
query_org = torch.cat([query_org, encoder_hidden_states_org_query_proj], dim=1)
key_org = torch.cat([key_org, encoder_hidden_states_org_key_proj], dim=1)
value_org = torch.cat([value_org, encoder_hidden_states_org_value_proj], dim=1)
inner_dim = key_org.shape[-1]
head_dim = inner_dim // attn.heads
query_org = query_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key_org = key_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value_org = value_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
hidden_states_org = F.scaled_dot_product_attention(
query_org, key_org, value_org, dropout_p=0.0, is_causal=False
)
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_org = hidden_states_org.to(query_org.dtype)
# Split the attention outputs.
hidden_states_org, encoder_hidden_states_org = (
hidden_states_org[:, : residual.shape[1]],
hidden_states_org[:, residual.shape[1] :],
)
# linear proj
hidden_states_org = attn.to_out[0](hidden_states_org)
# dropout
hidden_states_org = attn.to_out[1](hidden_states_org)
if not attn.context_pre_only:
encoder_hidden_states_org = attn.to_add_out(encoder_hidden_states_org)
if input_ndim == 4:
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
if context_input_ndim == 4:
encoder_hidden_states_org = encoder_hidden_states_org.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
################## perturbed path ##################
batch_size = encoder_hidden_states_ptb.shape[0]
# `sample` projections.
query_ptb = attn.to_q(hidden_states_ptb)
key_ptb = attn.to_k(hidden_states_ptb)
value_ptb = attn.to_v(hidden_states_ptb)
# `context` projections.
encoder_hidden_states_ptb_query_proj = attn.add_q_proj(encoder_hidden_states_ptb)
encoder_hidden_states_ptb_key_proj = attn.add_k_proj(encoder_hidden_states_ptb)
encoder_hidden_states_ptb_value_proj = attn.add_v_proj(encoder_hidden_states_ptb)
# attention
query_ptb = torch.cat([query_ptb, encoder_hidden_states_ptb_query_proj], dim=1)
key_ptb = torch.cat([key_ptb, encoder_hidden_states_ptb_key_proj], dim=1)
value_ptb = torch.cat([value_ptb, encoder_hidden_states_ptb_value_proj], dim=1)
inner_dim = key_ptb.shape[-1]
head_dim = inner_dim // attn.heads
query_ptb = query_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key_ptb = key_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value_ptb = value_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# create a full mask with all entries set to 0
seq_len = query_ptb.size(2)
full_mask = torch.zeros((seq_len, seq_len), device=query_ptb.device, dtype=query_ptb.dtype)
# set the attention value between image patches to -inf
full_mask[:identity_block_size, :identity_block_size] = float("-inf")
# set the diagonal of the attention value between image patches to 0
full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0)
# expand the mask to match the attention weights shape
full_mask = full_mask.unsqueeze(0).unsqueeze(0) # Add batch and num_heads dimensions
hidden_states_ptb = F.scaled_dot_product_attention(
query_ptb, key_ptb, value_ptb, attn_mask=full_mask, dropout_p=0.0, is_causal=False
)
hidden_states_ptb = hidden_states_ptb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_ptb = hidden_states_ptb.to(query_ptb.dtype)
# split the attention outputs.
hidden_states_ptb, encoder_hidden_states_ptb = (
hidden_states_ptb[:, : residual.shape[1]],
hidden_states_ptb[:, residual.shape[1] :],
)
# linear proj
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
# dropout
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
if not attn.context_pre_only:
encoder_hidden_states_ptb = attn.to_add_out(encoder_hidden_states_ptb)
if input_ndim == 4:
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
if context_input_ndim == 4:
encoder_hidden_states_ptb = encoder_hidden_states_ptb.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
################ concat ###############
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_ptb])
return hidden_states, encoder_hidden_states
class PAGCFGJointAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"PAGCFGJointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
residual = hidden_states
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
context_input_ndim = encoder_hidden_states.ndim
if context_input_ndim == 4:
batch_size, channel, height, width = encoder_hidden_states.shape
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
identity_block_size = hidden_states.shape[
1
] # patch embeddings width * height (correspond to self-attention map width or height)
# chunk
hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
(
encoder_hidden_states_uncond,
encoder_hidden_states_org,
encoder_hidden_states_ptb,
) = encoder_hidden_states.chunk(3)
encoder_hidden_states_org = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states_org])
################## original path ##################
batch_size = encoder_hidden_states_org.shape[0]
# `sample` projections.
query_org = attn.to_q(hidden_states_org)
key_org = attn.to_k(hidden_states_org)
value_org = attn.to_v(hidden_states_org)
# `context` projections.
encoder_hidden_states_org_query_proj = attn.add_q_proj(encoder_hidden_states_org)
encoder_hidden_states_org_key_proj = attn.add_k_proj(encoder_hidden_states_org)
encoder_hidden_states_org_value_proj = attn.add_v_proj(encoder_hidden_states_org)
# attention
query_org = torch.cat([query_org, encoder_hidden_states_org_query_proj], dim=1)
key_org = torch.cat([key_org, encoder_hidden_states_org_key_proj], dim=1)
value_org = torch.cat([value_org, encoder_hidden_states_org_value_proj], dim=1)
inner_dim = key_org.shape[-1]
head_dim = inner_dim // attn.heads
query_org = query_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key_org = key_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value_org = value_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
hidden_states_org = F.scaled_dot_product_attention(
query_org, key_org, value_org, dropout_p=0.0, is_causal=False
)
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_org = hidden_states_org.to(query_org.dtype)
# Split the attention outputs.
hidden_states_org, encoder_hidden_states_org = (
hidden_states_org[:, : residual.shape[1]],
hidden_states_org[:, residual.shape[1] :],
)
# linear proj
hidden_states_org = attn.to_out[0](hidden_states_org)
# dropout
hidden_states_org = attn.to_out[1](hidden_states_org)
if not attn.context_pre_only:
encoder_hidden_states_org = attn.to_add_out(encoder_hidden_states_org)
if input_ndim == 4:
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
if context_input_ndim == 4:
encoder_hidden_states_org = encoder_hidden_states_org.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
################## perturbed path ##################
batch_size = encoder_hidden_states_ptb.shape[0]
# `sample` projections.
query_ptb = attn.to_q(hidden_states_ptb)
key_ptb = attn.to_k(hidden_states_ptb)
value_ptb = attn.to_v(hidden_states_ptb)
# `context` projections.
encoder_hidden_states_ptb_query_proj = attn.add_q_proj(encoder_hidden_states_ptb)
encoder_hidden_states_ptb_key_proj = attn.add_k_proj(encoder_hidden_states_ptb)
encoder_hidden_states_ptb_value_proj = attn.add_v_proj(encoder_hidden_states_ptb)
# attention
query_ptb = torch.cat([query_ptb, encoder_hidden_states_ptb_query_proj], dim=1)
key_ptb = torch.cat([key_ptb, encoder_hidden_states_ptb_key_proj], dim=1)
value_ptb = torch.cat([value_ptb, encoder_hidden_states_ptb_value_proj], dim=1)
inner_dim = key_ptb.shape[-1]
head_dim = inner_dim // attn.heads
query_ptb = query_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key_ptb = key_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value_ptb = value_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# create a full mask with all entries set to 0
seq_len = query_ptb.size(2)
full_mask = torch.zeros((seq_len, seq_len), device=query_ptb.device, dtype=query_ptb.dtype)
# set the attention value between image patches to -inf
full_mask[:identity_block_size, :identity_block_size] = float("-inf")
# set the diagonal of the attention value between image patches to 0
full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0)
# expand the mask to match the attention weights shape
full_mask = full_mask.unsqueeze(0).unsqueeze(0) # Add batch and num_heads dimensions
hidden_states_ptb = F.scaled_dot_product_attention(
query_ptb, key_ptb, value_ptb, attn_mask=full_mask, dropout_p=0.0, is_causal=False
)
hidden_states_ptb = hidden_states_ptb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states_ptb = hidden_states_ptb.to(query_ptb.dtype)
# split the attention outputs.
hidden_states_ptb, encoder_hidden_states_ptb = (
hidden_states_ptb[:, : residual.shape[1]],
hidden_states_ptb[:, residual.shape[1] :],
)
# linear proj
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
# dropout
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
if not attn.context_pre_only:
encoder_hidden_states_ptb = attn.to_add_out(encoder_hidden_states_ptb)
if input_ndim == 4:
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
if context_input_ndim == 4:
encoder_hidden_states_ptb = encoder_hidden_states_ptb.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
################ concat ###############
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
encoder_hidden_states = torch.cat([encoder_hidden_states_org, encoder_hidden_states_ptb])
return hidden_states, encoder_hidden_states
class FusedJointAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
@@ -1274,6 +1598,103 @@ class AuraFlowAttnProcessor2_0:
return hidden_states
class FusedAuraFlowAttnProcessor2_0:
"""Attention processor used typically in processing Aura Flow with fused projections."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
raise ImportError(
"FusedAuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
*args,
**kwargs,
) -> torch.FloatTensor:
batch_size = hidden_states.shape[0]
# `sample` projections.
qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1)
# `context` projections.
if encoder_hidden_states is not None:
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
split_size = encoder_qkv.shape[-1] // 3
(
encoder_hidden_states_query_proj,
encoder_hidden_states_key_proj,
encoder_hidden_states_value_proj,
) = torch.split(encoder_qkv, split_size, dim=-1)
# Reshape.
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim)
key = key.view(batch_size, -1, attn.heads, head_dim)
value = value.view(batch_size, -1, attn.heads, head_dim)
# Apply QK norm.
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Concatenate the projections.
if encoder_hidden_states is not None:
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# Attention.
hidden_states = F.scaled_dot_product_attention(
query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# Split the attention outputs.
if encoder_hidden_states is not None:
hidden_states, encoder_hidden_states = (
hidden_states[:, encoder_hidden_states.shape[1] :],
hidden_states[:, : encoder_hidden_states.shape[1]],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if encoder_hidden_states is not None:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
if encoder_hidden_states is not None:
return hidden_states, encoder_hidden_states
else:
return hidden_states
# YiYi to-do: refactor rope related functions/classes
def apply_rope(xq, xk, freqs_cis):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
@@ -60,6 +60,8 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
"""
_always_upcast_modules = ["MaskConditionDecoder"]
@register_to_config
def __init__(
self,
@@ -70,6 +70,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]
_always_upcast_modules = ["Decoder"]
@register_to_config
def __init__(
@@ -1,3 +1,18 @@
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import numpy as np
@@ -7,6 +22,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import logging
from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..downsampling import CogVideoXDownsample3D
@@ -16,8 +32,11 @@ from ..upsampling import CogVideoXUpsample3D
from .vae import DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class CogVideoXSafeConv3d(nn.Conv3d):
"""
r"""
A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
"""
@@ -49,12 +68,12 @@ class CogVideoXCausalConv3d(nn.Module):
r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
Args:
in_channels (int): Number of channels in the input tensor.
out_channels (int): Number of output channels.
kernel_size (Union[int, Tuple[int, int, int]]): Size of the convolutional kernel.
stride (int, optional): Stride of the convolution. Default is 1.
dilation (int, optional): Dilation rate of the convolution. Default is 1.
pad_mode (str, optional): Padding mode. Default is "constant".
in_channels (`int`): Number of channels in the input tensor.
out_channels (`int`): Number of output channels produced by the convolution.
kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
stride (`int`, defaults to `1`): Stride of the convolution.
dilation (`int`, defaults to `1`): Dilation rate of the convolution.
pad_mode (`str`, defaults to `"constant"`): Padding mode.
"""
def __init__(
@@ -98,35 +117,31 @@ class CogVideoXCausalConv3d(nn.Module):
self.conv_cache = None
def fake_cp_pass_from_previous_rank(self, inputs: torch.Tensor) -> torch.Tensor:
dim = self.temporal_dim
def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor:
kernel_size = self.time_kernel_size
if kernel_size == 1:
return inputs
inputs = inputs.transpose(0, dim)
if self.conv_cache is not None:
inputs = torch.cat([self.conv_cache.transpose(0, dim).to(inputs.device), inputs], dim=0)
else:
inputs = torch.cat([inputs[:1]] * (kernel_size - 1) + [inputs], dim=0)
inputs = inputs.transpose(0, dim).contiguous()
if kernel_size > 1:
cached_inputs = (
[self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
)
inputs = torch.cat(cached_inputs + [inputs], dim=2)
return inputs
def forward(self, inputs: torch.Tensor, clear_fake_cp_cache: bool = True):
input_parallel = self.fake_cp_pass_from_previous_rank(inputs)
def _clear_fake_context_parallel_cache(self):
del self.conv_cache
self.conv_cache = None
if not clear_fake_cp_cache:
self.conv_cache = input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous().detach().clone().cpu()
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
inputs = self.fake_context_parallel_forward(inputs)
self._clear_fake_context_parallel_cache()
# Note: we could move these to the cpu for a lower maximum memory usage but its only a few
# hundred megabytes and so let's not do it for now
self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
input_parallel = F.pad(input_parallel, padding_2d, mode="constant", value=0)
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
output_parallel = self.conv(input_parallel)
output = output_parallel
output = self.conv(inputs)
return output
@@ -142,15 +157,18 @@ class CogVideoXSpatialNorm3D(nn.Module):
The number of channels for input to group normalization layer, and output of the spatial norm layer.
zq_channels (`int`):
The number of channels for the quantized vector as described in the paper.
groups (`int`):
Number of groups to separate the channels into for group normalization.
"""
def __init__(
self,
f_channels: int,
zq_channels: int,
groups: int = 32,
):
super().__init__()
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True)
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
@@ -175,17 +193,26 @@ class CogVideoXResnetBlock3D(nn.Module):
A 3D ResNet block used in the CogVideoX model.
Args:
in_channels (int): Number of input channels.
out_channels (Optional[int], optional):
Number of output channels. If None, defaults to `in_channels`. Default is None.
dropout (float, optional): Dropout rate. Default is 0.0.
temb_channels (int, optional): Number of time embedding channels. Default is 512.
groups (int, optional): Number of groups for group normalization. Default is 32.
eps (float, optional): Epsilon value for normalization layers. Default is 1e-6.
non_linearity (str, optional): Activation function to use. Default is "swish".
conv_shortcut (bool, optional): If True, use a convolutional shortcut. Default is False.
spatial_norm_dim (Optional[int], optional): Dimension of the spatial normalization. Default is None.
pad_mode (str, optional): Padding mode. Default is "first".
in_channels (`int`):
Number of input channels.
out_channels (`int`, *optional*):
Number of output channels. If None, defaults to `in_channels`.
dropout (`float`, defaults to `0.0`):
Dropout rate.
temb_channels (`int`, defaults to `512`):
Number of time embedding channels.
groups (`int`, defaults to `32`):
Number of groups to separate the channels into for group normalization.
eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
non_linearity (`str`, defaults to `"swish"`):
Activation function to use.
conv_shortcut (bool, defaults to `False`):
Whether or not to use a convolution shortcut.
spatial_norm_dim (`int`, *optional*):
The dimension to use for spatial norm if it is to be used instead of group norm.
pad_mode (str, defaults to `"first"`):
Padding mode.
"""
def __init__(
@@ -217,10 +244,12 @@ class CogVideoXResnetBlock3D(nn.Module):
self.norm1 = CogVideoXSpatialNorm3D(
f_channels=in_channels,
zq_channels=spatial_norm_dim,
groups=groups,
)
self.norm2 = CogVideoXSpatialNorm3D(
f_channels=out_channels,
zq_channels=spatial_norm_dim,
groups=groups,
)
self.conv1 = CogVideoXCausalConv3d(
@@ -250,15 +279,16 @@ class CogVideoXResnetBlock3D(nn.Module):
inputs: torch.Tensor,
temb: Optional[torch.Tensor] = None,
zq: Optional[torch.Tensor] = None,
clear_fake_cp_cache: bool = True,
) -> torch.Tensor:
hidden_states = inputs
if zq is not None:
hidden_states = self.norm1(hidden_states, zq)
else:
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states, clear_fake_cp_cache=clear_fake_cp_cache)
hidden_states = self.conv1(hidden_states)
if temb is not None:
hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
@@ -270,16 +300,13 @@ class CogVideoXResnetBlock3D(nn.Module):
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states, clear_fake_cp_cache=clear_fake_cp_cache)
hidden_states = self.conv2(hidden_states)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
inputs = self.conv_shortcut(inputs, clear_fake_cp_cache=clear_fake_cp_cache)
else:
inputs = self.conv_shortcut(inputs)
inputs = self.conv_shortcut(inputs)
output_tensor = inputs + hidden_states
return output_tensor
hidden_states = hidden_states + inputs
return hidden_states
class CogVideoXDownBlock3D(nn.Module):
@@ -287,18 +314,28 @@ class CogVideoXDownBlock3D(nn.Module):
A downsampling block used in the CogVideoX model.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
temb_channels (int): Number of time embedding channels.
dropout (float, optional): Dropout rate. Default is 0.0.
num_layers (int, optional): Number of layers in the block. Default is 1.
resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6.
resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish".
resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32.
add_downsample (bool, optional): If True, add a downsampling layer at the end of the block. Default is True.
downsample_padding (int, optional): Padding for the downsampling layer. Default is 0.
compress_time (bool, optional): If True, apply temporal compression. Default is False.
pad_mode (str, optional): Padding mode. Default is "first".
in_channels (`int`):
Number of input channels.
out_channels (`int`, *optional*):
Number of output channels. If None, defaults to `in_channels`.
temb_channels (`int`, defaults to `512`):
Number of time embedding channels.
num_layers (`int`, defaults to `1`):
Number of resnet layers.
dropout (`float`, defaults to `0.0`):
Dropout rate.
resnet_eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
resnet_act_fn (`str`, defaults to `"swish"`):
Activation function to use.
resnet_groups (`int`, defaults to `32`):
Number of groups to separate the channels into for group normalization.
add_downsample (`bool`, defaults to `True`):
Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
compress_time (`bool`, defaults to `False`):
Whether or not to downsample across temporal dimension.
pad_mode (str, defaults to `"first"`):
Padding mode.
"""
_supports_gradient_checkpointing = True
@@ -355,7 +392,6 @@ class CogVideoXDownBlock3D(nn.Module):
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
zq: Optional[torch.Tensor] = None,
clear_fake_cp_cache: bool = False,
) -> torch.Tensor:
for resnet in self.resnets:
if self.training and self.gradient_checkpointing:
@@ -367,10 +403,10 @@ class CogVideoXDownBlock3D(nn.Module):
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, zq, clear_fake_cp_cache
create_custom_forward(resnet), hidden_states, temb, zq
)
else:
hidden_states = resnet(hidden_states, temb, zq, clear_fake_cp_cache)
hidden_states = resnet(hidden_states, temb, zq)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
@@ -384,15 +420,24 @@ class CogVideoXMidBlock3D(nn.Module):
A middle block used in the CogVideoX model.
Args:
in_channels (int): Number of input channels.
temb_channels (int): Number of time embedding channels.
dropout (float, optional): Dropout rate. Default is 0.0.
num_layers (int, optional): Number of layers in the block. Default is 1.
resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6.
resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish".
resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32.
spatial_norm_dim (Optional[int], optional): Dimension of the spatial normalization. Default is None.
pad_mode (str, optional): Padding mode. Default is "first".
in_channels (`int`):
Number of input channels.
temb_channels (`int`, defaults to `512`):
Number of time embedding channels.
dropout (`float`, defaults to `0.0`):
Dropout rate.
num_layers (`int`, defaults to `1`):
Number of resnet layers.
resnet_eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
resnet_act_fn (`str`, defaults to `"swish"`):
Activation function to use.
resnet_groups (`int`, defaults to `32`):
Number of groups to separate the channels into for group normalization.
spatial_norm_dim (`int`, *optional*):
The dimension to use for spatial norm if it is to be used instead of group norm.
pad_mode (str, defaults to `"first"`):
Padding mode.
"""
_supports_gradient_checkpointing = True
@@ -435,7 +480,6 @@ class CogVideoXMidBlock3D(nn.Module):
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
zq: Optional[torch.Tensor] = None,
clear_fake_cp_cache: bool = False,
) -> torch.Tensor:
for resnet in self.resnets:
if self.training and self.gradient_checkpointing:
@@ -447,10 +491,10 @@ class CogVideoXMidBlock3D(nn.Module):
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, zq, clear_fake_cp_cache
create_custom_forward(resnet), hidden_states, temb, zq
)
else:
hidden_states = resnet(hidden_states, temb, zq, clear_fake_cp_cache)
hidden_states = resnet(hidden_states, temb, zq)
return hidden_states
@@ -460,19 +504,30 @@ class CogVideoXUpBlock3D(nn.Module):
An upsampling block used in the CogVideoX model.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
temb_channels (int): Number of time embedding channels.
dropout (float, optional): Dropout rate. Default is 0.0.
num_layers (int, optional): Number of layers in the block. Default is 1.
resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6.
resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish".
resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32.
spatial_norm_dim (int, optional): Dimension of the spatial normalization. Default is 16.
add_upsample (bool, optional): If True, add an upsampling layer at the end of the block. Default is True.
upsample_padding (int, optional): Padding for the upsampling layer. Default is 1.
compress_time (bool, optional): If True, apply temporal compression. Default is False.
pad_mode (str, optional): Padding mode. Default is "first".
in_channels (`int`):
Number of input channels.
out_channels (`int`, *optional*):
Number of output channels. If None, defaults to `in_channels`.
temb_channels (`int`, defaults to `512`):
Number of time embedding channels.
dropout (`float`, defaults to `0.0`):
Dropout rate.
num_layers (`int`, defaults to `1`):
Number of resnet layers.
resnet_eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
resnet_act_fn (`str`, defaults to `"swish"`):
Activation function to use.
resnet_groups (`int`, defaults to `32`):
Number of groups to separate the channels into for group normalization.
spatial_norm_dim (`int`, defaults to `16`):
The dimension to use for spatial norm if it is to be used instead of group norm.
add_upsample (`bool`, defaults to `True`):
Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension.
compress_time (`bool`, defaults to `False`):
Whether or not to downsample across temporal dimension.
pad_mode (str, defaults to `"first"`):
Padding mode.
"""
def __init__(
@@ -522,12 +577,13 @@ class CogVideoXUpBlock3D(nn.Module):
]
)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
zq: Optional[torch.Tensor] = None,
clear_fake_cp_cache: bool = False,
) -> torch.Tensor:
r"""Forward method of the `CogVideoXUpBlock3D` class."""
for resnet in self.resnets:
@@ -540,10 +596,10 @@ class CogVideoXUpBlock3D(nn.Module):
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, zq, clear_fake_cp_cache
create_custom_forward(resnet), hidden_states, temb, zq
)
else:
hidden_states = resnet(hidden_states, temb, zq, clear_fake_cp_cache)
hidden_states = resnet(hidden_states, temb, zq)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
@@ -566,14 +622,12 @@ class CogVideoXEncoder3D(nn.Module):
options.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
double_z (`bool`, *optional*, defaults to `True`):
Whether to double the number of output channels for the last block.
"""
_supports_gradient_checkpointing = True
@@ -651,11 +705,9 @@ class CogVideoXEncoder3D(nn.Module):
self.gradient_checkpointing = False
def forward(
self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None, clear_fake_cp_cache: bool = True
) -> torch.Tensor:
def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
r"""The forward method of the `CogVideoXEncoder3D` class."""
hidden_states = self.conv_in(sample, clear_fake_cp_cache=clear_fake_cp_cache)
hidden_states = self.conv_in(sample)
if self.training and self.gradient_checkpointing:
@@ -668,25 +720,25 @@ class CogVideoXEncoder3D(nn.Module):
# 1. Down
for down_block in self.down_blocks:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(down_block), hidden_states, temb, None, clear_fake_cp_cache
create_custom_forward(down_block), hidden_states, temb, None
)
# 2. Mid
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), hidden_states, temb, None, clear_fake_cp_cache
create_custom_forward(self.mid_block), hidden_states, temb, None
)
else:
# 1. Down
for down_block in self.down_blocks:
hidden_states = down_block(hidden_states, temb, None, clear_fake_cp_cache)
hidden_states = down_block(hidden_states, temb, None)
# 2. Mid
hidden_states = self.mid_block(hidden_states, temb, None, clear_fake_cp_cache)
hidden_states = self.mid_block(hidden_states, temb, None)
# 3. Post-process
hidden_states = self.norm_out(hidden_states)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states, clear_fake_cp_cache=clear_fake_cp_cache)
hidden_states = self.conv_out(hidden_states)
return hidden_states
@@ -704,14 +756,12 @@ class CogVideoXDecoder3D(nn.Module):
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
norm_type (`str`, *optional*, defaults to `"group"`):
The normalization type to use. Can be either `"group"` or `"spatial"`.
"""
_supports_gradient_checkpointing = True
@@ -788,7 +838,7 @@ class CogVideoXDecoder3D(nn.Module):
self.up_blocks.append(up_block)
self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels)
self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels, groups=norm_num_groups)
self.conv_act = nn.SiLU()
self.conv_out = CogVideoXCausalConv3d(
reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode
@@ -796,11 +846,9 @@ class CogVideoXDecoder3D(nn.Module):
self.gradient_checkpointing = False
def forward(
self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None, clear_fake_cp_cache: bool = True
) -> torch.Tensor:
def forward(self, sample: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
r"""The forward method of the `CogVideoXDecoder3D` class."""
hidden_states = self.conv_in(sample, clear_fake_cp_cache=clear_fake_cp_cache)
hidden_states = self.conv_in(sample)
if self.training and self.gradient_checkpointing:
@@ -812,32 +860,33 @@ class CogVideoXDecoder3D(nn.Module):
# 1. Mid
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), hidden_states, temb, sample, clear_fake_cp_cache
create_custom_forward(self.mid_block), hidden_states, temb, sample
)
# 2. Up
for up_block in self.up_blocks:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block), hidden_states, temb, sample, clear_fake_cp_cache
create_custom_forward(up_block), hidden_states, temb, sample
)
else:
# 1. Mid
hidden_states = self.mid_block(hidden_states, temb, sample, clear_fake_cp_cache)
hidden_states = self.mid_block(hidden_states, temb, sample)
# 2. Up
for up_block in self.up_blocks:
hidden_states = up_block(hidden_states, temb, sample, clear_fake_cp_cache)
hidden_states = up_block(hidden_states, temb, sample)
# 3. Post-process
hidden_states = self.norm_out(hidden_states, sample)
hidden_states = self.conv_act(hidden_states)
hidden_states = self.conv_out(hidden_states, clear_fake_cp_cache=clear_fake_cp_cache)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encodfing images into latents and decoding latent representations into images.
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
[CogVideoX](https://github.com/THUDM/CogVideo).
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
@@ -864,9 +913,6 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
can be fine-tuned / trained to a lower range without loosing too much precision in which case
`force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
mid_block_add_attention (`bool`, *optional*, default to `True`):
If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the
mid_block will only have resnet blocks
"""
_supports_gradient_checkpointing = True
@@ -896,7 +942,8 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
norm_eps: float = 1e-6,
norm_num_groups: int = 32,
temporal_compression_ratio: float = 4,
sample_size: int = 256,
sample_height: int = 480,
sample_width: int = 720,
scaling_factor: float = 1.15258426,
shift_factor: Optional[float] = None,
latents_mean: Optional[Tuple[float]] = None,
@@ -904,7 +951,6 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
force_upcast: float = True,
use_quant_conv: bool = False,
use_post_quant_conv: bool = False,
mid_block_add_attention: bool = True,
):
super().__init__()
@@ -936,22 +982,108 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.use_slicing = False
self.use_tiling = False
self.tile_sample_min_size = self.config.sample_size
sample_size = (
self.config.sample_size[0]
if isinstance(self.config.sample_size, (list, tuple))
else self.config.sample_size
# Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
# recommended because the temporal parts of the VAE, here, are tricky to understand.
# If you decode X latent frames together, the number of output frames is:
# (X + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) => X + 6 frames
#
# Example with num_latent_frames_batch_size = 2:
# - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together
# => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
# => 6 * 8 = 48 frames
# - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together
# => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) +
# ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
# => 1 * 9 + 5 * 8 = 49 frames
# It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that
# setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
# number of temporal frames.
self.num_latent_frames_batch_size = 2
# We make the minimum height and width of sample for tiling half that of the generally supported
self.tile_sample_min_height = sample_height // 2
self.tile_sample_min_width = sample_width // 2
self.tile_latent_min_height = int(
self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
)
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
# These are experimental overlap factors that were chosen based on experimentation and seem to work best for
# 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX
# and so the tiling implementation has only been tested on those specific resolutions.
self.tile_overlap_factor_height = 1 / 6
self.tile_overlap_factor_width = 1 / 5
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
module.gradient_checkpointing = value
def _clear_fake_context_parallel_cache(self):
for name, module in self.named_modules():
if isinstance(module, CogVideoXCausalConv3d):
logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
module._clear_fake_context_parallel_cache()
def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
tile_sample_min_width: Optional[int] = None,
tile_overlap_factor_height: Optional[float] = None,
tile_overlap_factor_width: Optional[float] = None,
) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
Args:
tile_sample_min_height (`int`, *optional*):
The minimum height required for a sample to be separated into tiles across the height dimension.
tile_sample_min_width (`int`, *optional*):
The minimum width required for a sample to be separated into tiles across the width dimension.
tile_overlap_factor_height (`int`, *optional*):
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
value might cause more tiles to be processed leading to slow down of the decoding process.
tile_overlap_factor_width (`int`, *optional*):
The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
value might cause more tiles to be processed leading to slow down of the decoding process.
"""
self.use_tiling = True
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
self.tile_latent_min_height = int(
self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
)
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True, fake_cp: bool = False
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
@@ -960,14 +1092,12 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
fake_cp (`bool`, *optional*, defaults to `True`):
If True, the fake context parallel will be used to reduce GPU memory consumption (Only 1 GPU work).
Returns:
The latent representations of the encoded images. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
h = self.encoder(x, clear_fake_cp_cache=not fake_cp)
h = self.encoder(x)
if self.quant_conv is not None:
h = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(h)
@@ -975,10 +1105,34 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = z.shape
if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
return self.tiled_decode(z, return_dict=return_dict)
frame_batch_size = self.num_latent_frames_batch_size
dec = []
for i in range(num_frames // frame_batch_size):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
end_frame = frame_batch_size * (i + 1) + remaining_frames
z_intermediate = z[:, :, start_frame:end_frame]
if self.post_quant_conv is not None:
z_intermediate = self.post_quant_conv(z_intermediate)
z_intermediate = self.decoder(z_intermediate)
dec.append(z_intermediate)
self._clear_fake_context_parallel_cache()
dec = torch.cat(dec, dim=2)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
@apply_forward_hook
def decode(
self, z: torch.FloatTensor, return_dict: bool = True, fake_cp: bool = False
) -> Union[DecoderOutput, torch.FloatTensor]:
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
"""
Decode a batch of images.
@@ -986,20 +1140,116 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
fake_cp (`bool`, *optional*, defaults to `True`):
If True, the fake context parallel will be used to reduce GPU memory consumption (Only 1 GPU work).
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.post_quant_conv is not None:
z = self.post_quant_conv(z)
dec = self.decoder(z, clear_fake_cp_cache=not fake_cp)
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
y / blend_extent
)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
x / blend_extent
)
return b
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
# Rough memory assessment:
# - In CogVideoX-2B, there are a total of 24 CausalConv3d layers.
# - The biggest intermediate dimensions are: [1, 128, 9, 480, 720].
# - Assume fp16 (2 bytes per value).
# Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB
#
# Memory assessment when using tiling:
# - Assume everything as above but now HxW is 240x360 by tiling in half
# Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB
batch_size, num_channels, num_frames, height, width = z.shape
overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
row_limit_height = self.tile_sample_min_height - blend_extent_height
row_limit_width = self.tile_sample_min_width - blend_extent_width
frame_batch_size = self.num_latent_frames_batch_size
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, overlap_height):
row = []
for j in range(0, width, overlap_width):
time = []
for k in range(num_frames // frame_batch_size):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
end_frame = frame_batch_size * (k + 1) + remaining_frames
tile = z[
:,
:,
start_frame:end_frame,
i : i + self.tile_latent_min_height,
j : j + self.tile_latent_min_width,
]
if self.post_quant_conv is not None:
tile = self.post_quant_conv(tile)
tile = self.decoder(tile)
time.append(tile)
self._clear_fake_context_parallel_cache()
row.append(torch.cat(time, dim=2))
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent_width)
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
result_rows.append(torch.cat(result_row, dim=4))
dec = torch.cat(result_rows, dim=3)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(
@@ -192,6 +192,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
"""
_supports_gradient_checkpointing = True
_always_upcast_modules = ["TemporalDecoder"]
@register_to_config
def __init__(

Some files were not shown because too many files have changed in this diff Show More