Compare commits

..

67 Commits

Author SHA1 Message Date
DN6 d9915a7d65 update 2025-03-12 11:44:40 +05:30
DN6 b7a795dbeb update 2025-03-12 11:40:40 +05:30
DN6 438905d63e update 2025-03-12 11:37:27 +05:30
DN6 904f24de5a update 2025-03-12 11:35:18 +05:30
DN6 e123bbcbc4 memmap 2025-03-12 11:23:14 +05:30
DN6 b3fa8c695d remove cpu param dict 2025-03-12 09:02:04 +05:30
DN6 720be2bac5 update 2025-03-12 08:49:45 +05:30
DN6 e74b782aac update 2025-03-12 08:45:09 +05:30
DN6 d6392b4b49 update 2025-03-12 08:18:19 +05:30
DN6 1475026960 sliding-window 2025-03-11 13:56:39 +05:30
DN6 878eb4ce35 update 2025-03-11 13:21:09 +05:30
Dhruv Nair 9add071592 [Quantization] Allow loading TorchAO serialized Tensor objects with torch>=2.6 (#11018)
* update

* update

* update

* update

* update

* update

* update

* update

* update
2025-03-11 10:52:01 +05:30
Tolga Cangöz b88fef4785 [Research Project] Add AnyText: Multilingual Visual Text Generation And Editing (#8998)
* Add initial template

* Second template

* feat: Add TextEmbeddingModule to AnyTextPipeline

* feat: Add AuxiliaryLatentModule template to AnyTextPipeline

* Add bert tokenizer from the anytext repo for now

* feat: Update AnyTextPipeline's modify_prompt method

This commit adds improvements to the modify_prompt method in the AnyTextPipeline class. The method now handles special characters and replaces selected string prompts with a placeholder. Additionally, it includes a check for Chinese text and translation using the trans_pipe.

* Fill in the `forward` pass of `AuxiliaryLatentModule`

* `make style && make quality`

* `chore: Update bert_tokenizer.py with a TODO comment suggesting the use of the transformers library`

* Update error handling to raise and logging

* Add `create_glyph_lines` function into `TextEmbeddingModule`

* make style

* Up

* Up

* Up

* Up

* Remove several comments

* refactor: Remove ControlNetConditioningEmbedding and update code accordingly

* Up

* Up

* up

* refactor: Update AnyTextPipeline to include new optional parameters

* up

* feat: Add OCR model and its components

* chore: Update `TextEmbeddingModule` to include OCR model components and dependencies

* chore: Update `AuxiliaryLatentModule` to include VAE model and its dependencies for masked image in the editing task

* `make style`

* refactor: Update `AnyTextPipeline`'s docstring

* Update `AuxiliaryLatentModule` to include info dictionary so that text processing is done once

* simplify

* `make style`

* Converting `TextEmbeddingModule` to ordinary `encode_prompt()` function

* Simplify for now

* `make style`

* Up

* feat: Add scripts to convert AnyText controlnet to diffusers

* `make style`

* Fix: Move glyph rendering to `TextEmbeddingModule` from `AuxiliaryLatentModule`

* make style

* Up

* Simplify

* Up

* feat: Add safetensors module for loading model file

* Fix device issues

* Up

* Up

* refactor: Simplify

* refactor: Simplify code for loading models and handling data types

* `make style`

* refactor: Update to() method in FrozenCLIPEmbedderT3 and TextEmbeddingModule

* refactor: Update dtype in embedding_manager.py to match proj.weight

* Up

* Add attribution and adaptation information to pipeline_anytext.py

* Update usage example

* Will refactor `controlnet_cond_embedding` initialization

* Add `AnyTextControlNetConditioningEmbedding` template

* Refactor organization

* style

* style

* Move custom blocks from `AuxiliaryLatentModule` to `AnyTextControlNetConditioningEmbedding`

* Follow one-file policy

* style

* [Docs] Update README and pipeline_anytext.py to use AnyTextControlNetModel

* [Docs] Update import statement for AnyTextControlNetModel in pipeline_anytext.py

* [Fix] Update import path for ControlNetModel, ControlNetOutput in anytext_controlnet.py

* Refactor AnyTextControlNet to use configurable conditioning embedding channels

* Complete control net conditioning embedding in AnyTextControlNetModel

* up

* [FIX] Ensure embeddings use correct device in AnyTextControlNetModel

* up

* up

* style

* [UPDATE] Revise README and example code for AnyTextPipeline integration with DiffusionPipeline

* [UPDATE] Update example code in anytext.py to use correct font file and improve clarity

* down

* [UPDATE] Refactor BasicTokenizer usage to a new Checker class for text processing

* update pillow

* [UPDATE] Remove commented-out code and unnecessary docstring in anytext.py and anytext_controlnet.py for improved clarity

* [REMOVE] Delete frozen_clip_embedder_t3.py as it is in the anytext.py file

* [UPDATE] Replace edict with dict for configuration in anytext.py and RecModel.py for consistency

* 🆙

* style

* [UPDATE] Revise README.md for clarity, remove unused imports in anytext.py, and add author credits in anytext_controlnet.py

* style

* Update examples/research_projects/anytext/README.md

Co-authored-by: Aryan <contact.aryanvs@gmail.com>

* Remove commented-out image preparation code in AnyTextPipeline

* Remove unnecessary blank line in README.md
2025-03-11 01:49:37 +05:30
Sayak Paul e7e6d85282 [Tests] improve quantization tests by additionally measuring the inference memory savings (#11021)
* memory usage tests

* fixes

* gguf
2025-03-10 21:42:24 +05:30
Aryan 8eefed65bd [LoRA] CogView4 (#10981)
* update

* make fix-copies

* update
2025-03-10 20:24:05 +05:30
Sayak Paul 26149c0ecd [LoRA] Improve warning messages when LoRA loading becomes a no-op (#10187)
* updates

* updates

* updates

* updates

* notebooks revert

* fix-copies.

* seeing

* fix

* revert

* fixes

* fixes

* fixes

* remove print

* fix

* conflicts ii.

* updates

* fixes

* better filtering of prefix.

---------

Co-authored-by: hlky <hlky@hlky.ac>
2025-03-10 09:28:32 +05:30
Ishan Modi 0703ce8800 [Single File] Add single file loading for SANA Transformer (#10947)
* added support for from_single_file

* added diffusers mapping script

* added testcase

* bug fix

* updated tests

* corrected code quality

* corrected code quality

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2025-03-10 08:38:30 +05:30
Dhruv Nair f5edaa7894 [Quantization] Add Quanto backend (#10756)
* update

* updaet

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* Update docs/source/en/quantization/quanto.md

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* Update src/diffusers/quantizers/quanto/utils.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* update

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-03-10 08:33:05 +05:30
Dhruv Nair 9a1810f0de Fix for fetching variants only (#10646)
* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update
2025-03-10 07:45:44 +05:30
Sayak Paul 1fddee211e [LoRA] Improve copied from comments in the LoRA loader classes (#10995)
* more sanity of mind with copied from ...

* better

* better
2025-03-08 19:59:21 +05:30
Kinam Kim b38450d5d2 Add STG to community pipelines (#10960)
* Support STG for video pipelines

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update pipeline_stg_cogvideox.py

* Update pipeline_stg_hunyuan_video.py

* Update pipeline_stg_ltx.py

* Update pipeline_stg_ltx_image2video.py

* Update pipeline_stg_mochi.py

* Update pipeline_stg_hunyuan_video.py

* Update pipeline_stg_ltx.py

* Update pipeline_stg_ltx_image2video.py

* Update pipeline_stg_mochi.py

* update

* remove rescaling

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-03-08 00:28:24 +05:30
Dhruv Nair 1357931d74 [Single File] Add single file support for Wan T2V/I2V (#10991)
* update

* update

* update

* update

* update

* update

* update
2025-03-07 22:13:25 +05:30
Sayak Paul a2d3d6af44 [LoRA] remove full key prefix from peft. (#11004)
remove full key prefix from peft.
2025-03-07 21:51:59 +05:30
hlky 363d1ab7e2 Wan VAE move scaling to pipeline (#10998) 2025-03-07 10:42:17 +00:00
C 6a0137eb3b Fix Graph Breaks When Compiling CogView4 (#10959)
* Fix Graph Breaks When Compiling CogView4

Eliminate this:

```
t]V0304 10:24:23.421000 3131076 torch/_dynamo/guards.py:2813] [0/4] [__recompiles] Recompiling function forward in /home/zeyi/repos/diffusers/src/diffusers/models/transformers/transformer_cogview4.py:374
V0304 10:24:23.421000 3131076 torch/_dynamo/guards.py:2813] [0/4] [__recompiles]     triggered by the following guard failure(s):
V0304 10:24:23.421000 3131076 torch/_dynamo/guards.py:2813] [0/4] [__recompiles]     - 0/3: ___check_obj_id(L['self'].rope.freqs_h, 139976127328032)    
V0304 10:24:23.421000 3131076 torch/_dynamo/guards.py:2813] [0/4] [__recompiles]     - 0/2: ___check_obj_id(L['self'].rope.freqs_h, 139976107780960)    
V0304 10:24:23.421000 3131076 torch/_dynamo/guards.py:2813] [0/4] [__recompiles]     - 0/1: ___check_obj_id(L['self'].rope.freqs_h, 140022511848960)    
V0304 10:24:23.421000 3131076 torch/_dynamo/guards.py:2813] [0/4] [__recompiles]     - 0/0: ___check_obj_id(L['self'].rope.freqs_h, 140024081342416)   
```

* Update transformer_cogview4.py

* fix cogview4 rotary pos embed

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-03-06 22:57:17 -10:00
Aryan 2e5203be04 Hunyuan I2V (#10983)
* update

* update

* update

* add tests

* update

* add model tests

* update docs

* update

* update example

* fix defaults

* update
2025-03-07 12:52:48 +05:30
yupeng1111 d55f41102a fix wan i2v pipeline bugs (#10975)
* fix wan i2v pipeline bugs

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-03-06 18:57:41 -10:00
LittleNyima 748cb0fab6 Add CogVideoX DDIM Inversion to Community Pipelines (#10956)
* add cogvideox ddim inversion script

* implement as a pipeline, and add documentation

---------

Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
2025-03-06 10:46:38 -10:00
Dhruv Nair 790a909b54 [Single File] Add user agent to SF download requests. (#10979)
update
2025-03-06 10:45:20 -10:00
CyberVy 54ab475391 Fix Flux Controlnet Pipeline _callback_tensor_inputs Missing Some Elements (#10974)
* Update pipeline_flux_controlnet.py

* Update pipeline_flux_controlnet_image_to_image.py

* Update pipeline_flux_controlnet_inpainting.py

* Update pipeline_flux_controlnet_inpainting.py

* Update pipeline_flux_controlnet_inpainting.py
2025-03-06 14:26:20 -03:00
dependabot[bot] f103993094 Bump jinja2 from 3.1.5 to 3.1.6 in /examples/research_projects/realfill (#10984)
Bumps [jinja2](https://github.com/pallets/jinja) from 3.1.5 to 3.1.6.
- [Release notes](https://github.com/pallets/jinja/releases)
- [Changelog](https://github.com/pallets/jinja/blob/main/CHANGES.rst)
- [Commits](https://github.com/pallets/jinja/compare/3.1.5...3.1.6)

---
updated-dependencies:
- dependency-name: jinja2
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-03-06 11:59:51 +00:00
Sayak Paul 1be0202502 [CI] remove synchornized. (#10980)
removed synchornized.
2025-03-06 17:03:19 +05:30
Pierre Chapuis ea81a4228d fix default values of Flux guidance_scale in docstrings (#10982) 2025-03-06 16:37:45 +05:30
hlky b15027636a Fix loading OneTrainer Flux LoRA (#10978)
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-03-06 13:53:36 +05:30
Sayak Paul 6e2a93de70 [tests] fix tests for save load components (#10977)
fix tests
2025-03-06 12:30:37 +05:30
Jun Yeop Na 37b8edfb86 [train_dreambooth_lora.py] Fix the LR Schedulers when num_train_epochs is passed in a distributed training env (#10973)
* updated train_dreambooth_lora to fix the LR schedulers for `num_train_epochs` in distributed training env

* fixed formatting

* remove trailing newlines

* fixed style error
2025-03-06 10:06:24 +05:30
Célina fbf6b856cc use style bot GH Action from huggingface_hub (#10970)
use style bot GH action from hfh

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-03-05 23:39:50 +05:30
Linoy Tsaban e031caf4ea [flux lora training] fix t5 training bug (#10845)
* fix t5 training bug

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-03-05 13:47:01 +02:00
hlky 08f74a8b92 Add VAE Decode endpoint slow test (#10946) 2025-03-05 11:28:06 +00:00
YiYi Xu 24c062aaa1 update check_input for cogview4 (#10966)
fix
2025-03-04 12:12:54 -10:00
Yuxuan Zhang a74f02fb40 [Docs] CogView4 comment fix (#10957)
* Update pipeline_cogview4.py

* Use GLM instead of T5 in doc
2025-03-04 11:25:43 -10:00
Eliseu Silva 66bf7ea5be feat: add Mixture-of-Diffusers ControlNet Tile upscaler Pipeline for SDXL (#10951)
* feat: add Mixture-of-Diffusers ControlNet Tile upscaler Pipeline for SDXL

* make style make quality
2025-03-04 17:17:36 -03:00
Alexey Zolotenkov b8215b1c06 Fix incorrect seed initialization when args.seed is 0 (#10964)
* Fix seed initialization to handle args.seed = 0 correctly

* Apply style fixes

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-03-04 10:09:52 -10:00
Aryan 3ee899fa0c [LoRA] Support Wan (#10943)
* update

* refactor image-to-video pipeline

* update

* fix copied from

* use FP32LayerNorm
2025-03-05 01:27:34 +05:30
CyberVy dcd77ce222 Fix the missing parentheses when calling is_torchao_available in quantization_config.py. (#10961)
Update quantization_config.py
2025-03-04 09:52:41 -03:00
a120092009 11d8e3ce2c [Quantization] support pass MappingType for TorchAoConfig (#10927)
* [Quantization] support pass MappingType for TorchAoConfig

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-03-04 16:40:50 +05:30
Sayak Paul 97fda1b75c [LoRA] feat: support non-diffusers lumina2 LoRAs. (#10909)
* feat: support non-diffusers lumina2 LoRAs.

* revert ipynb changes (but I don't know why this is required ☹️)

* empty

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-03-04 14:40:55 +05:30
Sayak Paul cc22058324 Update evaluation.md (#10938)
* Update evaluation.md

* Update docs/source/en/conceptual/evaluation.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2025-03-04 13:58:16 +05:30
Fanli Lin 7855ac597e [tests] make tests device-agnostic (part 4) (#10508)
* initial comit

* fix empty cache

* fix one more

* fix style

* update device functions

* update

* update

* Update src/diffusers/utils/testing_utils.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update src/diffusers/utils/testing_utils.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update src/diffusers/utils/testing_utils.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update tests/pipelines/controlnet/test_controlnet.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update src/diffusers/utils/testing_utils.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update src/diffusers/utils/testing_utils.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update tests/pipelines/controlnet/test_controlnet.py

Co-authored-by: hlky <hlky@hlky.ac>

* with gc.collect

* update

* make style

* check_torch_dependencies

* add mps empty cache

* add changes

* bug fix

* enable on xpu

* update more cases

* revert

* revert back

* Update test_stable_diffusion_xl.py

* Update tests/pipelines/stable_diffusion/test_stable_diffusion.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update tests/pipelines/stable_diffusion/test_stable_diffusion.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py

Co-authored-by: hlky <hlky@hlky.ac>

* Apply suggestions from code review

Co-authored-by: hlky <hlky@hlky.ac>

* add test marker

---------

Co-authored-by: hlky <hlky@hlky.ac>
2025-03-04 08:26:06 +00:00
CyberVy 30cef6bff3 Improve load_ip_adapter RAM Usage (#10948)
* Update ip_adapter.py

* Update ip_adapter.py

* Update ip_adapter.py

* Update ip_adapter.py

* Update ip_adapter.py

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: hlky <hlky@hlky.ac>
2025-03-04 07:21:23 +00:00
Ahmed Belgacem 8f15be169f Fix redundant prev_output_channel assignment in UNet2DModel (#10945) 2025-03-03 11:43:15 -10:00
Yuxuan Zhang f92e599c70 Update pipeline_cogview4.py (#10944) 2025-03-03 09:42:01 -10:00
Parag Ekbote 982f9b38d6 Add Example of IPAdapterScaleCutoffCallback to Docs (#10934)
* Add example of Ip-Adapter-Callback.

* Add image links from HF Hub.
2025-03-03 08:32:45 -08:00
fancydaddy c9a219b323 add from_single_file to animatediff (#10924)
* Update pipeline_animatediff.py

* Update pipeline_animatediff_controlnet.py

* Update pipeline_animatediff_sparsectrl.py

* Update pipeline_animatediff_video2video.py

* Update pipeline_animatediff_video2video_controlnet.py

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2025-03-03 19:11:54 +05:30
Teriks 9e910c4633 Fix SD2.X clip single file load projection_dim (#10770)
* Fix SD2.X clip single file load projection_dim

Infer projection_dim from the checkpoint before loading
from pretrained, override any incorrect hub config.

Hub configuration for SD2.X specifies projection_dim=512
which is incorrect for SD2.X checkpoints loaded from civitai
and similar.

Exception was previously thrown upon attempting to
load_model_dict_into_meta for SD2.X single file checkpoints.

Such LDM models usually require projection_dim=1024

* convert_open_clip_checkpoint use hidden_size for text_proj_dim

* convert_open_clip_checkpoint, revert checkpoint[text_proj_key].shape[1] -> [0]

values are identical

---------

Co-authored-by: Teriks <Teriks@users.noreply.github.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2025-03-03 19:00:39 +05:30
Bubbliiiing 5e3b7d2d8a Add EasyAnimateV5.1 text-to-video, image-to-video, control-to-video generation model (#10626)
* Update EasyAnimate V5.1

* Add docs && add tests && Fix comments problems in transformer3d and vae

* delete comments and remove useless import

* delete process

* Update EXAMPLE_DOC_STRING

* rename transformer file

* make fix-copies

* make style

* refactor pt. 1

* update toctree.yml

* add model tests

* Update layer_norm for norm_added_q and norm_added_k in Attention

* Fix processor problem

* refactor vae

* Fix problem in comments

* refactor tiling; remove einops dependency

* fix docs path

* make fix-copies

* Update src/diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py

* update _toctree.yml

* fix test

* update

* update

* update

* make fix-copies

* fix tests

---------

Co-authored-by: Aryan <aryan@huggingface.co>
Co-authored-by: Aryan <contact.aryanvs@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2025-03-03 18:37:19 +05:30
Sayak Paul 7513162b8b [Tests] Remove more encode prompts tests (#10942)
* fix-copies went uncaught it seems.

* remove more unneeded encode_prompt() tests

* Revert "fix-copies went uncaught it seems."

This reverts commit eefb302791.

* empty
2025-03-03 16:55:01 +05:30
Sayak Paul 4aaa0d21ba [chore] fix-copies to flux pipelines (#10941)
fix-copies went uncaught it seems.
2025-03-03 11:21:57 +05:30
hlky 54043c3e2e Update VAE Decode endpoints (#10939) 2025-03-02 18:29:53 +00:00
hlky fc4229a0c3 Add remote_decode to remote_utils (#10898)
* Add `remote_decode` to `remote_utils`

* test dependency

* test dependency

* dependency

* dependency

* dependency

* docstrings

* changes

* make style

* apply

* revert, add new options

* Apply style fixes

* deprecate base64, headers not needed

* address comments

* add license header

* init test_remote_decode

* more

* more test

* more test

* skeleton for xl, flux

* more test

* flux test

* flux packed

* no scaling

* -save

* hunyuanvideo test

* Apply style fixes

* init docs

* Update src/diffusers/utils/remote_utils.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* comments

* Apply style fixes

* comments

* hybrid_inference/vae_decode

* fix

* tip?

* tip

* api reference autodoc

* install tip

---------

Co-authored-by: sayakpaul <spsayakpaul@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-03-02 17:10:01 +00:00
hlky 694f9658c1 Support IPAdapter for more Flux pipelines (#10708)
* Support IPAdapter for more Flux pipelines

* -copied from

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-03-02 15:04:12 +00:00
YiYi Xu 2d8a41cae8 [Alibaba Wan Team] continue on #10921 Wan2.1 (#10922)
* Add wanx pipeline, model and example

* wanx_merged_v1

* change WanX into Wan

* fix i2v fp32 oom error

Link: https://code.alibaba-inc.com/open_wanx2/diffusers/codereview/20607813

* support t2v load fp32 ckpt

* add example

* final merge v1

* Update autoencoder_kl_wan.py

* up

* update middle, test up_block

* up up

* one less nn.sequential

* up more

* up

* more

* [refactor] [wip] Wan transformer/pipeline (#10926)

* update

* update

* refactor rope

* refactor pipeline

* make fix-copies

* add transformer test

* update

* update

* make style

* update tests

* tests

* conversion script

* conversion script

* update

* docs

* remove unused code

* fix _toctree.yml

* update dtype

* fix test

* fix tests: scale

* up

* more

* Apply suggestions from code review

* Apply suggestions from code review

* style

* Update scripts/convert_wan_to_diffusers.py

* update docs

* fix

---------

Co-authored-by: Yitong Huang <huangyitong.hyt@alibaba-inc.com>
Co-authored-by: 亚森 <wangjiayu.wjy@alibaba-inc.com>
Co-authored-by: Aryan <aryan@huggingface.co>
2025-03-02 17:24:26 +05:30
Dhruv Nair 7007febae5 [CI] Update Stylebot Permissions (#10931)
update
2025-03-01 09:43:05 +05:30
Sayak Paul d230ecc570 [style bot] improve security for the stylebot. (#10908)
* improve security for the stylebot.

* 
2025-02-28 22:01:31 +05:30
hlky 37a5f1b3b6 Experimental per control type scale for ControlNet Union (#10723)
* ControlNet Union scale

* fix

* universal interface

* from_multi

* from_multi
2025-02-27 10:23:38 +00:00
Dhruv Nair 501d9de701 [CI] Fix for failing IP Adapter test in Fast GPU PR tests (#10915)
* update

* update

* update

* update
2025-02-27 14:22:28 +05:30
Dhruv Nair e5c43b8af7 [CI] Fix Fast GPU tests on PR (#10912)
* update

* update

* update

* update

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-02-27 14:21:50 +05:30
241 changed files with 28455 additions and 2103 deletions
+2
View File
@@ -418,6 +418,8 @@ jobs:
test_location: "gguf"
- backend: "torchao"
test_location: "torchao"
- backend: "optimum_quanto"
test_location: "quanto"
runs-on:
group: aws-g6e-xlarge-plus
container:
+33 -109
View File
@@ -9,119 +9,43 @@ permissions:
pull-requests: write
jobs:
run-style-bot:
if: >
contains(github.event.comment.body, '@bot /style') &&
github.event.issue.pull_request != null
runs-on: ubuntu-latest
style:
uses: huggingface/huggingface_hub/.github/workflows/style-bot-action.yml@main
with:
python_quality_dependencies: "[quality]"
pre_commit_script_name: "Download and Compare files from the main branch"
pre_commit_script: |
echo "Downloading the files from the main branch"
steps:
- name: Extract PR details
id: pr_info
uses: actions/github-script@v6
with:
script: |
const prNumber = context.payload.issue.number;
const { data: pr } = await github.rest.pulls.get({
owner: context.repo.owner,
repo: context.repo.repo,
pull_number: prNumber
});
// We capture both the branch ref and the "full_name" of the head repo
// so that we can check out the correct repository & branch (including forks).
core.setOutput("prNumber", prNumber);
core.setOutput("headRef", pr.head.ref);
core.setOutput("headRepoFullName", pr.head.repo.full_name);
curl -o main_Makefile https://raw.githubusercontent.com/huggingface/diffusers/main/Makefile
curl -o main_setup.py https://raw.githubusercontent.com/huggingface/diffusers/refs/heads/main/setup.py
curl -o main_check_doc_toc.py https://raw.githubusercontent.com/huggingface/diffusers/refs/heads/main/utils/check_doc_toc.py
- name: Check out PR branch
uses: actions/checkout@v3
env:
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
HEADREF: ${{ steps.pr_info.outputs.headRef }}
with:
# Instead of checking out the base repo, use the contributor's repo name
repository: ${{ env.HEADREPOFULLNAME }}
ref: ${{ env.HEADREF }}
# You may need fetch-depth: 0 for being able to push
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Debug
env:
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
HEADREF: ${{ steps.pr_info.outputs.headRef }}
PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
run: |
echo "PR number: $PRNUMBER"
echo "Head Ref: $HEADREF"
echo "Head Repo Full Name: $HEADREPOFULLNAME"
echo "Compare the files and raise error if needed"
- name: Set up Python
uses: actions/setup-python@v4
diff_failed=0
if ! diff -q main_Makefile Makefile; then
echo "Error: The Makefile has changed. Please ensure it matches the main branch."
diff_failed=1
fi
- name: Install dependencies
run: |
pip install .[quality]
if ! diff -q main_setup.py setup.py; then
echo "Error: The setup.py has changed. Please ensure it matches the main branch."
diff_failed=1
fi
- name: Download Makefile from main branch
run: |
curl -o main_Makefile https://raw.githubusercontent.com/huggingface/diffusers/main/Makefile
- name: Compare Makefiles
run: |
if ! diff -q main_Makefile Makefile; then
echo "Error: The Makefile has changed. Please ensure it matches the main branch."
exit 1
fi
echo "No changes in Makefile. Proceeding..."
rm -rf main_Makefile
if ! diff -q main_check_doc_toc.py utils/check_doc_toc.py; then
echo "Error: The utils/check_doc_toc.py has changed. Please ensure it matches the main branch."
diff_failed=1
fi
- name: Run make style and make quality
run: |
make style && make quality
if [ $diff_failed -eq 1 ]; then
echo "❌ Error happened as we detected changes in the files that should not be changed ❌"
exit 1
fi
- name: Commit and push changes
id: commit_and_push
env:
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
HEADREF: ${{ steps.pr_info.outputs.headRef }}
PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
echo "HEADREPOFULLNAME: $HEADREPOFULLNAME, HEADREF: $HEADREF"
# Configure git with the Actions bot user
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
# Make sure your 'origin' remote is set to the contributor's fork
git remote set-url origin "https://x-access-token:${GITHUB_TOKEN}@github.com/$HEADREPOFULLNAME.git"
# If there are changes after running style/quality, commit them
if [ -n "$(git status --porcelain)" ]; then
git add .
git commit -m "Apply style fixes"
# Push to the original contributor's forked branch
git push origin HEAD:$HEADREF
echo "changes_pushed=true" >> $GITHUB_OUTPUT
else
echo "No changes to commit."
echo "changes_pushed=false" >> $GITHUB_OUTPUT
fi
- name: Comment on PR with workflow run link
if: steps.commit_and_push.outputs.changes_pushed == 'true'
uses: actions/github-script@v6
with:
script: |
const prNumber = parseInt(process.env.prNumber, 10);
const runUrl = `${process.env.GITHUB_SERVER_URL}/${process.env.GITHUB_REPOSITORY}/actions/runs/${process.env.GITHUB_RUN_ID}`
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: prNumber,
body: `Style fixes have been applied. [View the workflow run here](${runUrl}).`
});
env:
prNumber: ${{ steps.pr_info.outputs.prNumber }}
echo "No changes in the files. Proceeding..."
rm -rf main_Makefile main_setup.py main_check_doc_toc.py
style_command: "make style && make quality"
secrets:
bot_token: ${{ secrets.GITHUB_TOKEN }}
-1
View File
@@ -3,7 +3,6 @@ name: Fast tests for PRs
on:
pull_request:
branches: [main]
types: [synchronize]
paths:
- "src/diffusers/**.py"
- "benchmarks/**.py"
+12 -5
View File
@@ -106,11 +106,18 @@ jobs:
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx and $pattern" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
if [ "${{ matrix.module }}" = "ip_adapters" ]; then
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
else
pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx and $pattern" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
fi
- name: Failure short reports
if: ${{ failure() }}
+22
View File
@@ -76,6 +76,14 @@
- local: advanced_inference/outpaint
title: Outpainting
title: Advanced inference
- sections:
- local: hybrid_inference/overview
title: Overview
- local: hybrid_inference/vae_decode
title: VAE Decode
- local: hybrid_inference/api_reference
title: API Reference
title: Hybrid Inference
- sections:
- local: using-diffusers/cogvideox
title: CogVideoX
@@ -165,6 +173,8 @@
title: gguf
- local: quantization/torchao
title: torchao
- local: quantization/quanto
title: quanto
title: Quantization Methods
- sections:
- local: optimization/fp16
@@ -282,6 +292,8 @@
title: CogView4Transformer2DModel
- local: api/models/dit_transformer2d
title: DiTTransformer2DModel
- local: api/models/easyanimate_transformer3d
title: EasyAnimateTransformer3DModel
- local: api/models/flux_transformer
title: FluxTransformer2DModel
- local: api/models/hunyuan_transformer2d
@@ -314,6 +326,8 @@
title: Transformer2DModel
- local: api/models/transformer_temporal
title: TransformerTemporalModel
- local: api/models/wan_transformer_3d
title: WanTransformer3DModel
title: Transformers
- sections:
- local: api/models/stable_cascade_unet
@@ -342,8 +356,12 @@
title: AutoencoderKLHunyuanVideo
- local: api/models/autoencoderkl_ltx_video
title: AutoencoderKLLTXVideo
- local: api/models/autoencoderkl_magvit
title: AutoencoderKLMagvit
- local: api/models/autoencoderkl_mochi
title: AutoencoderKLMochi
- local: api/models/autoencoder_kl_wan
title: AutoencoderKLWan
- local: api/models/asymmetricautoencoderkl
title: AsymmetricAutoencoderKL
- local: api/models/autoencoder_dc
@@ -418,6 +436,8 @@
title: DiffEdit
- local: api/pipelines/dit
title: DiT
- local: api/pipelines/easyanimate
title: EasyAnimate
- local: api/pipelines/flux
title: Flux
- local: api/pipelines/control_flux_inpaint
@@ -534,6 +554,8 @@
title: UniDiffuser
- local: api/pipelines/value_guided_sampling
title: Value-guided sampling
- local: api/pipelines/wan
title: Wan
- local: api/pipelines/wuerstchen
title: Wuerstchen
title: Pipelines
@@ -0,0 +1,32 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderKLWan
The 3D variational autoencoder (VAE) model with KL loss used in [Wan 2.1](https://github.com/Wan-Video/Wan2.1) by the Alibaba Wan Team.
The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLWan
vae = AutoencoderKLWan.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32)
```
## AutoencoderKLWan
[[autodoc]] AutoencoderKLWan
- decode
- all
## DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput
@@ -0,0 +1,37 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderKLMagvit
The 3D variational autoencoder (VAE) model with KL loss used in [EasyAnimate](https://github.com/aigc-apps/EasyAnimate) was introduced by Alibaba PAI.
The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLMagvit
vae = AutoencoderKLMagvit.from_pretrained("alibaba-pai/EasyAnimateV5.1-12b-zh", subfolder="vae", torch_dtype=torch.float16).to("cuda")
```
## AutoencoderKLMagvit
[[autodoc]] AutoencoderKLMagvit
- decode
- encode
- all
## AutoencoderKLOutput
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
## DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput
@@ -0,0 +1,30 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# EasyAnimateTransformer3DModel
A Diffusion Transformer model for 3D data from [EasyAnimate](https://github.com/aigc-apps/EasyAnimate) was introduced by Alibaba PAI.
The model can be loaded with the following code snippet.
```python
from diffusers import EasyAnimateTransformer3DModel
transformer = EasyAnimateTransformer3DModel.from_pretrained("alibaba-pai/EasyAnimateV5.1-12b-zh", subfolder="transformer", torch_dtype=torch.float16).to("cuda")
```
## EasyAnimateTransformer3DModel
[[autodoc]] EasyAnimateTransformer3DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
@@ -0,0 +1,30 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# WanTransformer3DModel
A Diffusion Transformer model for 3D video-like data was introduced in [Wan 2.1](https://github.com/Wan-Video/Wan2.1) by the Alibaba Wan Team.
The model can be loaded with the following code snippet.
```python
from diffusers import WanTransformer3DModel
transformer = WanTransformer3DModel.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
```
## WanTransformer3DModel
[[autodoc]] WanTransformer3DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
@@ -0,0 +1,88 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-->
# EasyAnimate
[EasyAnimate](https://github.com/aigc-apps/EasyAnimate) by Alibaba PAI.
The description from it's GitHub page:
*EasyAnimate is a pipeline based on the transformer architecture, designed for generating AI images and videos, and for training baseline models and Lora models for Diffusion Transformer. We support direct prediction from pre-trained EasyAnimate models, allowing for the generation of videos with various resolutions, approximately 6 seconds in length, at 8fps (EasyAnimateV5.1, 1 to 49 frames). Additionally, users can train their own baseline and Lora models for specific style transformations.*
This pipeline was contributed by [bubbliiiing](https://github.com/bubbliiiing). The original codebase can be found [here](https://huggingface.co/alibaba-pai). The original weights can be found under [hf.co/alibaba-pai](https://huggingface.co/alibaba-pai).
There are two official EasyAnimate checkpoints for text-to-video and video-to-video.
| checkpoints | recommended inference dtype |
|:---:|:---:|
| [`alibaba-pai/EasyAnimateV5.1-12b-zh`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh) | torch.float16 |
| [`alibaba-pai/EasyAnimateV5.1-12b-zh-InP`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-InP) | torch.float16 |
There is one official EasyAnimate checkpoints available for image-to-video and video-to-video.
| checkpoints | recommended inference dtype |
|:---:|:---:|
| [`alibaba-pai/EasyAnimateV5.1-12b-zh-InP`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-InP) | torch.float16 |
There are two official EasyAnimate checkpoints available for control-to-video.
| checkpoints | recommended inference dtype |
|:---:|:---:|
| [`alibaba-pai/EasyAnimateV5.1-12b-zh-Control`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-Control) | torch.float16 |
| [`alibaba-pai/EasyAnimateV5.1-12b-zh-Control-Camera`](https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-Control-Camera) | torch.float16 |
For the EasyAnimateV5.1 series:
- Text-to-video (T2V) and Image-to-video (I2V) works for multiple resolutions. The width and height can vary from 256 to 1024.
- Both T2V and I2V models support generation with 1~49 frames and work best at this value. Exporting videos at 8 FPS is recommended.
## Quantization
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`EasyAnimatePipeline`] for inference with bitsandbytes.
```py
import torch
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, EasyAnimateTransformer3DModel, EasyAnimatePipeline
from diffusers.utils import export_to_video
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
transformer_8bit = EasyAnimateTransformer3DModel.from_pretrained(
"alibaba-pai/EasyAnimateV5.1-12b-zh",
subfolder="transformer",
quantization_config=quant_config,
torch_dtype=torch.float16,
)
pipeline = EasyAnimatePipeline.from_pretrained(
"alibaba-pai/EasyAnimateV5.1-12b-zh",
transformer=transformer_8bit,
torch_dtype=torch.float16,
device_map="balanced",
)
prompt = "A cat walks on the grass, realistic style."
negative_prompt = "bad detailed"
video = pipeline(prompt=prompt, negative_prompt=negative_prompt, num_frames=49, num_inference_steps=30).frames[0]
export_to_video(video, "cat.mp4", fps=8)
```
## EasyAnimatePipeline
[[autodoc]] EasyAnimatePipeline
- all
- __call__
## EasyAnimatePipelineOutput
[[autodoc]] pipelines.easyanimate.pipeline_output.EasyAnimatePipelineOutput
@@ -49,7 +49,8 @@ The following models are available for the image-to-video pipeline:
| Model name | Description |
|:---|:---|
| [`https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-I2V`](https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-I2V) | Skywork's custom finetune of HunyuanVideo (de-distilled). Performs best with `97x544x960` resolution. Performs best at `97x544x960` resolution, `guidance_scale=1.0`, `true_cfg_scale=6.0` and a negative prompt. |
| [`Skywork/SkyReels-V1-Hunyuan-I2V`](https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-I2V) | Skywork's custom finetune of HunyuanVideo (de-distilled). Performs best with `97x544x960` resolution. Performs best at `97x544x960` resolution, `guidance_scale=1.0`, `true_cfg_scale=6.0` and a negative prompt. |
| [`hunyuanvideo-community/HunyuanVideo-I2V`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20) |
## Quantization
+78
View File
@@ -0,0 +1,78 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
# Wan
[Wan 2.1](https://github.com/Wan-Video/Wan2.1) by the Alibaba Wan Team.
<!-- TODO(aryan): update abstract once paper is out -->
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
Recommendations for inference:
- VAE in `torch.float32` for better decoding quality.
- `num_frames` should be of the form `4 * k + 1`, for example `49` or `81`.
- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution videos, try higher values (between `7.0` and `12.0`). The default value is `3.0` for Wan.
### Using a custom scheduler
Wan can be used with many different schedulers, each with their own benefits regarding speed and generation quality. By default, Wan uses the `UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0)` scheduler. You can use a different scheduler as follows:
```python
from diffusers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler, WanPipeline
scheduler_a = FlowMatchEulerDiscreteScheduler(shift=5.0)
scheduler_b = UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=4.0)
pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", scheduler=<CUSTOM_SCHEDULER_HERE>)
# or,
pipe.scheduler = <CUSTOM_SCHEDULER_HERE>
```
### Using single file loading with Wan
The `WanTransformer3DModel` and `AutoencoderKLWan` models support loading checkpoints in their original format via the `from_single_file` loading
method.
```python
import torch
from diffusers import WanPipeline, WanTransformer3DModel
ckpt_path = "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors"
transformer = WanTransformer3DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", transformer=transformer)
```
## WanPipeline
[[autodoc]] WanPipeline
- all
- __call__
## WanImageToVideoPipeline
[[autodoc]] WanImageToVideoPipeline
- all
- __call__
## WanPipelineOutput
[[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput
+5
View File
@@ -31,6 +31,11 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui
## GGUFQuantizationConfig
[[autodoc]] GGUFQuantizationConfig
## QuantoConfig
[[autodoc]] QuantoConfig
## TorchAoConfig
[[autodoc]] TorchAoConfig
+5
View File
@@ -16,6 +16,11 @@ specific language governing permissions and limitations under the License.
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>
> [!TIP]
> This document has now grown outdated given the emergence of existing evaluation frameworks for diffusion models for image generation. Please check
> out works like [HEIM](https://crfm.stanford.edu/helm/heim/latest/), [T2I-Compbench](https://arxiv.org/abs/2307.06350),
> [GenEval](https://arxiv.org/abs/2310.11513).
Evaluation of generative models like [Stable Diffusion](https://huggingface.co/docs/diffusers/stable_diffusion) is subjective in nature. But as practitioners and researchers, we often have to make careful choices amongst many different possibilities. So, when working with different generative models (like GANs, Diffusion, etc.), how do we choose one over the other?
Qualitative evaluation of such models can be error-prone and might incorrectly influence a decision.
@@ -0,0 +1,5 @@
# Hybrid Inference API Reference
## Remote Decode
[[autodoc]] utils.remote_utils.remote_decode
@@ -0,0 +1,54 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Hybrid Inference
**Empowering local AI builders with Hybrid Inference**
> [!TIP]
> Hybrid Inference is an [experimental feature](https://huggingface.co/blog/remote_vae).
> Feedback can be provided [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
## Why use Hybrid Inference?
Hybrid Inference offers a fast and simple way to offload local generation requirements.
- 🚀 **Reduced Requirements:** Access powerful models without expensive hardware.
- 💎 **Without Compromise:** Achieve the highest quality without sacrificing performance.
- 💰 **Cost Effective:** It's free! 🤑
- 🎯 **Diverse Use Cases:** Fully compatible with Diffusers 🧨 and the wider community.
- 🔧 **Developer-Friendly:** Simple requests, fast responses.
---
## Available Models
* **VAE Decode 🖼️:** Quickly decode latent representations into high-quality images without compromising performance or workflow speed.
* **VAE Encode 🔢 (coming soon):** Efficiently encode images into latent representations for generation and training.
* **Text Encoders 📃 (coming soon):** Compute text embeddings for your prompts quickly and accurately, ensuring a smooth and high-quality workflow.
---
## Integrations
* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.
## Contents
The documentation is organized into two sections:
* **VAE Decode** Learn the basics of how to use VAE Decode with Hybrid Inference.
* **API Reference** Dive into task-specific settings and parameters.
@@ -0,0 +1,345 @@
# Getting Started: VAE Decode with Hybrid Inference
VAE decode is an essential component of diffusion models - turning latent representations into images or videos.
## Memory
These tables demonstrate the VRAM requirements for VAE decode with SD v1 and SD XL on different GPUs.
For the majority of these GPUs the memory usage % dictates other models (text encoders, UNet/Transformer) must be offloaded, or tiled decoding has to be used which increases time taken and impacts quality.
<details><summary>SD v1.5</summary>
| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) |
| --- | --- | --- | --- | --- | --- |
| NVIDIA GeForce RTX 4090 | 512x512 | 0.031 | 5.60% | 0.031 (0%) | 5.60% |
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.148 | 20.00% | 0.301 (+103%) | 5.60% |
| NVIDIA GeForce RTX 4080 | 512x512 | 0.05 | 8.40% | 0.050 (0%) | 8.40% |
| NVIDIA GeForce RTX 4080 | 1024x1024 | 0.224 | 30.00% | 0.356 (+59%) | 8.40% |
| NVIDIA GeForce RTX 4070 Ti | 512x512 | 0.066 | 11.30% | 0.066 (0%) | 11.30% |
| NVIDIA GeForce RTX 4070 Ti | 1024x1024 | 0.284 | 40.50% | 0.454 (+60%) | 11.40% |
| NVIDIA GeForce RTX 3090 | 512x512 | 0.062 | 5.20% | 0.062 (0%) | 5.20% |
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.253 | 18.50% | 0.464 (+83%) | 5.20% |
| NVIDIA GeForce RTX 3080 | 512x512 | 0.07 | 12.80% | 0.070 (0%) | 12.80% |
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.286 | 45.30% | 0.466 (+63%) | 12.90% |
| NVIDIA GeForce RTX 3070 | 512x512 | 0.102 | 15.90% | 0.102 (0%) | 15.90% |
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.421 | 56.30% | 0.746 (+77%) | 16.00% |
</details>
<details><summary>SDXL</summary>
| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) |
| --- | --- | --- | --- | --- | --- |
| NVIDIA GeForce RTX 4090 | 512x512 | 0.057 | 10.00% | 0.057 (0%) | 10.00% |
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.256 | 35.50% | 0.257 (+0.4%) | 35.50% |
| NVIDIA GeForce RTX 4080 | 512x512 | 0.092 | 15.00% | 0.092 (0%) | 15.00% |
| NVIDIA GeForce RTX 4080 | 1024x1024 | 0.406 | 53.30% | 0.406 (0%) | 53.30% |
| NVIDIA GeForce RTX 4070 Ti | 512x512 | 0.121 | 20.20% | 0.120 (-0.8%) | 20.20% |
| NVIDIA GeForce RTX 4070 Ti | 1024x1024 | 0.519 | 72.00% | 0.519 (0%) | 72.00% |
| NVIDIA GeForce RTX 3090 | 512x512 | 0.107 | 10.50% | 0.107 (0%) | 10.50% |
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.459 | 38.00% | 0.460 (+0.2%) | 38.00% |
| NVIDIA GeForce RTX 3080 | 512x512 | 0.121 | 25.60% | 0.121 (0%) | 25.60% |
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.524 | 93.00% | 0.524 (0%) | 93.00% |
| NVIDIA GeForce RTX 3070 | 512x512 | 0.183 | 31.80% | 0.183 (0%) | 31.80% |
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.794 | 96.40% | 0.794 (0%) | 96.40% |
</details>
## Available VAEs
| | **Endpoint** | **Model** |
|:-:|:-----------:|:--------:|
| **Stable Diffusion v1** | [https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud](https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud) | [`stabilityai/sd-vae-ft-mse`](https://hf.co/stabilityai/sd-vae-ft-mse) |
| **Stable Diffusion XL** | [https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud](https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) |
| **Flux** | [https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud](https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) |
| **HunyuanVideo** | [https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud](https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud) | [`hunyuanvideo-community/HunyuanVideo`](https://hf.co/hunyuanvideo-community/HunyuanVideo) |
> [!TIP]
> Model support can be requested [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
## Code
> [!TIP]
> Install `diffusers` from `main` to run the code: `pip install git+https://github.com/huggingface/diffusers@main`
A helper method simplifies interacting with Hybrid Inference.
```python
from diffusers.utils.remote_utils import remote_decode
```
### Basic example
Here, we show how to use the remote VAE on random tensors.
<details><summary>Code</summary>
```python
image = remote_decode(
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=torch.randn([1, 4, 64, 64], dtype=torch.float16),
scaling_factor=0.18215,
)
```
</details>
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/output.png"/>
</figure>
Usage for Flux is slightly different. Flux latents are packed so we need to send the `height` and `width`.
<details><summary>Code</summary>
```python
image = remote_decode(
endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=torch.randn([1, 4096, 64], dtype=torch.float16),
height=1024,
width=1024,
scaling_factor=0.3611,
shift_factor=0.1159,
)
```
</details>
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/flux_random_latent.png"/>
</figure>
Finally, an example for HunyuanVideo.
<details><summary>Code</summary>
```python
video = remote_decode(
endpoint="https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=torch.randn([1, 16, 3, 40, 64], dtype=torch.float16),
output_type="mp4",
)
with open("video.mp4", "wb") as f:
f.write(video)
```
</details>
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<video
alt="queue.mp4"
autoplay loop autobuffer muted playsinline
>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/video_1.mp4" type="video/mp4">
</video>
</figure>
### Generation
But we want to use the VAE on an actual pipeline to get an actual image, not random noise. The example below shows how to do it with SD v1.5.
<details><summary>Code</summary>
```python
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
torch_dtype=torch.float16,
variant="fp16",
vae=None,
).to("cuda")
prompt = "Strawberry ice cream, in a stylish modern glass, coconut, splashing milk cream and honey, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious"
latent = pipe(
prompt=prompt,
output_type="latent",
).images
image = remote_decode(
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=latent,
scaling_factor=0.18215,
)
image.save("test.jpg")
```
</details>
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/test.jpg"/>
</figure>
Heres another example with Flux.
<details><summary>Code</summary>
```python
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16,
vae=None,
).to("cuda")
prompt = "Strawberry ice cream, in a stylish modern glass, coconut, splashing milk cream and honey, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious"
latent = pipe(
prompt=prompt,
guidance_scale=0.0,
num_inference_steps=4,
output_type="latent",
).images
image = remote_decode(
endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=latent,
height=1024,
width=1024,
scaling_factor=0.3611,
shift_factor=0.1159,
)
image.save("test.jpg")
```
</details>
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/test_1.jpg"/>
</figure>
Heres an example with HunyuanVideo.
<details><summary>Code</summary>
```python
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
model_id = "hunyuanvideo-community/HunyuanVideo"
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
model_id, subfolder="transformer", torch_dtype=torch.bfloat16
)
pipe = HunyuanVideoPipeline.from_pretrained(
model_id, transformer=transformer, vae=None, torch_dtype=torch.float16
).to("cuda")
latent = pipe(
prompt="A cat walks on the grass, realistic",
height=320,
width=512,
num_frames=61,
num_inference_steps=30,
output_type="latent",
).frames
video = remote_decode(
endpoint="https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=latent,
output_type="mp4",
)
if isinstance(video, bytes):
with open("video.mp4", "wb") as f:
f.write(video)
```
</details>
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<video
alt="queue.mp4"
autoplay loop autobuffer muted playsinline
>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/video.mp4" type="video/mp4">
</video>
</figure>
### Queueing
One of the great benefits of using a remote VAE is that we can queue multiple generation requests. While the current latent is being processed for decoding, we can already queue another one. This helps improve concurrency.
<details><summary>Code</summary>
```python
import queue
import threading
from IPython.display import display
from diffusers import StableDiffusionPipeline
def decode_worker(q: queue.Queue):
while True:
item = q.get()
if item is None:
break
image = remote_decode(
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=item,
scaling_factor=0.18215,
)
display(image)
q.task_done()
q = queue.Queue()
thread = threading.Thread(target=decode_worker, args=(q,), daemon=True)
thread.start()
def decode(latent: torch.Tensor):
q.put(latent)
prompts = [
"Blueberry ice cream, in a stylish modern glass , ice cubes, nuts, mint leaves, splashing milk cream, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious",
"Lemonade in a glass, mint leaves, in an aqua and white background, flowers, ice cubes, halo, fluid motion, dynamic movement, soft lighting, digital painting, rule of thirds composition, Art by Greg rutkowski, Coby whitmore",
"Comic book art, beautiful, vintage, pastel neon colors, extremely detailed pupils, delicate features, light on face, slight smile, Artgerm, Mary Blair, Edmund Dulac, long dark locks, bangs, glowing, fashionable style, fairytale ambience, hot pink.",
"Masterpiece, vanilla cone ice cream garnished with chocolate syrup, crushed nuts, choco flakes, in a brown background, gold, cinematic lighting, Art by WLOP",
"A bowl of milk, falling cornflakes, berries, blueberries, in a white background, soft lighting, intricate details, rule of thirds, octane render, volumetric lighting",
"Cold Coffee with cream, crushed almonds, in a glass, choco flakes, ice cubes, wet, in a wooden background, cinematic lighting, hyper realistic painting, art by Carne Griffiths, octane render, volumetric lighting, fluid motion, dynamic movement, muted colors,",
]
pipe = StableDiffusionPipeline.from_pretrained(
"Lykon/dreamshaper-8",
torch_dtype=torch.float16,
vae=None,
).to("cuda")
pipe.unet = pipe.unet.to(memory_format=torch.channels_last)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
_ = pipe(
prompt=prompts[0],
output_type="latent",
)
for prompt in prompts:
latent = pipe(
prompt=prompt,
output_type="latent",
).images
decode(latent)
q.put(None)
thread.join()
```
</details>
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<video
alt="queue.mp4"
autoplay loop autobuffer muted playsinline
>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/queue.mp4" type="video/mp4">
</video>
</figure>
## Integrations
* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.
+1
View File
@@ -36,5 +36,6 @@ Diffusers currently supports the following quantization methods.
- [BitsandBytes](./bitsandbytes)
- [TorchAO](./torchao)
- [GGUF](./gguf)
- [Quanto](./quanto.md)
[This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
+148
View File
@@ -0,0 +1,148 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Quanto
[Quanto](https://github.com/huggingface/optimum-quanto) is a PyTorch quantization backend for [Optimum](https://huggingface.co/docs/optimum/en/index). It has been designed with versatility and simplicity in mind:
- All features are available in eager mode (works with non-traceable models)
- Supports quantization aware training
- Quantized models are compatible with `torch.compile`
- Quantized models are Device agnostic (e.g CUDA,XPU,MPS,CPU)
In order to use the Quanto backend, you will first need to install `optimum-quanto>=0.2.6` and `accelerate`
```shell
pip install optimum-quanto accelerate
```
Now you can quantize a model by passing the `QuantoConfig` object to the `from_pretrained()` method. Although the Quanto library does allow quantizing `nn.Conv2d` and `nn.LayerNorm` modules, currently, Diffusers only supports quantizing the weights in the `nn.Linear` layers of a model. The following snippet demonstrates how to apply `float8` quantization with Quanto.
```python
import torch
from diffusers import FluxTransformer2DModel, QuantoConfig
model_id = "black-forest-labs/FLUX.1-dev"
quantization_config = QuantoConfig(weights_dtype="float8")
transformer = FluxTransformer2DModel.from_pretrained(
model_id,
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)
pipe = FluxPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch_dtype)
pipe.to("cuda")
prompt = "A cat holding a sign that says hello world"
image = pipe(
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
).images[0]
image.save("output.png")
```
## Skipping Quantization on specific modules
It is possible to skip applying quantization on certain modules using the `modules_to_not_convert` argument in the `QuantoConfig`. Please ensure that the modules passed in to this argument match the keys of the modules in the `state_dict`
```python
import torch
from diffusers import FluxTransformer2DModel, QuantoConfig
model_id = "black-forest-labs/FLUX.1-dev"
quantization_config = QuantoConfig(weights_dtype="float8", modules_to_not_convert=["proj_out"])
transformer = FluxTransformer2DModel.from_pretrained(
model_id,
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)
```
## Using `from_single_file` with the Quanto Backend
`QuantoConfig` is compatible with `~FromOriginalModelMixin.from_single_file`.
```python
import torch
from diffusers import FluxTransformer2DModel, QuantoConfig
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
quantization_config = QuantoConfig(weights_dtype="float8")
transformer = FluxTransformer2DModel.from_single_file(ckpt_path, quantization_config=quantization_config, torch_dtype=torch.bfloat16)
```
## Saving Quantized models
Diffusers supports serializing Quanto models using the `~ModelMixin.save_pretrained` method.
The serialization and loading requirements are different for models quantized directly with the Quanto library and models quantized
with Diffusers using Quanto as the backend. It is currently not possible to load models quantized directly with Quanto into Diffusers using `~ModelMixin.from_pretrained`
```python
import torch
from diffusers import FluxTransformer2DModel, QuantoConfig
model_id = "black-forest-labs/FLUX.1-dev"
quantization_config = QuantoConfig(weights_dtype="float8")
transformer = FluxTransformer2DModel.from_pretrained(
model_id,
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)
# save quantized model to reuse
transformer.save_pretrained("<your quantized model save path>")
# you can reload your quantized model with
model = FluxTransformer2DModel.from_pretrained("<your quantized model save path>")
```
## Using `torch.compile` with Quanto
Currently the Quanto backend supports `torch.compile` for the following quantization types:
- `int8` weights
```python
import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig
model_id = "black-forest-labs/FLUX.1-dev"
quantization_config = QuantoConfig(weights_dtype="int8")
transformer = FluxTransformer2DModel.from_pretrained(
model_id,
subfolder="transformer",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
)
transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
pipe = FluxPipeline.from_pretrained(
model_id, transformer=transformer, torch_dtype=torch_dtype
)
pipe.to("cuda")
images = pipe("A cat holding a sign that says hello").images[0]
images.save("flux-quanto-compile.png")
```
## Supported Quantization Types
### Weights
- float8
- int8
- int4
- int2
+1 -1
View File
@@ -126,7 +126,7 @@ image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0]
image.save("output.png")
```
Some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
If you are using `torch<=2.6.0`, some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
```python
import torch
@@ -157,6 +157,84 @@ pipeline(
)
```
## IP Adapter Cutoff
IP Adapter is an image prompt adapter that can be used for diffusion models without any changes to the underlying model. We can use the IP Adapter Cutoff Callback to disable the IP Adapter after a certain number of steps. To set up the callback, you need to specify the number of denoising steps after which the callback comes into effect. You can do so by using either one of these two arguments:
- `cutoff_step_ratio`: Float number with the ratio of the steps.
- `cutoff_step_index`: Integer number with the exact number of the step.
We need to download the diffusion model and load the ip_adapter for it as follows:
```py
from diffusers import AutoPipelineForText2Image
from diffusers.utils import load_image
import torch
pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
pipeline.set_ip_adapter_scale(0.6)
```
The setup for the callback should look something like this:
```py
from diffusers import AutoPipelineForText2Image
from diffusers.callbacks import IPAdapterScaleCutoffCallback
from diffusers.utils import load_image
import torch
pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16
).to("cuda")
pipeline.load_ip_adapter(
"h94/IP-Adapter",
subfolder="sdxl_models",
weight_name="ip-adapter_sdxl.bin"
)
pipeline.set_ip_adapter_scale(0.6)
callback = IPAdapterScaleCutoffCallback(
cutoff_step_ratio=None,
cutoff_step_index=5
)
image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_diner.png"
)
generator = torch.Generator(device="cuda").manual_seed(2628670641)
images = pipeline(
prompt="a tiger sitting in a chair drinking orange juice",
ip_adapter_image=image,
negative_prompt="deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
generator=generator,
num_inference_steps=50,
callback_on_step_end=callback,
).images
images[0].save("custom_callback_img.png")
```
<div class="flex gap-4">
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/without_callback.png" alt="generated image of a tiger sitting in a chair drinking orange juice" />
<figcaption class="mt-2 text-center text-sm text-gray-500">without IPAdapterScaleCutoffCallback</figcaption>
</div>
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/with_callback2.png" alt="generated image of a tiger sitting in a chair drinking orange juice with ip adapter callback" />
<figcaption class="mt-2 text-center text-sm text-gray-500">with IPAdapterScaleCutoffCallback</figcaption>
</div>
</div>
## Display image after each generation step
> [!TIP]
@@ -227,7 +227,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
autocast_ctx = nullcontext()
with autocast_ctx:
@@ -880,9 +880,7 @@ class TokenEmbeddingsHandler:
idx_to_text_encoder_name = {0: "clip_l", 1: "t5"}
for idx, text_encoder in enumerate(self.text_encoders):
train_ids = self.train_ids if idx == 0 else self.train_ids_t5
embeds = (
text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens
)
embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared
assert embeds.weight.data.shape[0] == len(self.tokenizers[idx]), "Tokenizers should be the same."
new_token_embeddings = embeds.weight.data[train_ids]
@@ -904,9 +902,7 @@ class TokenEmbeddingsHandler:
@torch.no_grad()
def retract_embeddings(self):
for idx, text_encoder in enumerate(self.text_encoders):
embeds = (
text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.encoder.embed_tokens
)
embeds = text_encoder.text_model.embeddings.token_embedding if idx == 0 else text_encoder.shared
index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"]
embeds.weight.data[index_no_updates] = (
self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates]
@@ -1749,7 +1745,7 @@ def main(args):
if args.enable_t5_ti: # whether to do pivotal tuning/textual inversion for T5 as well
text_lora_parameters_two = []
for name, param in text_encoder_two.named_parameters():
if "token_embedding" in name:
if "shared" in name:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
param.data = param.to(dtype=torch.float32)
param.requires_grad = True
@@ -1883,7 +1883,11 @@ def main(args):
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = (
torch.Generator(device=accelerator.device).manual_seed(args.seed)
if args.seed is not None
else None
)
pipeline_args = {"prompt": args.validation_prompt}
if torch.backends.mps.is_available():
@@ -1987,7 +1991,9 @@ def main(args):
)
# run inference
pipeline = pipeline.to(accelerator.device)
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = (
torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
)
images = [
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
for _ in range(args.num_validation_images)
@@ -269,7 +269,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
@@ -722,7 +722,7 @@ def log_validation(
# pipe.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
videos = []
for _ in range(args.num_validation_videos):
+1 -1
View File
@@ -739,7 +739,7 @@ def log_validation(
# pipe.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
videos = []
for _ in range(args.num_validation_videos):
+185
View File
@@ -10,6 +10,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
| Example | Description | Code Example | Colab | Author |
|:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:|
|Spatiotemporal Skip Guidance (STG)|[Spatiotemporal Skip Guidance for Enhanced Video Diffusion Sampling](https://arxiv.org/abs/2411.18664) (CVPR 2025) enhances video diffusion models by generating a weaker model through layer skipping and using it as guidance, improving fidelity in models like HunyuanVideo, LTXVideo, and Mochi.|[Spatiotemporal Skip Guidance](#spatiotemporal-skip-guidance)|-|[Junha Hyung](https://junhahyung.github.io/), [Kinam Kim](https://kinam0252.github.io/)|
|Adaptive Mask Inpainting|Adaptive Mask Inpainting algorithm from [Beyond the Contact: Discovering Comprehensive Affordance for 3D Objects from Pre-trained 2D Diffusion Models](https://github.com/snuvclab/coma) (ECCV '24, Oral) provides a way to insert human inside the scene image without altering the background, by inpainting with adapting mask.|[Adaptive Mask Inpainting](#adaptive-mask-inpainting)|-|[Hyeonwoo Kim](https://sshowbiz.xyz),[Sookwan Han](https://jellyheadandrew.github.io)|
|Flux with CFG|[Flux with CFG](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md) provides an implementation of using CFG in [Flux](https://blackforestlabs.ai/announcing-black-forest-labs/).|[Flux with CFG](#flux-with-cfg)|[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/flux_with_cfg.ipynb)|[Linoy Tsaban](https://github.com/linoytsaban), [Apolinário](https://github.com/apolinario), and [Sayak Paul](https://github.com/sayakpaul)|
|Differential Diffusion|[Differential Diffusion](https://github.com/exx8/differential-diffusion) modifies an image according to a text prompt, and according to a map that specifies the amount of change in each region.|[Differential Diffusion](#differential-diffusion)|[![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/exx8/differential-diffusion) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/exx8/differential-diffusion/blob/main/examples/SD2.ipynb)|[Eran Levin](https://github.com/exx8) and [Ohad Fried](https://www.ohadf.com/)|
@@ -53,6 +54,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
| Stable Diffusion Mixture Tiling Pipeline SD 1.5 | A pipeline generates cohesive images by integrating multiple diffusion processes, each focused on a specific image region and considering boundary effects for smooth blending | [Stable Diffusion Mixture Tiling Pipeline SD 1.5](#stable-diffusion-mixture-tiling-pipeline-sd-15) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/albarji/mixture-of-diffusers) | [Álvaro B Jiménez](https://github.com/albarji/) |
| Stable Diffusion Mixture Canvas Pipeline SD 1.5 | A pipeline generates cohesive images by integrating multiple diffusion processes, each focused on a specific image region and considering boundary effects for smooth blending. Works by defining a list of Text2Image region objects that detail the region of influence of each diffuser. | [Stable Diffusion Mixture Canvas Pipeline SD 1.5](#stable-diffusion-mixture-canvas-pipeline-sd-15) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/albarji/mixture-of-diffusers) | [Álvaro B Jiménez](https://github.com/albarji/) |
| Stable Diffusion Mixture Tiling Pipeline SDXL | A pipeline generates cohesive images by integrating multiple diffusion processes, each focused on a specific image region and considering boundary effects for smooth blending | [Stable Diffusion Mixture Tiling Pipeline SDXL](#stable-diffusion-mixture-tiling-pipeline-sdxl) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/elismasilva/mixture-of-diffusers-sdxl-tiling) | [Eliseu Silva](https://github.com/DEVAIEXP/) |
| Stable Diffusion MoD ControlNet Tile SR Pipeline SDXL | This is an advanced pipeline that leverages ControlNet Tile and Mixture-of-Diffusers techniques, integrating tile diffusion directly into the latent space denoising process. Designed to overcome the limitations of conventional pixel-space tile processing, this pipeline delivers Super Resolution (SR) upscaling for higher-quality images, reduced processing time, and greater adaptability. | [Stable Diffusion MoD ControlNet Tile SR Pipeline SDXL](#stable-diffusion-mod-controlnet-tile-sr-pipeline-sdxl) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/elismasilva/mod-control-tile-upscaler-sdxl) | [Eliseu Silva](https://github.com/DEVAIEXP/) |
| FABRIC - Stable Diffusion with feedback Pipeline | pipeline supports feedback from liked and disliked images | [Stable Diffusion Fabric Pipeline](#stable-diffusion-fabric-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_diffusion_fabric.ipynb)| [Shauray Singh](https://shauray8.github.io/about_shauray/) |
| sketch inpaint - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion Pipeline](#stable-diffusion-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
| sketch inpaint xl - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion XL Pipeline](#stable-diffusion-xl-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
@@ -82,6 +84,7 @@ PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixar
| [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111) | A diffusion process that denoises inputs at multiple resolutions jointly and uses a NestedUNet architecture where features and parameters for small scale inputs are nested within those of the large scales. See [original codebase](https://github.com/apple/ml-mdm). | [🪆Matryoshka Diffusion Models](#matryoshka-diffusion-models) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/pcuenq/mdm) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/tolgacangoz/1f54875fc7aeaabcf284ebde64820966/matryoshka_hf.ipynb) | [M. Tolga Cangöz](https://github.com/tolgacangoz) |
| Stable Diffusion XL Attentive Eraser Pipeline |[[AAAI2025 Oral] Attentive Eraser](https://github.com/Anonym0u3/AttentiveEraser) is a novel tuning-free method that enhances object removal capabilities in pre-trained diffusion models.|[Stable Diffusion XL Attentive Eraser Pipeline](#stable-diffusion-xl-attentive-eraser-pipeline)|-|[Wenhao Sun](https://github.com/Anonym0u3) and [Benlei Cui](https://github.com/Benny079)|
| Perturbed-Attention Guidance |StableDiffusionPAGPipeline is a modification of StableDiffusionPipeline to support Perturbed-Attention Guidance (PAG).|[Perturbed-Attention Guidance](#perturbed-attention-guidance)|[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/perturbed_attention_guidance.ipynb)|[Hyoungwon Cho](https://github.com/HyoungwonCho)|
| CogVideoX DDIM Inversion Pipeline | Implementation of DDIM inversion and guided attention-based editing denoising process on CogVideoX. | [CogVideoX DDIM Inversion Pipeline](#cogvideox-ddim-inversion-pipeline) | - | [LittleNyima](https://github.com/LittleNyima) |
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
@@ -91,6 +94,55 @@ pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion
## Example usages
### Spatiotemporal Skip Guidance
**Junha Hyung\*, Kinam Kim\*, Susung Hong, Min-Jung Kim, Jaegul Choo**
**KAIST AI, University of Washington**
[*Spatiotemporal Skip Guidance (STG) for Enhanced Video Diffusion Sampling*](https://arxiv.org/abs/2411.18664) (CVPR 2025) is a simple training-free sampling guidance method for enhancing transformer-based video diffusion models. STG employs an implicit weak model via self-perturbation, avoiding the need for external models or additional training. By selectively skipping spatiotemporal layers, STG produces an aligned, degraded version of the original model to boost sample quality without compromising diversity or dynamic degree.
Following is the example video of STG applied to Mochi.
https://github.com/user-attachments/assets/148adb59-da61-4c50-9dfa-425dcb5c23b3
More examples and information can be found on the [GitHub repository](https://github.com/junhahyung/STGuidance) and the [Project website](https://junhahyung.github.io/STGuidance/).
#### Usage example
```python
import torch
from pipeline_stg_mochi import MochiSTGPipeline
from diffusers.utils import export_to_video
# Load the pipeline
pipe = MochiSTGPipeline.from_pretrained("genmo/mochi-1-preview", variant="bf16", torch_dtype=torch.bfloat16)
# Enable memory savings
pipe = pipe.to("cuda")
#--------Option--------#
prompt = "A close-up of a beautiful woman's face with colored powder exploding around her, creating an abstract splash of vibrant hues, realistic style."
stg_applied_layers_idx = [34]
stg_mode = "STG"
stg_scale = 1.0 # 0.0 for CFG
#----------------------#
# Generate video frames
frames = pipe(
prompt,
height=480,
width=480,
num_frames=81,
stg_applied_layers_idx=stg_applied_layers_idx,
stg_scale=stg_scale,
generator = torch.Generator().manual_seed(42),
do_rescaling=do_rescaling,
).frames[0]
export_to_video(frames, "output.mp4", fps=30)
```
### Adaptive Mask Inpainting
**Hyeonwoo Kim\*, Sookwan Han\*, Patrick Kwon, Hanbyul Joo**
@@ -2630,6 +2682,103 @@ image = pipe(
![mixture_tiling_results](https://huggingface.co/datasets/elismasilva/results/resolve/main/mixture_of_diffusers_sdxl_1.png)
### Stable Diffusion MoD ControlNet Tile SR Pipeline SDXL
This pipeline implements the [MoD (Mixture-of-Diffusers)]("https://arxiv.org/pdf/2408.06072") tiled diffusion technique and combines it with SDXL's ControlNet Tile process to generate SR images.
This works better with 4x scales, but you can try adjusts parameters to higher scales.
````python
import torch
from diffusers import DiffusionPipeline, ControlNetUnionModel, AutoencoderKL, UniPCMultistepScheduler, UNet2DConditionModel
from diffusers.utils import load_image
from PIL import Image
device = "cuda"
# Initialize the models and pipeline
controlnet = ControlNetUnionModel.from_pretrained(
"brad-twinkl/controlnet-union-sdxl-1.0-promax", torch_dtype=torch.float16
).to(device=device)
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to(device=device)
model_id = "SG161222/RealVisXL_V5.0"
pipe = DiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16,
vae=vae,
controlnet=controlnet,
custom_pipeline="mod_controlnet_tile_sr_sdxl",
use_safetensors=True,
variant="fp16",
).to(device)
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", variant="fp16", use_safetensors=True)
#pipe.enable_model_cpu_offload() # << Enable this if you have limited VRAM
pipe.enable_vae_tiling() # << Enable this if you have limited VRAM
pipe.enable_vae_slicing() # << Enable this if you have limited VRAM
# Set selected scheduler
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
# Load image
control_image = load_image("https://huggingface.co/datasets/DEVAIEXP/assets/resolve/main/1.jpg")
original_height = control_image.height
original_width = control_image.width
print(f"Current resolution: H:{original_height} x W:{original_width}")
# Pre-upscale image for tiling
resolution = 4096
tile_gaussian_sigma = 0.3
max_tile_size = 1024 # or 1280
current_size = max(control_image.size)
scale_factor = max(2, resolution / current_size)
new_size = (int(control_image.width * scale_factor), int(control_image.height * scale_factor))
image = control_image.resize(new_size, Image.LANCZOS)
# Update target height and width
target_height = image.height
target_width = image.width
print(f"Target resolution: H:{target_height} x W:{target_width}")
# Calculate overlap size
normal_tile_overlap, border_tile_overlap = pipe.calculate_overlap(target_width, target_height)
# Set other params
tile_weighting_method = pipe.TileWeightingMethod.COSINE.value
guidance_scale = 4
num_inference_steps = 35
denoising_strenght = 0.65
controlnet_strength = 1.0
prompt = "high-quality, noise-free edges, high quality, 4k, hd, 8k"
negative_prompt = "blurry, pixelated, noisy, low resolution, artifacts, poor details"
# Image generation
generated_image = pipe(
image=image,
control_image=control_image,
control_mode=[6],
controlnet_conditioning_scale=float(controlnet_strength),
prompt=prompt,
negative_prompt=negative_prompt,
normal_tile_overlap=normal_tile_overlap,
border_tile_overlap=border_tile_overlap,
height=target_height,
width=target_width,
original_size=(original_width, original_height),
target_size=(target_width, target_height),
guidance_scale=guidance_scale,
strength=float(denoising_strenght),
tile_weighting_method=tile_weighting_method,
max_tile_size=max_tile_size,
tile_gaussian_sigma=float(tile_gaussian_sigma),
num_inference_steps=num_inference_steps,
)["images"][0]
````
![Upscaled](https://huggingface.co/datasets/DEVAIEXP/assets/resolve/main/1_input_4x.png)
### TensorRT Inpainting Stable Diffusion Pipeline
The TensorRT Pipeline can be used to accelerate the Inpainting Stable Diffusion Inference run.
@@ -5124,3 +5273,39 @@ with torch.no_grad():
In the folder examples/pixart there is also a script that can be used to train new models.
Please check the script `train_controlnet_hf_diffusers.sh` on how to start the training.
# CogVideoX DDIM Inversion Pipeline
This implementation performs DDIM inversion on the video based on CogVideoX and uses guided attention to reconstruct or edit the inversion latents.
## Example Usage
```python
import torch
from examples.community.cogvideox_ddim_inversion import CogVideoXPipelineForDDIMInversion
# Load pretrained pipeline
pipeline = CogVideoXPipelineForDDIMInversion.from_pretrained(
"THUDM/CogVideoX1.5-5B",
torch_dtype=torch.bfloat16,
).to("cuda")
# Run DDIM inversion, and the videos will be generated in the output_path
output = pipeline_for_inversion(
prompt="prompt that describes the edited video",
video_path="path/to/input.mp4",
guidance_scale=6.0,
num_inference_steps=50,
skip_frames_start=0,
skip_frames_end=0,
frame_sample_step=None,
max_num_frames=81,
width=720,
height=480,
seed=42,
)
pipeline.export_latents_to_video(output.inverse_latents[-1], "path/to/inverse_video.mp4", fps=8)
pipeline.export_latents_to_video(output.recon_latents[-1], "path/to/recon_video.mp4", fps=8)
```
@@ -0,0 +1,645 @@
"""
This script performs DDIM inversion for video frames using a pre-trained model and generates
a video reconstruction based on a provided prompt. It utilizes the CogVideoX pipeline to
process video frames, apply the DDIM inverse scheduler, and produce an output video.
**Please notice that this script is based on the CogVideoX 5B model, and would not generate
a good result for 2B variants.**
Usage:
python cogvideox_ddim_inversion.py
--model-path /path/to/model
--prompt "a prompt"
--video-path /path/to/video.mp4
--output-path /path/to/output
For more details about the cli arguments, please run `python cogvideox_ddim_inversion.py --help`.
Author:
LittleNyima <littlenyima[at]163[dot]com>
"""
import argparse
import math
import os
from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union, cast
import torch
import torch.nn.functional as F
import torchvision.transforms as T
from transformers import T5EncoderModel, T5Tokenizer
from diffusers.models.attention_processor import Attention, CogVideoXAttnProcessor2_0
from diffusers.models.autoencoders import AutoencoderKLCogVideoX
from diffusers.models.embeddings import apply_rotary_emb
from diffusers.models.transformers.cogvideox_transformer_3d import CogVideoXBlock, CogVideoXTransformer3DModel
from diffusers.pipelines.cogvideo.pipeline_cogvideox import CogVideoXPipeline, retrieve_timesteps
from diffusers.schedulers import CogVideoXDDIMScheduler, DDIMInverseScheduler
from diffusers.utils import export_to_video
# Must import after torch because this can sometimes lead to a nasty segmentation fault, or stack smashing error.
# Very few bug reports but it happens. Look in decord Github issues for more relevant information.
import decord # isort: skip
class DDIMInversionArguments(TypedDict):
model_path: str
prompt: str
video_path: str
output_path: str
guidance_scale: float
num_inference_steps: int
skip_frames_start: int
skip_frames_end: int
frame_sample_step: Optional[int]
max_num_frames: int
width: int
height: int
fps: int
dtype: torch.dtype
seed: int
device: torch.device
def get_args() -> DDIMInversionArguments:
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, required=True, help="Path of the pretrained model")
parser.add_argument("--prompt", type=str, required=True, help="Prompt for the direct sample procedure")
parser.add_argument("--video_path", type=str, required=True, help="Path of the video for inversion")
parser.add_argument("--output_path", type=str, default="output", help="Path of the output videos")
parser.add_argument("--guidance_scale", type=float, default=6.0, help="Classifier-free guidance scale")
parser.add_argument("--num_inference_steps", type=int, default=50, help="Number of inference steps")
parser.add_argument("--skip_frames_start", type=int, default=0, help="Number of skipped frames from the start")
parser.add_argument("--skip_frames_end", type=int, default=0, help="Number of skipped frames from the end")
parser.add_argument("--frame_sample_step", type=int, default=None, help="Temporal stride of the sampled frames")
parser.add_argument("--max_num_frames", type=int, default=81, help="Max number of sampled frames")
parser.add_argument("--width", type=int, default=720, help="Resized width of the video frames")
parser.add_argument("--height", type=int, default=480, help="Resized height of the video frames")
parser.add_argument("--fps", type=int, default=8, help="Frame rate of the output videos")
parser.add_argument("--dtype", type=str, default="bf16", choices=["bf16", "fp16"], help="Dtype of the model")
parser.add_argument("--seed", type=int, default=42, help="Seed for the random number generator")
parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"], help="Device for inference")
args = parser.parse_args()
args.dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16
args.device = torch.device(args.device)
return DDIMInversionArguments(**vars(args))
class CogVideoXAttnProcessor2_0ForDDIMInversion(CogVideoXAttnProcessor2_0):
def __init__(self):
super().__init__()
def calculate_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn: Attention,
batch_size: int,
image_seq_length: int,
text_seq_length: int,
attention_mask: Optional[torch.Tensor],
image_rotary_emb: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Core attention computation with inversion-guided RoPE integration.
Args:
query (`torch.Tensor`): `[batch_size, seq_len, dim]` query tensor
key (`torch.Tensor`): `[batch_size, seq_len, dim]` key tensor
value (`torch.Tensor`): `[batch_size, seq_len, dim]` value tensor
attn (`Attention`): Parent attention module with projection layers
batch_size (`int`): Effective batch size (after chunk splitting)
image_seq_length (`int`): Length of image feature sequence
text_seq_length (`int`): Length of text feature sequence
attention_mask (`Optional[torch.Tensor]`): Attention mask tensor
image_rotary_emb (`Optional[torch.Tensor]`): Rotary embeddings for image positions
Returns:
`Tuple[torch.Tensor, torch.Tensor]`:
(1) hidden_states: [batch_size, image_seq_length, dim] processed image features
(2) encoder_hidden_states: [batch_size, text_seq_length, dim] processed text features
"""
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE if needed
if image_rotary_emb is not None:
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
if not attn.is_cross_attention:
if key.size(2) == query.size(2): # Attention for reference hidden states
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
else: # RoPE should be applied to each group of image tokens
key[:, :, text_seq_length : text_seq_length + image_seq_length] = apply_rotary_emb(
key[:, :, text_seq_length : text_seq_length + image_seq_length], image_rotary_emb
)
key[:, :, text_seq_length * 2 + image_seq_length :] = apply_rotary_emb(
key[:, :, text_seq_length * 2 + image_seq_length :], image_rotary_emb
)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Process the dual-path attention for the inversion-guided denoising procedure.
Args:
attn (`Attention`): Parent attention module
hidden_states (`torch.Tensor`): `[batch_size, image_seq_len, dim]` Image tokens
encoder_hidden_states (`torch.Tensor`): `[batch_size, text_seq_len, dim]` Text tokens
attention_mask (`Optional[torch.Tensor]`): Optional attention mask
image_rotary_emb (`Optional[torch.Tensor]`): Rotary embeddings for image tokens
Returns:
`Tuple[torch.Tensor, torch.Tensor]`:
(1) Final hidden states: `[batch_size, image_seq_length, dim]` Resulting image tokens
(2) Final encoder states: `[batch_size, text_seq_length, dim]` Resulting text tokens
"""
image_seq_length = hidden_states.size(1)
text_seq_length = encoder_hidden_states.size(1)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
query, query_reference = query.chunk(2)
key, key_reference = key.chunk(2)
value, value_reference = value.chunk(2)
batch_size = batch_size // 2
hidden_states, encoder_hidden_states = self.calculate_attention(
query=query,
key=torch.cat((key, key_reference), dim=1),
value=torch.cat((value, value_reference), dim=1),
attn=attn,
batch_size=batch_size,
image_seq_length=image_seq_length,
text_seq_length=text_seq_length,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
hidden_states_reference, encoder_hidden_states_reference = self.calculate_attention(
query=query_reference,
key=key_reference,
value=value_reference,
attn=attn,
batch_size=batch_size,
image_seq_length=image_seq_length,
text_seq_length=text_seq_length,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
return (
torch.cat((hidden_states, hidden_states_reference)),
torch.cat((encoder_hidden_states, encoder_hidden_states_reference)),
)
class OverrideAttnProcessors:
r"""
Context manager for temporarily overriding attention processors in CogVideo transformer blocks.
Designed for DDIM inversion process, replaces original attention processors with
`CogVideoXAttnProcessor2_0ForDDIMInversion` and restores them upon exit. Uses Python context manager
pattern to safely manage processor replacement.
Typical usage:
```python
with OverrideAttnProcessors(transformer):
# Perform DDIM inversion operations
```
Args:
transformer (`CogVideoXTransformer3DModel`):
The transformer model containing attention blocks to be modified. Should have
`transformer_blocks` attribute containing `CogVideoXBlock` instances.
"""
def __init__(self, transformer: CogVideoXTransformer3DModel):
self.transformer = transformer
self.original_processors = {}
def __enter__(self):
for block in self.transformer.transformer_blocks:
block = cast(CogVideoXBlock, block)
self.original_processors[id(block)] = block.attn1.get_processor()
block.attn1.set_processor(CogVideoXAttnProcessor2_0ForDDIMInversion())
def __exit__(self, _0, _1, _2):
for block in self.transformer.transformer_blocks:
block = cast(CogVideoXBlock, block)
block.attn1.set_processor(self.original_processors[id(block)])
def get_video_frames(
video_path: str,
width: int,
height: int,
skip_frames_start: int,
skip_frames_end: int,
max_num_frames: int,
frame_sample_step: Optional[int],
) -> torch.FloatTensor:
"""
Extract and preprocess video frames from a video file for VAE processing.
Args:
video_path (`str`): Path to input video file
width (`int`): Target frame width for decoding
height (`int`): Target frame height for decoding
skip_frames_start (`int`): Number of frames to skip at video start
skip_frames_end (`int`): Number of frames to skip at video end
max_num_frames (`int`): Maximum allowed number of output frames
frame_sample_step (`Optional[int]`):
Frame sampling step size. If None, automatically calculated as:
(total_frames - skipped_frames) // max_num_frames
Returns:
`torch.FloatTensor`: Preprocessed frames in `[F, C, H, W]` format where:
- `F`: Number of frames (adjusted to 4k + 1 for VAE compatibility)
- `C`: Channels (3 for RGB)
- `H`: Frame height
- `W`: Frame width
"""
with decord.bridge.use_torch():
video_reader = decord.VideoReader(uri=video_path, width=width, height=height)
video_num_frames = len(video_reader)
start_frame = min(skip_frames_start, video_num_frames)
end_frame = max(0, video_num_frames - skip_frames_end)
if end_frame <= start_frame:
indices = [start_frame]
elif end_frame - start_frame <= max_num_frames:
indices = list(range(start_frame, end_frame))
else:
step = frame_sample_step or (end_frame - start_frame) // max_num_frames
indices = list(range(start_frame, end_frame, step))
frames = video_reader.get_batch(indices=indices)
frames = frames[:max_num_frames].float() # ensure that we don't go over the limit
# Choose first (4k + 1) frames as this is how many is required by the VAE
selected_num_frames = frames.size(0)
remainder = (3 + selected_num_frames) % 4
if remainder != 0:
frames = frames[:-remainder]
assert frames.size(0) % 4 == 1
# Normalize the frames
transform = T.Lambda(lambda x: x / 255.0 * 2.0 - 1.0)
frames = torch.stack(tuple(map(transform, frames)), dim=0)
return frames.permute(0, 3, 1, 2).contiguous() # [F, C, H, W]
class CogVideoXDDIMInversionOutput:
inverse_latents: torch.FloatTensor
recon_latents: torch.FloatTensor
def __init__(self, inverse_latents: torch.FloatTensor, recon_latents: torch.FloatTensor):
self.inverse_latents = inverse_latents
self.recon_latents = recon_latents
class CogVideoXPipelineForDDIMInversion(CogVideoXPipeline):
def __init__(
self,
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
vae: AutoencoderKLCogVideoX,
transformer: CogVideoXTransformer3DModel,
scheduler: CogVideoXDDIMScheduler,
):
super().__init__(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
transformer=transformer,
scheduler=scheduler,
)
self.inverse_scheduler = DDIMInverseScheduler(**scheduler.config)
def encode_video_frames(self, video_frames: torch.FloatTensor) -> torch.FloatTensor:
"""
Encode video frames into latent space using Variational Autoencoder.
Args:
video_frames (`torch.FloatTensor`):
Input frames tensor in `[F, C, H, W]` format from `get_video_frames()`
Returns:
`torch.FloatTensor`: Encoded latents in `[1, F, D, H_latent, W_latent]` format where:
- `F`: Number of frames (same as input)
- `D`: Latent channel dimension
- `H_latent`: Latent space height (H // 2^vae.downscale_factor)
- `W_latent`: Latent space width (W // 2^vae.downscale_factor)
"""
vae: AutoencoderKLCogVideoX = self.vae
video_frames = video_frames.to(device=vae.device, dtype=vae.dtype)
video_frames = video_frames.unsqueeze(0).permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
latent_dist = vae.encode(x=video_frames).latent_dist.sample().transpose(1, 2)
return latent_dist * vae.config.scaling_factor
@torch.no_grad()
def export_latents_to_video(self, latents: torch.FloatTensor, video_path: str, fps: int):
r"""
Decode latent vectors into video and export as video file.
Args:
latents (`torch.FloatTensor`): Encoded latents in `[B, F, D, H_latent, W_latent]` format from
`encode_video_frames()`
video_path (`str`): Output path for video file
fps (`int`): Target frames per second for output video
"""
video = self.decode_latents(latents)
frames = self.video_processor.postprocess_video(video=video, output_type="pil")
os.makedirs(os.path.dirname(video_path), exist_ok=True)
export_to_video(video_frames=frames[0], output_video_path=video_path, fps=fps)
# Modified from CogVideoXPipeline.__call__
@torch.no_grad()
def sample(
self,
latents: torch.FloatTensor,
scheduler: Union[DDIMInverseScheduler, CogVideoXDDIMScheduler],
prompt: Optional[Union[str, List[str]]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_inference_steps: int = 50,
guidance_scale: float = 6,
use_dynamic_cfg: bool = False,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
reference_latents: torch.FloatTensor = None,
) -> torch.FloatTensor:
r"""
Execute the core sampling loop for video generation/inversion using CogVideoX.
Implements the full denoising trajectory recording for both DDIM inversion and
generation processes. Supports dynamic classifier-free guidance and reference
latent conditioning.
Args:
latents (`torch.FloatTensor`):
Initial noise tensor of shape `[B, F, C, H, W]`.
scheduler (`Union[DDIMInverseScheduler, CogVideoXDDIMScheduler]`):
Scheduling strategy for diffusion process. Use:
(1) `DDIMInverseScheduler` for inversion
(2) `CogVideoXDDIMScheduler` for generation
prompt (`Optional[Union[str, List[str]]]`):
Text prompt(s) for conditional generation. Defaults to unconditional.
negative_prompt (`Optional[Union[str, List[str]]]`):
Negative prompt(s) for guidance. Requires `guidance_scale > 1`.
num_inference_steps (`int`):
Number of denoising steps. Affects quality/compute trade-off.
guidance_scale (`float`):
Classifier-free guidance weight. 1.0 = no guidance.
use_dynamic_cfg (`bool`):
Enable time-varying guidance scale (cosine schedule)
eta (`float`):
DDIM variance parameter (0 = deterministic process)
generator (`Optional[Union[torch.Generator, List[torch.Generator]]]`):
Random number generator(s) for reproducibility
attention_kwargs (`Optional[Dict[str, Any]]`):
Custom parameters for attention modules
reference_latents (`torch.FloatTensor`):
Reference latent trajectory for conditional sampling. Shape should match
`[T, B, F, C, H, W]` where `T` is number of timesteps
Returns:
`torch.FloatTensor`:
Full denoising trajectory tensor of shape `[T, B, F, C, H, W]`.
"""
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._interrupt = False
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
negative_prompt,
do_classifier_free_guidance,
device=device,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
if reference_latents is not None:
prompt_embeds = torch.cat([prompt_embeds] * 2, dim=0)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(scheduler, num_inference_steps, device)
self._num_timesteps = len(timesteps)
# 5. Prepare latents.
latents = latents.to(device=device) * scheduler.init_noise_sigma
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
if isinstance(scheduler, DDIMInverseScheduler): # Inverse scheduler does not accept extra kwargs
extra_step_kwargs = {}
# 7. Create rotary embeds if required
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(
height=latents.size(3) * self.vae_scale_factor_spatial,
width=latents.size(4) * self.vae_scale_factor_spatial,
num_frames=latents.size(1),
device=device,
)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * scheduler.order, 0)
trajectory = torch.zeros_like(latents).unsqueeze(0).repeat(len(timesteps), 1, 1, 1, 1, 1)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
if reference_latents is not None:
reference = reference_latents[i]
reference = torch.cat([reference] * 2) if do_classifier_free_guidance else reference
latent_model_input = torch.cat([latent_model_input, reference], dim=0)
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
if reference_latents is not None: # Recover the original batch size
noise_pred, _ = noise_pred.chunk(2)
# perform guidance
if use_dynamic_cfg:
self._guidance_scale = 1 + guidance_scale * (
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
)
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the noisy sample x_t-1 -> x_t
latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
latents = latents.to(prompt_embeds.dtype)
trajectory[i] = latents
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
progress_bar.update()
# Offload all models
self.maybe_free_model_hooks()
return trajectory
@torch.no_grad()
def __call__(
self,
prompt: str,
video_path: str,
guidance_scale: float,
num_inference_steps: int,
skip_frames_start: int,
skip_frames_end: int,
frame_sample_step: Optional[int],
max_num_frames: int,
width: int,
height: int,
seed: int,
):
"""
Performs DDIM inversion on a video to reconstruct it with a new prompt.
Args:
prompt (`str`): The text prompt to guide the reconstruction.
video_path (`str`): Path to the input video file.
guidance_scale (`float`): Scale for classifier-free guidance.
num_inference_steps (`int`): Number of denoising steps.
skip_frames_start (`int`): Number of frames to skip from the beginning of the video.
skip_frames_end (`int`): Number of frames to skip from the end of the video.
frame_sample_step (`Optional[int]`): Step size for sampling frames. If None, all frames are used.
max_num_frames (`int`): Maximum number of frames to process.
width (`int`): Width of the output video frames.
height (`int`): Height of the output video frames.
seed (`int`): Random seed for reproducibility.
Returns:
`CogVideoXDDIMInversionOutput`: Contains the inverse latents and reconstructed latents.
"""
if not self.transformer.config.use_rotary_positional_embeddings:
raise NotImplementedError("This script supports CogVideoX 5B model only.")
video_frames = get_video_frames(
video_path=video_path,
width=width,
height=height,
skip_frames_start=skip_frames_start,
skip_frames_end=skip_frames_end,
max_num_frames=max_num_frames,
frame_sample_step=frame_sample_step,
).to(device=self.device)
video_latents = self.encode_video_frames(video_frames=video_frames)
inverse_latents = self.sample(
latents=video_latents,
scheduler=self.inverse_scheduler,
prompt="",
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=torch.Generator(device=self.device).manual_seed(seed),
)
with OverrideAttnProcessors(transformer=self.transformer):
recon_latents = self.sample(
latents=torch.randn_like(video_latents),
scheduler=self.scheduler,
prompt=prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
generator=torch.Generator(device=self.device).manual_seed(seed),
reference_latents=reversed(inverse_latents),
)
return CogVideoXDDIMInversionOutput(
inverse_latents=inverse_latents,
recon_latents=recon_latents,
)
if __name__ == "__main__":
arguments = get_args()
pipeline = CogVideoXPipelineForDDIMInversion.from_pretrained(
arguments.pop("model_path"),
torch_dtype=arguments.pop("dtype"),
).to(device=arguments.pop("device"))
output_path = arguments.pop("output_path")
fps = arguments.pop("fps")
inverse_video_path = os.path.join(output_path, f"{arguments.get('video_path')}_inversion.mp4")
recon_video_path = os.path.join(output_path, f"{arguments.get('video_path')}_reconstruction.mp4")
# Run DDIM inversion
output = pipeline(**arguments)
pipeline.export_latents_to_video(output.inverse_latents[-1], inverse_video_path, fps)
pipeline.export_latents_to_video(output.recon_latents[-1], recon_video_path, fps)
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,876 @@
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import math
import types
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import T5EncoderModel, T5Tokenizer
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.loaders import CogVideoXLoraLoaderMixin
from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from diffusers.pipelines.cogvideo.pipeline_output import CogVideoXPipelineOutput
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers.utils import export_to_video
>>> from examples.community.pipeline_stg_cogvideox import CogVideoXSTGPipeline
>>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
>>> pipe = CogVideoXSTGPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.float16).to("cuda")
>>> prompt = (
... "A father and son building a treehouse together, their hands covered in sawdust and smiles on their faces, realistic style."
... )
>>> pipe.transformer.to(memory_format=torch.channels_last)
>>> # Configure STG mode options
>>> stg_applied_layers_idx = [11] # Layer indices from 0 to 41
>>> stg_scale = 1.0 # Set to 0.0 for CFG
>>> do_rescaling = False
>>> # Generate video frames with STG parameters
>>> frames = pipe(
... prompt=prompt,
... stg_applied_layers_idx=stg_applied_layers_idx,
... stg_scale=stg_scale,
... do_rescaling=do_rescaling,
>>> ).frames[0]
>>> export_to_video(frames, "output.mp4", fps=8)
```
"""
def forward_with_stg(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
hidden_states_ptb = hidden_states[2:]
encoder_hidden_states_ptb = encoder_hidden_states[2:]
text_seq_length = encoder_hidden_states.size(1)
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
hidden_states, encoder_hidden_states, temb
)
# attention
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + gate_msa * attn_hidden_states
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
hidden_states, encoder_hidden_states, temb
)
# feed-forward
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
hidden_states[2:] = hidden_states_ptb
encoder_hidden_states[2:] = encoder_hidden_states_ptb
return hidden_states, encoder_hidden_states
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
tw = tgt_width
th = tgt_height
h, w = src
r = h / w
if r > (th / tw):
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h))
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class CogVideoXSTGPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
r"""
Pipeline for text-to-video generation using CogVideoX.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
text_encoder ([`T5EncoderModel`]):
Frozen text-encoder. CogVideoX uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
tokenizer (`T5Tokenizer`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
transformer ([`CogVideoXTransformer3DModel`]):
A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
"""
_optional_components = []
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
]
def __init__(
self,
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
vae: AutoencoderKLCogVideoX,
transformer: CogVideoXTransformer3DModel,
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
):
super().__init__()
self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor_spatial = (
2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
)
self.vae_scale_factor_temporal = (
self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
)
self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_videos_per_prompt: int = 1,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
return prompt_embeds
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = self._get_t5_prompt_embeds(
prompt=negative_prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
return prompt_embeds, negative_prompt_embeds
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
):
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
shape = (
batch_size,
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
num_channels_latents,
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
latents = 1 / self.vae_scaling_factor_image * latents
frames = self.vae.decode(latents).sample
return frames
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
def check_inputs(
self,
prompt,
height,
width,
negative_prompt,
callback_on_step_end_tensor_inputs,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
def fuse_qkv_projections(self) -> None:
r"""Enables fused QKV projections."""
self.fusing_transformer = True
self.transformer.fuse_qkv_projections()
def unfuse_qkv_projections(self) -> None:
r"""Disable QKV projection fusion if enabled."""
if not self.fusing_transformer:
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
else:
self.transformer.unfuse_qkv_projections()
self.fusing_transformer = False
def _prepare_rotary_positional_embeddings(
self,
height: int,
width: int,
num_frames: int,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
p = self.transformer.config.patch_size
p_t = self.transformer.config.patch_size_t
base_size_width = self.transformer.config.sample_width // p
base_size_height = self.transformer.config.sample_height // p
if p_t is None:
# CogVideoX 1.0
grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height
)
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
device=device,
)
else:
# CogVideoX 1.5
base_num_frames = (num_frames + p_t - 1) // p_t
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=None,
grid_size=(grid_height, grid_width),
temporal_size=base_num_frames,
grid_type="slice",
max_size=(base_size_height, base_size_width),
device=device,
)
return freqs_cos, freqs_sin
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_spatio_temporal_guidance(self):
return self._stg_scale > 0.0
@property
def num_timesteps(self):
return self._num_timesteps
@property
def attention_kwargs(self):
return self._attention_kwargs
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_frames: Optional[int] = None,
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
guidance_scale: float = 6,
use_dynamic_cfg: bool = False,
num_videos_per_prompt: int = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: str = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 226,
stg_applied_layers_idx: Optional[List[int]] = [11],
stg_scale: Optional[float] = 0.0,
do_rescaling: Optional[bool] = False,
) -> Union[CogVideoXPipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
The height in pixels of the generated image. This is set to 480 by default for the best results.
width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
The width in pixels of the generated image. This is set to 720 by default for the best results.
num_frames (`int`, defaults to `48`):
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
num_seconds is 6 and fps is 8. However, since videos can be saved at any fps, the only condition that
needs to be satisfied is that of divisibility mentioned above.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int`, defaults to `226`):
Maximum sequence length in encoded prompt. Must be consistent with
`self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
Examples:
Returns:
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] or `tuple`:
[`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
num_frames = num_frames or self.transformer.config.sample_frames
num_videos_per_prompt = 1
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
negative_prompt,
callback_on_step_end_tensor_inputs,
prompt_embeds,
negative_prompt_embeds,
)
self._stg_scale = stg_scale
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
if self.do_spatio_temporal_guidance:
for i in stg_applied_layers_idx:
self.transformer.transformer_blocks[i].forward = types.MethodType(
forward_with_stg, self.transformer.transformer_blocks[i]
)
# 2. Default call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
negative_prompt,
do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
device=device,
)
if do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
elif do_classifier_free_guidance and self.do_spatio_temporal_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds], dim=0)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
self._num_timesteps = len(timesteps)
# 5. Prepare latents
latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
# For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
patch_size_t = self.transformer.config.patch_size_t
additional_frames = 0
if patch_size_t is not None and latent_frames % patch_size_t != 0:
additional_frames = patch_size_t - latent_frames % patch_size_t
num_frames += additional_frames * self.vae_scale_factor_temporal
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
latent_channels,
num_frames,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Create rotary embeds if required
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
# for DPM-solver++
old_pred_original_sample = None
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
if do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
latent_model_input = torch.cat([latents] * 2)
elif do_classifier_free_guidance and self.do_spatio_temporal_guidance:
latent_model_input = torch.cat([latents] * 3)
else:
latent_model_input = latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
# perform guidance
if use_dynamic_cfg:
self._guidance_scale = 1 + guidance_scale * (
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
)
if do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
elif do_classifier_free_guidance and self.do_spatio_temporal_guidance:
noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
noise_pred = (
noise_pred_uncond
+ self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ self._stg_scale * (noise_pred_text - noise_pred_perturb)
)
if do_rescaling:
rescaling_scale = 0.7
factor = noise_pred_text.std() / noise_pred.std()
factor = rescaling_scale * factor + (1 - rescaling_scale)
noise_pred = noise_pred * factor
# compute the previous noisy sample x_t -> x_t-1
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
else:
latents, old_pred_original_sample = self.scheduler.step(
noise_pred,
old_pred_original_sample,
t,
timesteps[i - 1] if i > 0 else None,
latents,
**extra_step_kwargs,
return_dict=False,
)
latents = latents.to(prompt_embeds.dtype)
# call the callback, if provided
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if not output_type == "latent":
# Discard any padding frames that were added for CogVideoX 1.5
latents = latents[:, additional_frames:]
video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return CogVideoXPipelineOutput(frames=video)
@@ -0,0 +1,794 @@
# Copyright 2024 The HunyuanVideo Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import types
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.loaders import HunyuanVideoLoraLoaderMixin
from diffusers.models import AutoencoderKLHunyuanVideo, HunyuanVideoTransformer3DModel
from diffusers.pipelines.hunyuan_video.pipeline_output import HunyuanVideoPipelineOutput
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers.utils import export_to_video
>>> from diffusers import HunyuanVideoTransformer3DModel
>>> from examples.community.pipeline_stg_hunyuan_video import HunyuanVideoSTGPipeline
>>> model_id = "hunyuanvideo-community/HunyuanVideo"
>>> transformer = HunyuanVideoTransformer3DModel.from_pretrained(
... model_id, subfolder="transformer", torch_dtype=torch.bfloat16
... )
>>> pipe = HunyuanVideoSTGPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16)
>>> pipe.vae.enable_tiling()
>>> pipe.to("cuda")
>>> # Configure STG mode options
>>> stg_applied_layers_idx = [2] # Layer indices from 0 to 41
>>> stg_scale = 1.0 # Set 0.0 for CFG
>>> output = pipe(
... prompt="A wolf howling at the moon, with the moon subtly resembling a giant clock face, realistic style.",
... height=320,
... width=512,
... num_frames=61,
... num_inference_steps=30,
... stg_applied_layers_idx=stg_applied_layers_idx,
... stg_scale=stg_scale,
>>> ).frames[0]
>>> export_to_video(output, "output.mp4", fps=15)
```
"""
DEFAULT_PROMPT_TEMPLATE = {
"template": (
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
"1. The main content and theme of the video."
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
"4. background environment, light, style and atmosphere."
"5. camera angles, movements, and transitions used in the video:<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
),
"crop_start": 95,
}
def forward_with_stg(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
return hidden_states, encoder_hidden_states
def forward_without_stg(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Input normalization
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
)
# 2. Joint attention
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=freqs_cis,
)
# 3. Modulation and residual connection
hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1)
encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
norm_hidden_states = self.norm2(hidden_states)
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
# 4. Feed-forward
ff_output = self.ff(norm_hidden_states)
context_ff_output = self.ff_context(norm_encoder_hidden_states)
hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
return hidden_states, encoder_hidden_states
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class HunyuanVideoSTGPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
r"""
Pipeline for text-to-video generation using HunyuanVideo.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Args:
text_encoder ([`LlamaModel`]):
[Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
tokenizer (`LlamaTokenizer`):
Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers).
transformer ([`HunyuanVideoTransformer3DModel`]):
Conditional Transformer to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLHunyuanVideo`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
text_encoder_2 ([`CLIPTextModel`]):
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer_2 (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
text_encoder: LlamaModel,
tokenizer: LlamaTokenizerFast,
transformer: HunyuanVideoTransformer3DModel,
vae: AutoencoderKLHunyuanVideo,
scheduler: FlowMatchEulerDiscreteScheduler,
text_encoder_2: CLIPTextModel,
tokenizer_2: CLIPTokenizer,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
)
self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
def _get_llama_prompt_embeds(
self,
prompt: Union[str, List[str]],
prompt_template: Dict[str, Any],
num_videos_per_prompt: int = 1,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 256,
num_hidden_layers_to_skip: int = 2,
) -> Tuple[torch.Tensor, torch.Tensor]:
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
prompt = [prompt_template["template"].format(p) for p in prompt]
crop_start = prompt_template.get("crop_start", None)
if crop_start is None:
prompt_template_input = self.tokenizer(
prompt_template["template"],
padding="max_length",
return_tensors="pt",
return_length=False,
return_overflowing_tokens=False,
return_attention_mask=False,
)
crop_start = prompt_template_input["input_ids"].shape[-1]
# Remove <|eot_id|> token and placeholder {}
crop_start -= 2
max_sequence_length += crop_start
text_inputs = self.tokenizer(
prompt,
max_length=max_sequence_length,
padding="max_length",
truncation=True,
return_tensors="pt",
return_length=False,
return_overflowing_tokens=False,
return_attention_mask=True,
)
text_input_ids = text_inputs.input_ids.to(device=device)
prompt_attention_mask = text_inputs.attention_mask.to(device=device)
prompt_embeds = self.text_encoder(
input_ids=text_input_ids,
attention_mask=prompt_attention_mask,
output_hidden_states=True,
).hidden_states[-(num_hidden_layers_to_skip + 1)]
prompt_embeds = prompt_embeds.to(dtype=dtype)
if crop_start is not None and crop_start > 0:
prompt_embeds = prompt_embeds[:, crop_start:]
prompt_attention_mask = prompt_attention_mask[:, crop_start:]
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt)
prompt_attention_mask = prompt_attention_mask.view(batch_size * num_videos_per_prompt, seq_len)
return prompt_embeds, prompt_attention_mask
def _get_clip_prompt_embeds(
self,
prompt: Union[str, List[str]],
num_videos_per_prompt: int = 1,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 77,
) -> torch.Tensor:
device = device or self._execution_device
dtype = dtype or self.text_encoder_2.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = self.tokenizer_2(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, -1)
return prompt_embeds
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]] = None,
prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
pooled_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 256,
):
if prompt_embeds is None:
prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds(
prompt,
prompt_template,
num_videos_per_prompt,
device=device,
dtype=dtype,
max_sequence_length=max_sequence_length,
)
if pooled_prompt_embeds is None:
if prompt_2 is None and pooled_prompt_embeds is None:
prompt_2 = prompt
pooled_prompt_embeds = self._get_clip_prompt_embeds(
prompt,
num_videos_per_prompt,
device=device,
dtype=dtype,
max_sequence_length=77,
)
return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask
def check_inputs(
self,
prompt,
prompt_2,
height,
width,
prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
prompt_template=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt_2 is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
if prompt_template is not None:
if not isinstance(prompt_template, dict):
raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}")
if "template" not in prompt_template:
raise ValueError(
f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}"
)
def prepare_latents(
self,
batch_size: int,
num_channels_latents: 32,
height: int = 720,
width: int = 1280,
num_frames: int = 129,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if latents is not None:
return latents.to(device=device, dtype=dtype)
shape = (
batch_size,
num_channels_latents,
num_frames,
int(height) // self.vae_scale_factor_spatial,
int(width) // self.vae_scale_factor_spatial,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_spatio_temporal_guidance(self):
return self._stg_scale > 0.0
@property
def num_timesteps(self):
return self._num_timesteps
@property
def attention_kwargs(self):
return self._attention_kwargs
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Union[str, List[str]] = None,
height: int = 720,
width: int = 1280,
num_frames: int = 129,
num_inference_steps: int = 50,
sigmas: List[float] = None,
guidance_scale: float = 6.0,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
pooled_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
max_sequence_length: int = 256,
stg_applied_layers_idx: Optional[List[int]] = [2],
stg_scale: Optional[float] = 0.0,
):
r"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
will be used instead.
height (`int`, defaults to `720`):
The height in pixels of the generated image.
width (`int`, defaults to `1280`):
The width in pixels of the generated image.
num_frames (`int`, defaults to `129`):
The number of frames in the generated video.
num_inference_steps (`int`, defaults to `50`):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, defaults to `6.0`):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality. Note that the only available HunyuanVideo model is
CFG-distilled, which means that traditional guidance between unconditional and conditional latent is
not applied.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
Examples:
Returns:
[`~HunyuanVideoPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned
where the first element is a list with the generated images and the second element is a list of `bool`s
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
height,
width,
prompt_embeds,
callback_on_step_end_tensor_inputs,
prompt_template,
)
self._stg_scale = stg_scale
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
device = self._execution_device
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# 3. Encode input prompt
prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
prompt_template=prompt_template,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
device=device,
max_sequence_length=max_sequence_length,
)
transformer_dtype = self.transformer.dtype
prompt_embeds = prompt_embeds.to(transformer_dtype)
prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)
if pooled_prompt_embeds is not None:
pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype)
# 4. Prepare timesteps
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
)
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
num_latent_frames,
torch.float32,
device,
generator,
latents,
)
# 6. Prepare guidance condition
guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = latents.to(transformer_dtype)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
if self.do_spatio_temporal_guidance:
for i in stg_applied_layers_idx:
self.transformer.transformer_blocks[i].forward = types.MethodType(
forward_without_stg, self.transformer.transformer_blocks[i]
)
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
pooled_projections=pooled_prompt_embeds,
guidance=guidance,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_spatio_temporal_guidance:
for i in stg_applied_layers_idx:
self.transformer.transformer_blocks[i].forward = types.MethodType(
forward_with_stg, self.transformer.transformer_blocks[i]
)
noise_pred_perturb = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
pooled_projections=pooled_prompt_embeds,
guidance=guidance,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred + self._stg_scale * (noise_pred - noise_pred_perturb)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if not output_type == "latent":
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
video = self.vae.decode(latents, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return HunyuanVideoPipelineOutput(frames=video)
+886
View File
@@ -0,0 +1,886 @@
# Copyright 2024 Lightricks and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import types
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import T5EncoderModel, T5TokenizerFast
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
from diffusers.models.autoencoders import AutoencoderKLLTXVideo
from diffusers.models.transformers import LTXVideoTransformer3DModel
from diffusers.pipelines.ltx.pipeline_output import LTXPipelineOutput
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers.utils import export_to_video
>>> from examples.community.pipeline_stg_ltx import LTXSTGPipeline
>>> pipe = LTXSTGPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> prompt = "A woman with light skin, wearing a blue jacket and a black hat with a veil, looks down and to her right, then back up as she speaks; she has brown hair styled in an updo, light brown eyebrows, and is wearing a white collared shirt under her jacket; the camera remains stationary on her face as she speaks; the background is out of focus, but shows trees and people in period clothing; the scene is captured in real-life footage."
>>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
>>> # Configure STG mode options
>>> stg_applied_layers_idx = [19] # Layer indices from 0 to 41
>>> stg_scale = 1.0 # Set 0.0 for CFG
>>> do_rescaling = False
>>> video = pipe(
... prompt=prompt,
... negative_prompt=negative_prompt,
... width=704,
... height=480,
... num_frames=161,
... num_inference_steps=50,
... stg_applied_layers_idx=stg_applied_layers_idx,
... stg_scale=stg_scale,
... do_rescaling=do_rescaling,
>>> ).frames[0]
>>> export_to_video(video, "output.mp4", fps=24)
```
"""
def forward_with_stg(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states_ptb = hidden_states[2:]
encoder_hidden_states_ptb = encoder_hidden_states[2:]
batch_size = hidden_states.size(0)
norm_hidden_states = self.norm1(hidden_states)
num_ada_params = self.scale_shift_table.shape[0]
ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
attn_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=None,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + attn_hidden_states * gate_msa
attn_hidden_states = self.attn2(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
image_rotary_emb=None,
attention_mask=encoder_attention_mask,
)
hidden_states = hidden_states + attn_hidden_states
norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp) + shift_mlp
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + ff_output * gate_mlp
hidden_states[2:] = hidden_states_ptb
encoder_hidden_states[2:] = encoder_hidden_states_ptb
return hidden_states
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.16,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class LTXSTGPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
r"""
Pipeline for text-to-video generation.
Reference: https://github.com/Lightricks/LTX-Video
Args:
transformer ([`LTXVideoTransformer3DModel`]):
Conditional Transformer architecture to denoise the encoded video latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLLTXVideo`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
tokenizer (`T5TokenizerFast`):
Second Tokenizer of class
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_optional_components = []
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKLLTXVideo,
text_encoder: T5EncoderModel,
tokenizer: T5TokenizerFast,
transformer: LTXVideoTransformer3DModel,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
)
self.vae_spatial_compression_ratio = (
self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
)
self.vae_temporal_compression_ratio = (
self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
)
self.transformer_spatial_patch_size = (
self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1
)
self.transformer_temporal_patch_size = (
self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 128
)
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_videos_per_prompt: int = 1,
max_sequence_length: int = 128,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.bool().to(device)
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
return prompt_embeds, prompt_attention_mask
# Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline.encode_prompt with 256->128
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
max_sequence_length: int = 128,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
prompt=negative_prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
def check_inputs(
self,
prompt,
height,
width,
callback_on_step_end_tensor_inputs=None,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
):
if height % 32 != 0 or width % 32 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt_embeds is not None and prompt_attention_mask is None:
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
raise ValueError(
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
f" {negative_prompt_attention_mask.shape}."
)
@staticmethod
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
# The patch dimensions are then permuted and collapsed into the channel dimension of shape:
# [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
# dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
batch_size, num_channels, num_frames, height, width = latents.shape
post_patch_num_frames = num_frames // patch_size_t
post_patch_height = height // patch_size
post_patch_width = width // patch_size
latents = latents.reshape(
batch_size,
-1,
post_patch_num_frames,
patch_size_t,
post_patch_height,
patch_size,
post_patch_width,
patch_size,
)
latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
return latents
@staticmethod
def _unpack_latents(
latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
) -> torch.Tensor:
# Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)
# are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of
# what happens in the `_pack_latents` method.
batch_size = latents.size(0)
latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
return latents
@staticmethod
def _normalize_latents(
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
) -> torch.Tensor:
# Normalize latents across the channel dimension [B, C, F, H, W]
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents = (latents - latents_mean) * scaling_factor / latents_std
return latents
@staticmethod
def _denormalize_latents(
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
) -> torch.Tensor:
# Denormalize latents across the channel dimension [B, C, F, H, W]
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents = latents * latents_std / scaling_factor + latents_mean
return latents
def prepare_latents(
self,
batch_size: int = 1,
num_channels_latents: int = 128,
height: int = 512,
width: int = 704,
num_frames: int = 161,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if latents is not None:
return latents.to(device=device, dtype=dtype)
height = height // self.vae_spatial_compression_ratio
width = width // self.vae_spatial_compression_ratio
num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
shape = (batch_size, num_channels_latents, num_frames, height, width)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_latents(
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
)
return latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0
@property
def do_spatio_temporal_guidance(self):
return self._stg_scale > 0.0
@property
def num_timesteps(self):
return self._num_timesteps
@property
def attention_kwargs(self):
return self._attention_kwargs
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 512,
width: int = 704,
num_frames: int = 161,
frame_rate: int = 25,
num_inference_steps: int = 50,
timesteps: List[int] = None,
guidance_scale: float = 3,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
decode_timestep: Union[float, List[float]] = 0.0,
decode_noise_scale: Optional[Union[float, List[float]]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 128,
stg_applied_layers_idx: Optional[List[int]] = [19],
stg_scale: Optional[float] = 1.0,
do_rescaling: Optional[bool] = False,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, defaults to `512`):
The height in pixels of the generated image. This is set to 480 by default for the best results.
width (`int`, defaults to `704`):
The width in pixels of the generated image. This is set to 848 by default for the best results.
num_frames (`int`, defaults to `161`):
The number of video frames to generate
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, defaults to `3 `):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
prompt_attention_mask (`torch.Tensor`, *optional*):
Pre-generated attention mask for text embeddings.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
Pre-generated attention mask for negative text embeddings.
decode_timestep (`float`, defaults to `0.0`):
The timestep at which generated video is decoded.
decode_noise_scale (`float`, defaults to `None`):
The interpolation factor between random noise and denoised latents at the decode timestep.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to `128 `):
Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.ltx.LTXPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is
returned where the first element is a list with the generated images.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt=prompt,
height=height,
width=width,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
)
self._stg_scale = stg_scale
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._interrupt = False
if self.do_spatio_temporal_guidance:
for i in stg_applied_layers_idx:
self.transformer.transformer_blocks[i].forward = types.MethodType(
forward_with_stg, self.transformer.transformer_blocks[i]
)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# 3. Prepare text embeddings
(
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
max_sequence_length=max_sequence_length,
device=device,
)
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat(
[negative_prompt_attention_mask, prompt_attention_mask, prompt_attention_mask], dim=0
)
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
num_frames,
torch.float32,
device,
generator,
latents,
)
# 5. Prepare timesteps
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
latent_height = height // self.vae_spatial_compression_ratio
latent_width = width // self.vae_spatial_compression_ratio
video_sequence_length = latent_num_frames * latent_height * latent_width
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
mu = calculate_shift(
video_sequence_length,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.16),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# 6. Prepare micro-conditions
latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
rope_interpolation_scale = (
1 / latent_frame_rate,
self.vae_spatial_compression_ratio,
self.vae_spatial_compression_ratio,
)
# 7. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
latent_model_input = torch.cat([latents] * 2)
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
latent_model_input = torch.cat([latents] * 3)
else:
latent_model_input = latents
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
encoder_attention_mask=prompt_attention_mask,
num_frames=latent_num_frames,
height=latent_height,
width=latent_width,
rope_interpolation_scale=rope_interpolation_scale,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
noise_pred = (
noise_pred_uncond
+ self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ self._stg_scale * (noise_pred_text - noise_pred_perturb)
)
if do_rescaling:
rescaling_scale = 0.7
factor = noise_pred_text.std() / noise_pred.std()
factor = rescaling_scale * factor + (1 - rescaling_scale)
noise_pred = noise_pred * factor
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
video = latents
else:
latents = self._unpack_latents(
latents,
latent_num_frames,
latent_height,
latent_width,
self.transformer_spatial_patch_size,
self.transformer_temporal_patch_size,
)
latents = self._denormalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
latents = latents.to(prompt_embeds.dtype)
if not self.vae.config.timestep_conditioning:
timestep = None
else:
noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
if not isinstance(decode_timestep, list):
decode_timestep = [decode_timestep] * batch_size
if decode_noise_scale is None:
decode_noise_scale = decode_timestep
elif not isinstance(decode_noise_scale, list):
decode_noise_scale = [decode_noise_scale] * batch_size
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
:, None, None, None, None
]
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
video = self.vae.decode(latents, timestep, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return LTXPipelineOutput(frames=video)
@@ -0,0 +1,985 @@
# Copyright 2024 Lightricks and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import types
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import T5EncoderModel, T5TokenizerFast
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.image_processor import PipelineImageInput
from diffusers.loaders import FromSingleFileMixin, LTXVideoLoraLoaderMixin
from diffusers.models.autoencoders import AutoencoderKLLTXVideo
from diffusers.models.transformers import LTXVideoTransformer3DModel
from diffusers.pipelines.ltx.pipeline_output import LTXPipelineOutput
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers.utils import export_to_video, load_image
>>> from examples.community.pipeline_stg_ltx_image2video import LTXImageToVideoSTGPipeline
>>> pipe = LTXImageToVideoSTGPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> image = load_image(
... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/11.png"
>>> )
>>> prompt = "A medieval fantasy scene featuring a rugged man with shoulder-length brown hair and a beard. He wears a dark leather tunic over a maroon shirt with intricate metal details. His facial expression is serious and intense, and he is making a gesture with his right hand, forming a small circle with his thumb and index finger. The warm golden lighting casts dramatic shadows on his face. The background includes an ornate stone arch and blurred medieval-style decor, creating an epic atmosphere."
>>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
>>> # Configure STG mode options
>>> stg_applied_layers_idx = [19] # Layer indices from 0 to 41
>>> stg_scale = 1.0 # Set 0.0 for CFG
>>> do_rescaling = False
>>> video = pipe(
... image=image,
... prompt=prompt,
... negative_prompt=negative_prompt,
... width=704,
... height=480,
... num_frames=161,
... num_inference_steps=50,
... stg_applied_layers_idx=stg_applied_layers_idx,
... stg_scale=stg_scale,
... do_rescaling=do_rescaling,
>>> ).frames[0]
>>> export_to_video(video, "output.mp4", fps=24)
```
"""
def forward_with_stg(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states_ptb = hidden_states[2:]
encoder_hidden_states_ptb = encoder_hidden_states[2:]
batch_size = hidden_states.size(0)
norm_hidden_states = self.norm1(hidden_states)
num_ada_params = self.scale_shift_table.shape[0]
ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
attn_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=None,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + attn_hidden_states * gate_msa
attn_hidden_states = self.attn2(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
image_rotary_emb=None,
attention_mask=encoder_attention_mask,
)
hidden_states = hidden_states + attn_hidden_states
norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp) + shift_mlp
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + ff_output * gate_mlp
hidden_states[2:] = hidden_states_ptb
encoder_hidden_states[2:] = encoder_hidden_states_ptb
return hidden_states
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.16,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class LTXImageToVideoSTGPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
r"""
Pipeline for image-to-video generation.
Reference: https://github.com/Lightricks/LTX-Video
Args:
transformer ([`LTXVideoTransformer3DModel`]):
Conditional Transformer architecture to denoise the encoded video latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLLTXVideo`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
tokenizer (`T5TokenizerFast`):
Second Tokenizer of class
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_optional_components = []
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKLLTXVideo,
text_encoder: T5EncoderModel,
tokenizer: T5TokenizerFast,
transformer: LTXVideoTransformer3DModel,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
)
self.vae_spatial_compression_ratio = (
self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
)
self.vae_temporal_compression_ratio = (
self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
)
self.transformer_spatial_patch_size = (
self.transformer.config.patch_size if getattr(self, "transformer", None) is not None else 1
)
self.transformer_temporal_patch_size = (
self.transformer.config.patch_size_t if getattr(self, "transformer") is not None else 1
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 128
)
self.default_height = 512
self.default_width = 704
self.default_frames = 121
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_videos_per_prompt: int = 1,
max_sequence_length: int = 128,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.bool().to(device)
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
return prompt_embeds, prompt_attention_mask
# Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline.encode_prompt with 256->128
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
max_sequence_length: int = 128,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
prompt=negative_prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline.check_inputs
def check_inputs(
self,
prompt,
height,
width,
callback_on_step_end_tensor_inputs=None,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
):
if height % 32 != 0 or width % 32 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt_embeds is not None and prompt_attention_mask is None:
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
raise ValueError(
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
f" {negative_prompt_attention_mask.shape}."
)
@staticmethod
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._pack_latents
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
# The patch dimensions are then permuted and collapsed into the channel dimension of shape:
# [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
# dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
batch_size, num_channels, num_frames, height, width = latents.shape
post_patch_num_frames = num_frames // patch_size_t
post_patch_height = height // patch_size
post_patch_width = width // patch_size
latents = latents.reshape(
batch_size,
-1,
post_patch_num_frames,
patch_size_t,
post_patch_height,
patch_size,
post_patch_width,
patch_size,
)
latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
return latents
@staticmethod
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._unpack_latents
def _unpack_latents(
latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
) -> torch.Tensor:
# Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)
# are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of
# what happens in the `_pack_latents` method.
batch_size = latents.size(0)
latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
return latents
@staticmethod
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents
def _normalize_latents(
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
) -> torch.Tensor:
# Normalize latents across the channel dimension [B, C, F, H, W]
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents = (latents - latents_mean) * scaling_factor / latents_std
return latents
@staticmethod
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._denormalize_latents
def _denormalize_latents(
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
) -> torch.Tensor:
# Denormalize latents across the channel dimension [B, C, F, H, W]
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents = latents * latents_std / scaling_factor + latents_mean
return latents
def prepare_latents(
self,
image: Optional[torch.Tensor] = None,
batch_size: int = 1,
num_channels_latents: int = 128,
height: int = 512,
width: int = 704,
num_frames: int = 161,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
height = height // self.vae_spatial_compression_ratio
width = width // self.vae_spatial_compression_ratio
num_frames = (
(num_frames - 1) // self.vae_temporal_compression_ratio + 1 if latents is None else latents.size(2)
)
shape = (batch_size, num_channels_latents, num_frames, height, width)
mask_shape = (batch_size, 1, num_frames, height, width)
if latents is not None:
conditioning_mask = latents.new_zeros(shape)
conditioning_mask[:, :, 0] = 1.0
conditioning_mask = self._pack_latents(
conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
)
return latents.to(device=device, dtype=dtype), conditioning_mask
if isinstance(generator, list):
if len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
init_latents = [
retrieve_latents(self.vae.encode(image[i].unsqueeze(0).unsqueeze(2)), generator[i])
for i in range(batch_size)
]
else:
init_latents = [
retrieve_latents(self.vae.encode(img.unsqueeze(0).unsqueeze(2)), generator) for img in image
]
init_latents = torch.cat(init_latents, dim=0).to(dtype)
init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std)
init_latents = init_latents.repeat(1, 1, num_frames, 1, 1)
conditioning_mask = torch.zeros(mask_shape, device=device, dtype=dtype)
conditioning_mask[:, :, 0] = 1.0
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = init_latents * conditioning_mask + noise * (1 - conditioning_mask)
conditioning_mask = self._pack_latents(
conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
).squeeze(-1)
latents = self._pack_latents(
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
)
return latents, conditioning_mask
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0
@property
def do_spatio_temporal_guidance(self):
return self._stg_scale > 0.0
@property
def num_timesteps(self):
return self._num_timesteps
@property
def attention_kwargs(self):
return self._attention_kwargs
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
image: PipelineImageInput = None,
prompt: Union[str, List[str]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 512,
width: int = 704,
num_frames: int = 161,
frame_rate: int = 25,
num_inference_steps: int = 50,
timesteps: List[int] = None,
guidance_scale: float = 3,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
decode_timestep: Union[float, List[float]] = 0.0,
decode_noise_scale: Optional[Union[float, List[float]]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 128,
stg_applied_layers_idx: Optional[List[int]] = [19],
stg_scale: Optional[float] = 1.0,
do_rescaling: Optional[bool] = False,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
image (`PipelineImageInput`):
The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, defaults to `512`):
The height in pixels of the generated image. This is set to 480 by default for the best results.
width (`int`, defaults to `704`):
The width in pixels of the generated image. This is set to 848 by default for the best results.
num_frames (`int`, defaults to `161`):
The number of video frames to generate
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, defaults to `3 `):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
prompt_attention_mask (`torch.Tensor`, *optional*):
Pre-generated attention mask for text embeddings.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
Pre-generated attention mask for negative text embeddings.
decode_timestep (`float`, defaults to `0.0`):
The timestep at which generated video is decoded.
decode_noise_scale (`float`, defaults to `None`):
The interpolation factor between random noise and denoised latents at the decode timestep.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to `128 `):
Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.ltx.LTXPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is
returned where the first element is a list with the generated images.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt=prompt,
height=height,
width=width,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
)
self._stg_scale = stg_scale
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._interrupt = False
if self.do_spatio_temporal_guidance:
for i in stg_applied_layers_idx:
self.transformer.transformer_blocks[i].forward = types.MethodType(
forward_with_stg, self.transformer.transformer_blocks[i]
)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# 3. Prepare text embeddings
(
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
max_sequence_length=max_sequence_length,
device=device,
)
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat(
[negative_prompt_attention_mask, prompt_attention_mask, prompt_attention_mask], dim=0
)
# 4. Prepare latent variables
if latents is None:
image = self.video_processor.preprocess(image, height=height, width=width)
image = image.to(device=device, dtype=prompt_embeds.dtype)
num_channels_latents = self.transformer.config.in_channels
latents, conditioning_mask = self.prepare_latents(
image,
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
num_frames,
torch.float32,
device,
generator,
latents,
)
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
conditioning_mask = torch.cat([conditioning_mask, conditioning_mask])
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
conditioning_mask = torch.cat([conditioning_mask, conditioning_mask, conditioning_mask])
# 5. Prepare timesteps
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
latent_height = height // self.vae_spatial_compression_ratio
latent_width = width // self.vae_spatial_compression_ratio
video_sequence_length = latent_num_frames * latent_height * latent_width
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
mu = calculate_shift(
video_sequence_length,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.16),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# 6. Prepare micro-conditions
latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
rope_interpolation_scale = (
1 / latent_frame_rate,
self.vae_spatial_compression_ratio,
self.vae_spatial_compression_ratio,
)
# 7. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
latent_model_input = torch.cat([latents] * 2)
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
latent_model_input = torch.cat([latents] * 3)
else:
latent_model_input = latents
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
encoder_attention_mask=prompt_attention_mask,
num_frames=latent_num_frames,
height=latent_height,
width=latent_width,
rope_interpolation_scale=rope_interpolation_scale,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
timestep, _ = timestep.chunk(2)
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
noise_pred = (
noise_pred_uncond
+ self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ self._stg_scale * (noise_pred_text - noise_pred_perturb)
)
timestep, _, _ = timestep.chunk(3)
if do_rescaling:
rescaling_scale = 0.7
factor = noise_pred_text.std() / noise_pred.std()
factor = rescaling_scale * factor + (1 - rescaling_scale)
noise_pred = noise_pred * factor
# compute the previous noisy sample x_t -> x_t-1
noise_pred = self._unpack_latents(
noise_pred,
latent_num_frames,
latent_height,
latent_width,
self.transformer_spatial_patch_size,
self.transformer_temporal_patch_size,
)
latents = self._unpack_latents(
latents,
latent_num_frames,
latent_height,
latent_width,
self.transformer_spatial_patch_size,
self.transformer_temporal_patch_size,
)
noise_pred = noise_pred[:, :, 1:]
noise_latents = latents[:, :, 1:]
pred_latents = self.scheduler.step(noise_pred, t, noise_latents, return_dict=False)[0]
latents = torch.cat([latents[:, :, :1], pred_latents], dim=2)
latents = self._pack_latents(
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
video = latents
else:
latents = self._unpack_latents(
latents,
latent_num_frames,
latent_height,
latent_width,
self.transformer_spatial_patch_size,
self.transformer_temporal_patch_size,
)
latents = self._denormalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
latents = latents.to(prompt_embeds.dtype)
if not self.vae.config.timestep_conditioning:
timestep = None
else:
noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype)
if not isinstance(decode_timestep, list):
decode_timestep = [decode_timestep] * batch_size
if decode_noise_scale is None:
decode_noise_scale = decode_timestep
elif not isinstance(decode_noise_scale, list):
decode_noise_scale = [decode_noise_scale] * batch_size
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
:, None, None, None, None
]
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
video = self.vae.decode(latents, timestep, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return LTXPipelineOutput(frames=video)
+843
View File
@@ -0,0 +1,843 @@
# Copyright 2024 Genmo and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import types
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from transformers import T5EncoderModel, T5TokenizerFast
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.loaders import Mochi1LoraLoaderMixin
from diffusers.models import AutoencoderKLMochi, MochiTransformer3DModel
from diffusers.pipelines.mochi.pipeline_output import MochiPipelineOutput
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import (
is_torch_xla_available,
logging,
replace_example_docstring,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers.utils import export_to_video
>>> from examples.community.pipeline_stg_mochi import MochiSTGPipeline
>>> pipe = MochiSTGPipeline.from_pretrained("genmo/mochi-1-preview", torch_dtype=torch.bfloat16)
>>> pipe.enable_model_cpu_offload()
>>> pipe.enable_vae_tiling()
>>> prompt = "A close-up of a beautiful woman's face with colored powder exploding around her, creating an abstract splash of vibrant hues, realistic style."
>>> # Configure STG mode options
>>> stg_applied_layers_idx = [34] # Layer indices from 0 to 41
>>> stg_scale = 1.0 # Set 0.0 for CFG
>>> do_rescaling = False
>>> frames = pipe(
... prompt=prompt,
... num_inference_steps=28,
... guidance_scale=3.5,
... stg_applied_layers_idx=stg_applied_layers_idx,
... stg_scale=stg_scale,
... do_rescaling=do_rescaling).frames[0]
>>> export_to_video(frames, "mochi.mp4")
```
"""
def forward_with_stg(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
encoder_attention_mask: torch.Tensor,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
hidden_states_ptb = hidden_states[2:]
encoder_hidden_states_ptb = encoder_hidden_states[2:]
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
if not self.context_pre_only:
norm_encoder_hidden_states, enc_gate_msa, enc_scale_mlp, enc_gate_mlp = self.norm1_context(
encoder_hidden_states, temb
)
else:
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
attn_hidden_states, context_attn_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
attention_mask=encoder_attention_mask,
)
hidden_states = hidden_states + self.norm2(attn_hidden_states, torch.tanh(gate_msa).unsqueeze(1))
norm_hidden_states = self.norm3(hidden_states, (1 + scale_mlp.unsqueeze(1).to(torch.float32)))
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + self.norm4(ff_output, torch.tanh(gate_mlp).unsqueeze(1))
if not self.context_pre_only:
encoder_hidden_states = encoder_hidden_states + self.norm2_context(
context_attn_hidden_states, torch.tanh(enc_gate_msa).unsqueeze(1)
)
norm_encoder_hidden_states = self.norm3_context(
encoder_hidden_states, (1 + enc_scale_mlp.unsqueeze(1).to(torch.float32))
)
context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + self.norm4_context(
context_ff_output, torch.tanh(enc_gate_mlp).unsqueeze(1)
)
hidden_states[2:] = hidden_states_ptb
encoder_hidden_states[2:] = encoder_hidden_states_ptb
return hidden_states, encoder_hidden_states
# from: https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
if linear_steps is None:
linear_steps = num_steps // 2
linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
threshold_noise_step_diff = linear_steps - threshold_noise * num_steps
quadratic_steps = num_steps - linear_steps
quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2)
linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2)
const = quadratic_coef * (linear_steps**2)
quadratic_sigma_schedule = [
quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps)
]
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule
sigma_schedule = [1.0 - x for x in sigma_schedule]
return sigma_schedule
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom value")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class MochiSTGPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
r"""
The mochi pipeline for text-to-video generation.
Reference: https://github.com/genmoai/models
Args:
transformer ([`MochiTransformer3DModel`]):
Conditional Transformer architecture to denoise the encoded video latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLMochi`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
text_encoder ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
tokenizer (`T5TokenizerFast`):
Second Tokenizer of class
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_optional_components = []
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKLMochi,
text_encoder: T5EncoderModel,
tokenizer: T5TokenizerFast,
transformer: MochiTransformer3DModel,
force_zeros_for_empty_prompt: bool = False,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
)
# TODO: determine these scaling factors from model parameters
self.vae_spatial_scale_factor = 8
self.vae_temporal_scale_factor = 6
self.patch_size = 2
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_scale_factor)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 256
)
self.default_height = 480
self.default_width = 848
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_videos_per_prompt: int = 1,
max_sequence_length: int = 256,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.bool().to(device)
# The original Mochi implementation zeros out empty negative prompts
# but this can lead to overflow when placing the entire pipeline under the autocast context
# adding this here so that we can enable zeroing prompts if necessary
if self.config.force_zeros_for_empty_prompt and (prompt == "" or prompt[-1] == ""):
text_input_ids = torch.zeros_like(text_input_ids, device=device)
prompt_attention_mask = torch.zeros_like(prompt_attention_mask, dtype=torch.bool, device=device)
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
return prompt_embeds, prompt_attention_mask
# Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
max_sequence_length: int = 256,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
prompt=negative_prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
def check_inputs(
self,
prompt,
height,
width,
callback_on_step_end_tensor_inputs=None,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt_embeds is not None and prompt_attention_mask is None:
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
raise ValueError(
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
f" {negative_prompt_attention_mask.shape}."
)
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
num_frames,
dtype,
device,
generator,
latents=None,
):
height = height // self.vae_spatial_scale_factor
width = width // self.vae_spatial_scale_factor
num_frames = (num_frames - 1) // self.vae_temporal_scale_factor + 1
shape = (batch_size, num_channels_latents, num_frames, height, width)
if latents is not None:
return latents.to(device=device, dtype=dtype)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=torch.float32)
latents = latents.to(dtype)
return latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0
@property
def do_spatio_temporal_guidance(self):
return self._stg_scale > 0.0
@property
def num_timesteps(self):
return self._num_timesteps
@property
def attention_kwargs(self):
return self._attention_kwargs
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_frames: int = 19,
num_inference_steps: int = 64,
timesteps: List[int] = None,
guidance_scale: float = 4.5,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 256,
stg_applied_layers_idx: Optional[List[int]] = [34],
stg_scale: Optional[float] = 0.0,
do_rescaling: Optional[bool] = False,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, *optional*, defaults to `self.default_height`):
The height in pixels of the generated image. This is set to 480 by default for the best results.
width (`int`, *optional*, defaults to `self.default_width`):
The width in pixels of the generated image. This is set to 848 by default for the best results.
num_frames (`int`, defaults to `19`):
The number of video frames to generate
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, defaults to `4.5`):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
prompt_attention_mask (`torch.Tensor`, *optional*):
Pre-generated attention mask for text embeddings.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
Pre-generated attention mask for negative text embeddings.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.mochi.MochiPipelineOutput`] instead of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to `256`):
Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.mochi.MochiPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.mochi.MochiPipelineOutput`] is returned, otherwise a `tuple`
is returned where the first element is a list with the generated images.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
height = height or self.default_height
width = width or self.default_width
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt=prompt,
height=height,
width=width,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
)
self._guidance_scale = guidance_scale
self._stg_scale = stg_scale
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
if self.do_spatio_temporal_guidance:
for i in stg_applied_layers_idx:
self.transformer.transformer_blocks[i].forward = types.MethodType(
forward_with_stg, self.transformer.transformer_blocks[i]
)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# 3. Prepare text embeddings
(
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
max_sequence_length=max_sequence_length,
device=device,
)
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
num_frames,
prompt_embeds.dtype,
device,
generator,
latents,
)
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat(
[negative_prompt_attention_mask, prompt_attention_mask, prompt_attention_mask], dim=0
)
# 5. Prepare timestep
# from https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
threshold_noise = 0.025
sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise)
sigmas = np.array(sigmas)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# Note: Mochi uses reversed timesteps. To ensure compatibility with methods like FasterCache, we need
# to make sure we're using the correct non-reversed timestep value.
self._current_timestep = 1000 - t
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
latent_model_input = torch.cat([latents] * 2)
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
latent_model_input = torch.cat([latents] * 3)
else:
latent_model_input = latents
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
encoder_attention_mask=prompt_attention_mask,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
# Mochi CFG + Sampling runs in FP32
noise_pred = noise_pred.to(torch.float32)
if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
elif self.do_classifier_free_guidance and self.do_spatio_temporal_guidance:
noise_pred_uncond, noise_pred_text, noise_pred_perturb = noise_pred.chunk(3)
noise_pred = (
noise_pred_uncond
+ self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ self._stg_scale * (noise_pred_text - noise_pred_perturb)
)
if do_rescaling:
rescaling_scale = 0.7
factor = noise_pred_text.std() / noise_pred.std()
factor = rescaling_scale * factor + (1 - rescaling_scale)
noise_pred = noise_pred * factor
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents.to(torch.float32), return_dict=False)[0]
latents = latents.to(latents_dtype)
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if output_type == "latent":
video = latents
else:
# unscale/denormalize the latents
# denormalize with the mean and std if available and not None
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
if has_latents_mean and has_latents_std:
latents_mean = (
torch.tensor(self.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
)
latents_std = (
torch.tensor(self.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
)
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
else:
latents = latents / self.vae.config.scaling_factor
video = self.vae.decode(latents, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return MochiPipelineOutput(frames=video)
@@ -1334,7 +1334,9 @@ def main(args):
# run inference
if args.validation_prompt and args.num_validation_images > 0:
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = (
torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
)
images = [
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator, eta=1.0).images[0]
for _ in range(args.num_validation_images)
+1 -1
View File
@@ -172,7 +172,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
autocast_ctx = nullcontext()
+20 -8
View File
@@ -150,7 +150,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
if args.validation_images is None:
images = []
@@ -1119,17 +1119,22 @@ def main(args):
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
num_training_steps_for_scheduler = (
args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
)
else:
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_warmup_steps=num_warmup_steps_for_scheduler,
num_training_steps=num_training_steps_for_scheduler,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
@@ -1146,8 +1151,15 @@ def main(args):
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
if num_training_steps_for_scheduler != args.max_train_steps:
logger.warning(
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
f"This inconsistency may result in the learning rate scheduler not functioning properly."
)
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
@@ -181,7 +181,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
autocast_ctx = nullcontext()
@@ -167,7 +167,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
with autocast_ctx:
@@ -170,7 +170,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
@@ -199,7 +199,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
autocast_ctx = nullcontext()
@@ -207,7 +207,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
+1 -1
View File
@@ -175,7 +175,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
autocast_ctx = nullcontext()
@@ -0,0 +1,32 @@
# AnyTextPipeline Pipeline
Project page: https://aigcdesigngroup.github.io/homepage_anytext
"AnyText comprises a diffusion pipeline with two primary elements: an auxiliary latent module and a text embedding module. The former uses inputs like text glyph, position, and masked image to generate latent features for text generation or editing. The latter employs an OCR model for encoding stroke data as embeddings, which blend with image caption embeddings from the tokenizer to generate texts that seamlessly integrate with the background. We employed text-control diffusion loss and text perceptual loss for training to further enhance writing accuracy."
Each text line that needs to be generated should be enclosed in double quotes. For any usage questions, please refer to the [paper](https://arxiv.org/abs/2311.03054).
```py
import torch
from diffusers import DiffusionPipeline
from anytext_controlnet import AnyTextControlNetModel
from diffusers.utils import load_image
# I chose a font file shared by an HF staff:
# !wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf
anytext_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16,
variant="fp16",)
pipe = DiffusionPipeline.from_pretrained("tolgacangoz/anytext", font_path="arial-unicode-ms.ttf",
controlnet=anytext_controlnet, torch_dtype=torch.float16,
trust_remote_code=False, # One needs to give permission to run this pipeline's code
).to("cuda")
# generate image
prompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream'
draw_pos = load_image("https://raw.githubusercontent.com/tyxsspa/AnyText/refs/heads/main/example_images/gen9.png")
image = pipe(prompt, num_inference_steps=20, mode="generate", draw_pos=draw_pos,
).images[0]
image
```
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,463 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Based on [AnyText: Multilingual Visual Text Generation And Editing](https://huggingface.co/papers/2311.03054).
# Authors: Yuxiang Tuo, Wangmeng Xiang, Jun-Yan He, Yifeng Geng, Xuansong Xie
# Code: https://github.com/tyxsspa/AnyText with Apache-2.0 license
#
# Adapted to Diffusers by [M. Tolga Cangöz](https://github.com/tolgacangoz).
from typing import Any, Dict, Optional, Tuple, Union
import torch
from torch import nn
from diffusers.configuration_utils import register_to_config
from diffusers.models.controlnets.controlnet import (
ControlNetModel,
ControlNetOutput,
)
from diffusers.utils import logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class AnyTextControlNetConditioningEmbedding(nn.Module):
"""
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 latent images for stabilized
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
model) to encode image-space conditions ... into feature maps ..."
"""
def __init__(
self,
conditioning_embedding_channels: int,
glyph_channels=1,
position_channels=1,
):
super().__init__()
self.glyph_block = nn.Sequential(
nn.Conv2d(glyph_channels, 8, 3, padding=1),
nn.SiLU(),
nn.Conv2d(8, 8, 3, padding=1),
nn.SiLU(),
nn.Conv2d(8, 16, 3, padding=1, stride=2),
nn.SiLU(),
nn.Conv2d(16, 16, 3, padding=1),
nn.SiLU(),
nn.Conv2d(16, 32, 3, padding=1, stride=2),
nn.SiLU(),
nn.Conv2d(32, 32, 3, padding=1),
nn.SiLU(),
nn.Conv2d(32, 96, 3, padding=1, stride=2),
nn.SiLU(),
nn.Conv2d(96, 96, 3, padding=1),
nn.SiLU(),
nn.Conv2d(96, 256, 3, padding=1, stride=2),
nn.SiLU(),
)
self.position_block = nn.Sequential(
nn.Conv2d(position_channels, 8, 3, padding=1),
nn.SiLU(),
nn.Conv2d(8, 8, 3, padding=1),
nn.SiLU(),
nn.Conv2d(8, 16, 3, padding=1, stride=2),
nn.SiLU(),
nn.Conv2d(16, 16, 3, padding=1),
nn.SiLU(),
nn.Conv2d(16, 32, 3, padding=1, stride=2),
nn.SiLU(),
nn.Conv2d(32, 32, 3, padding=1),
nn.SiLU(),
nn.Conv2d(32, 64, 3, padding=1, stride=2),
nn.SiLU(),
)
self.fuse_block = nn.Conv2d(256 + 64 + 4, conditioning_embedding_channels, 3, padding=1)
def forward(self, glyphs, positions, text_info):
glyph_embedding = self.glyph_block(glyphs.to(self.glyph_block[0].weight.device))
position_embedding = self.position_block(positions.to(self.position_block[0].weight.device))
guided_hint = self.fuse_block(torch.cat([glyph_embedding, position_embedding, text_info["masked_x"]], dim=1))
return guided_hint
class AnyTextControlNetModel(ControlNetModel):
"""
A AnyTextControlNetModel model.
Args:
in_channels (`int`, defaults to 4):
The number of channels in the input sample.
flip_sin_to_cos (`bool`, defaults to `True`):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, defaults to 0):
The frequency shift to apply to the time embedding.
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use.
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
layers_per_block (`int`, defaults to 2):
The number of layers per block.
downsample_padding (`int`, defaults to 1):
The padding to use for the downsampling convolution.
mid_block_scale_factor (`float`, defaults to 1):
The scale factor to use for the mid block.
act_fn (`str`, defaults to "silu"):
The activation function to use.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
in post-processing.
norm_eps (`float`, defaults to 1e-5):
The epsilon to use for the normalization.
cross_attention_dim (`int`, defaults to 1280):
The dimension of the cross attention features.
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
encoder_hid_dim (`int`, *optional*, defaults to None):
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
dimension to `cross_attention_dim`.
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
The dimension of the attention heads.
use_linear_projection (`bool`, defaults to `False`):
class_embed_type (`str`, *optional*, defaults to `None`):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
addition_embed_type (`str`, *optional*, defaults to `None`):
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
"text". "text" will use the `TextTimeEmbedding` layer.
num_class_embeds (`int`, *optional*, defaults to 0):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`.
upcast_attention (`bool`, defaults to `False`):
resnet_time_scale_shift (`str`, defaults to `"default"`):
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
`class_embed_type="projection"`.
controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
The tuple of output channel for each block in the `conditioning_embedding` layer.
global_pool_conditions (`bool`, defaults to `False`):
TODO(Patrick) - unused parameter.
addition_embed_type_num_heads (`int`, defaults to 64):
The number of heads to use for the `TextTimeEmbedding` layer.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 4,
conditioning_channels: int = 1,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
addition_embed_type: Optional[str] = None,
addition_time_embed_dim: Optional[int] = None,
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
projection_class_embeddings_input_dim: Optional[int] = None,
controlnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
global_pool_conditions: bool = False,
addition_embed_type_num_heads: int = 64,
):
super().__init__(
in_channels,
conditioning_channels,
flip_sin_to_cos,
freq_shift,
down_block_types,
mid_block_type,
only_cross_attention,
block_out_channels,
layers_per_block,
downsample_padding,
mid_block_scale_factor,
act_fn,
norm_num_groups,
norm_eps,
cross_attention_dim,
transformer_layers_per_block,
encoder_hid_dim,
encoder_hid_dim_type,
attention_head_dim,
num_attention_heads,
use_linear_projection,
class_embed_type,
addition_embed_type,
addition_time_embed_dim,
num_class_embeds,
upcast_attention,
resnet_time_scale_shift,
projection_class_embeddings_input_dim,
controlnet_conditioning_channel_order,
conditioning_embedding_out_channels,
global_pool_conditions,
addition_embed_type_num_heads,
)
# control net conditioning embedding
self.controlnet_cond_embedding = AnyTextControlNetConditioningEmbedding(
conditioning_embedding_channels=block_out_channels[0],
glyph_channels=conditioning_channels,
position_channels=conditioning_channels,
)
def forward(
self,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
controlnet_cond: torch.Tensor,
conditioning_scale: float = 1.0,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guess_mode: bool = False,
return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
"""
The [`~PromptDiffusionControlNetModel`] forward method.
Args:
sample (`torch.Tensor`):
The noisy input tensor.
timestep (`Union[torch.Tensor, float, int]`):
The number of timesteps to denoise an input.
encoder_hidden_states (`torch.Tensor`):
The encoder hidden states.
#controlnet_cond (`torch.Tensor`):
# The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
conditioning_scale (`float`, defaults to `1.0`):
The scale factor for ControlNet outputs.
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
embeddings.
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
negative values to the attention scores corresponding to "discard" tokens.
added_cond_kwargs (`dict`):
Additional conditions for the Stable Diffusion XL UNet.
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
guess_mode (`bool`, defaults to `False`):
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
Returns:
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
returned where the first element is the sample tensor.
"""
# check channel order
channel_order = self.config.controlnet_conditioning_channel_order
if channel_order == "rgb":
# in rgb order by default
...
# elif channel_order == "bgr":
# controlnet_cond = torch.flip(controlnet_cond, dims=[1])
else:
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
# prepare attention_mask
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
aug_emb = None
if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
if self.config.addition_embed_type is not None:
if self.config.addition_embed_type == "text":
aug_emb = self.add_embedding(encoder_hidden_states)
elif self.config.addition_embed_type == "text_time":
if "text_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
)
text_embeds = added_cond_kwargs.get("text_embeds")
if "time_ids" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)
emb = emb + aug_emb if aug_emb is not None else emb
# 2. pre-process
sample = self.conv_in(sample)
controlnet_cond = self.controlnet_cond_embedding(*controlnet_cond)
sample = sample + controlnet_cond
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# 4. mid
if self.mid_block is not None:
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
)
else:
sample = self.mid_block(sample, emb)
# 5. Control net blocks
controlnet_down_block_res_samples = ()
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
down_block_res_sample = controlnet_block(down_block_res_sample)
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
down_block_res_samples = controlnet_down_block_res_samples
mid_block_res_sample = self.controlnet_mid_block(sample)
# 6. scaling
if guess_mode and not self.config.global_pool_conditions:
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
scales = scales * conditioning_scale
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
else:
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
mid_block_res_sample = mid_block_res_sample * conditioning_scale
if self.config.global_pool_conditions:
down_block_res_samples = [
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
]
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
if not return_dict:
return (down_block_res_samples, mid_block_res_sample)
return ControlNetOutput(
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
)
# Copied from diffusers.models.controlnet.zero_module
def zero_module(module):
for p in module.parameters():
nn.init.zeros_(p)
return module
+209
View File
@@ -0,0 +1,209 @@
import torch
from torch import nn
from .RecSVTR import Block
class Swish(nn.Module):
def __int__(self):
super(Swish, self).__int__()
def forward(self, x):
return x * torch.sigmoid(x)
class Im2Im(nn.Module):
def __init__(self, in_channels, **kwargs):
super().__init__()
self.out_channels = in_channels
def forward(self, x):
return x
class Im2Seq(nn.Module):
def __init__(self, in_channels, **kwargs):
super().__init__()
self.out_channels = in_channels
def forward(self, x):
B, C, H, W = x.shape
# assert H == 1
x = x.reshape(B, C, H * W)
x = x.permute((0, 2, 1))
return x
class EncoderWithRNN(nn.Module):
def __init__(self, in_channels, **kwargs):
super(EncoderWithRNN, self).__init__()
hidden_size = kwargs.get("hidden_size", 256)
self.out_channels = hidden_size * 2
self.lstm = nn.LSTM(in_channels, hidden_size, bidirectional=True, num_layers=2, batch_first=True)
def forward(self, x):
self.lstm.flatten_parameters()
x, _ = self.lstm(x)
return x
class SequenceEncoder(nn.Module):
def __init__(self, in_channels, encoder_type="rnn", **kwargs):
super(SequenceEncoder, self).__init__()
self.encoder_reshape = Im2Seq(in_channels)
self.out_channels = self.encoder_reshape.out_channels
self.encoder_type = encoder_type
if encoder_type == "reshape":
self.only_reshape = True
else:
support_encoder_dict = {"reshape": Im2Seq, "rnn": EncoderWithRNN, "svtr": EncoderWithSVTR}
assert encoder_type in support_encoder_dict, "{} must in {}".format(
encoder_type, support_encoder_dict.keys()
)
self.encoder = support_encoder_dict[encoder_type](self.encoder_reshape.out_channels, **kwargs)
self.out_channels = self.encoder.out_channels
self.only_reshape = False
def forward(self, x):
if self.encoder_type != "svtr":
x = self.encoder_reshape(x)
if not self.only_reshape:
x = self.encoder(x)
return x
else:
x = self.encoder(x)
x = self.encoder_reshape(x)
return x
class ConvBNLayer(nn.Module):
def __init__(
self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias_attr=False, groups=1, act=nn.GELU
):
super().__init__()
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
# weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()),
bias=bias_attr,
)
self.norm = nn.BatchNorm2d(out_channels)
self.act = Swish()
def forward(self, inputs):
out = self.conv(inputs)
out = self.norm(out)
out = self.act(out)
return out
class EncoderWithSVTR(nn.Module):
def __init__(
self,
in_channels,
dims=64, # XS
depth=2,
hidden_dims=120,
use_guide=False,
num_heads=8,
qkv_bias=True,
mlp_ratio=2.0,
drop_rate=0.1,
attn_drop_rate=0.1,
drop_path=0.0,
qk_scale=None,
):
super(EncoderWithSVTR, self).__init__()
self.depth = depth
self.use_guide = use_guide
self.conv1 = ConvBNLayer(in_channels, in_channels // 8, padding=1, act="swish")
self.conv2 = ConvBNLayer(in_channels // 8, hidden_dims, kernel_size=1, act="swish")
self.svtr_block = nn.ModuleList(
[
Block(
dim=hidden_dims,
num_heads=num_heads,
mixer="Global",
HW=None,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
act_layer="swish",
attn_drop=attn_drop_rate,
drop_path=drop_path,
norm_layer="nn.LayerNorm",
epsilon=1e-05,
prenorm=False,
)
for i in range(depth)
]
)
self.norm = nn.LayerNorm(hidden_dims, eps=1e-6)
self.conv3 = ConvBNLayer(hidden_dims, in_channels, kernel_size=1, act="swish")
# last conv-nxn, the input is concat of input tensor and conv3 output tensor
self.conv4 = ConvBNLayer(2 * in_channels, in_channels // 8, padding=1, act="swish")
self.conv1x1 = ConvBNLayer(in_channels // 8, dims, kernel_size=1, act="swish")
self.out_channels = dims
self.apply(self._init_weights)
def _init_weights(self, m):
# weight initialization
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.ConvTranspose2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x):
# for use guide
if self.use_guide:
z = x.clone()
z.stop_gradient = True
else:
z = x
# for short cut
h = z
# reduce dim
z = self.conv1(z)
z = self.conv2(z)
# SVTR global block
B, C, H, W = z.shape
z = z.flatten(2).permute(0, 2, 1)
for blk in self.svtr_block:
z = blk(z)
z = self.norm(z)
# last stage
z = z.reshape([-1, H, W, C]).permute(0, 3, 1, 2)
z = self.conv3(z)
z = torch.cat((h, z), dim=1)
z = self.conv1x1(self.conv4(z))
return z
if __name__ == "__main__":
svtrRNN = EncoderWithSVTR(56)
print(svtrRNN)
+45
View File
@@ -0,0 +1,45 @@
from torch import nn
class CTCHead(nn.Module):
def __init__(
self, in_channels, out_channels=6625, fc_decay=0.0004, mid_channels=None, return_feats=False, **kwargs
):
super(CTCHead, self).__init__()
if mid_channels is None:
self.fc = nn.Linear(
in_channels,
out_channels,
bias=True,
)
else:
self.fc1 = nn.Linear(
in_channels,
mid_channels,
bias=True,
)
self.fc2 = nn.Linear(
mid_channels,
out_channels,
bias=True,
)
self.out_channels = out_channels
self.mid_channels = mid_channels
self.return_feats = return_feats
def forward(self, x, labels=None):
if self.mid_channels is None:
predicts = self.fc(x)
else:
x = self.fc1(x)
predicts = self.fc2(x)
if self.return_feats:
result = {}
result["ctc"] = predicts
result["ctc_neck"] = x
else:
result = predicts
return result
+49
View File
@@ -0,0 +1,49 @@
from torch import nn
from .RecCTCHead import CTCHead
from .RecMv1_enhance import MobileNetV1Enhance
from .RNN import Im2Im, Im2Seq, SequenceEncoder
backbone_dict = {"MobileNetV1Enhance": MobileNetV1Enhance}
neck_dict = {"SequenceEncoder": SequenceEncoder, "Im2Seq": Im2Seq, "None": Im2Im}
head_dict = {"CTCHead": CTCHead}
class RecModel(nn.Module):
def __init__(self, config):
super().__init__()
assert "in_channels" in config, "in_channels must in model config"
backbone_type = config["backbone"].pop("type")
assert backbone_type in backbone_dict, f"backbone.type must in {backbone_dict}"
self.backbone = backbone_dict[backbone_type](config["in_channels"], **config["backbone"])
neck_type = config["neck"].pop("type")
assert neck_type in neck_dict, f"neck.type must in {neck_dict}"
self.neck = neck_dict[neck_type](self.backbone.out_channels, **config["neck"])
head_type = config["head"].pop("type")
assert head_type in head_dict, f"head.type must in {head_dict}"
self.head = head_dict[head_type](self.neck.out_channels, **config["head"])
self.name = f"RecModel_{backbone_type}_{neck_type}_{head_type}"
def load_3rd_state_dict(self, _3rd_name, _state):
self.backbone.load_3rd_state_dict(_3rd_name, _state)
self.neck.load_3rd_state_dict(_3rd_name, _state)
self.head.load_3rd_state_dict(_3rd_name, _state)
def forward(self, x):
import torch
x = x.to(torch.float32)
x = self.backbone(x)
x = self.neck(x)
x = self.head(x)
return x
def encode(self, x):
x = self.backbone(x)
x = self.neck(x)
x = self.head.ctc_encoder(x)
return x
@@ -0,0 +1,197 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .common import Activation
class ConvBNLayer(nn.Module):
def __init__(
self, num_channels, filter_size, num_filters, stride, padding, channels=None, num_groups=1, act="hard_swish"
):
super(ConvBNLayer, self).__init__()
self.act = act
self._conv = nn.Conv2d(
in_channels=num_channels,
out_channels=num_filters,
kernel_size=filter_size,
stride=stride,
padding=padding,
groups=num_groups,
bias=False,
)
self._batch_norm = nn.BatchNorm2d(
num_filters,
)
if self.act is not None:
self._act = Activation(act_type=act, inplace=True)
def forward(self, inputs):
y = self._conv(inputs)
y = self._batch_norm(y)
if self.act is not None:
y = self._act(y)
return y
class DepthwiseSeparable(nn.Module):
def __init__(
self, num_channels, num_filters1, num_filters2, num_groups, stride, scale, dw_size=3, padding=1, use_se=False
):
super(DepthwiseSeparable, self).__init__()
self.use_se = use_se
self._depthwise_conv = ConvBNLayer(
num_channels=num_channels,
num_filters=int(num_filters1 * scale),
filter_size=dw_size,
stride=stride,
padding=padding,
num_groups=int(num_groups * scale),
)
if use_se:
self._se = SEModule(int(num_filters1 * scale))
self._pointwise_conv = ConvBNLayer(
num_channels=int(num_filters1 * scale),
filter_size=1,
num_filters=int(num_filters2 * scale),
stride=1,
padding=0,
)
def forward(self, inputs):
y = self._depthwise_conv(inputs)
if self.use_se:
y = self._se(y)
y = self._pointwise_conv(y)
return y
class MobileNetV1Enhance(nn.Module):
def __init__(self, in_channels=3, scale=0.5, last_conv_stride=1, last_pool_type="max", **kwargs):
super().__init__()
self.scale = scale
self.block_list = []
self.conv1 = ConvBNLayer(
num_channels=in_channels, filter_size=3, channels=3, num_filters=int(32 * scale), stride=2, padding=1
)
conv2_1 = DepthwiseSeparable(
num_channels=int(32 * scale), num_filters1=32, num_filters2=64, num_groups=32, stride=1, scale=scale
)
self.block_list.append(conv2_1)
conv2_2 = DepthwiseSeparable(
num_channels=int(64 * scale), num_filters1=64, num_filters2=128, num_groups=64, stride=1, scale=scale
)
self.block_list.append(conv2_2)
conv3_1 = DepthwiseSeparable(
num_channels=int(128 * scale), num_filters1=128, num_filters2=128, num_groups=128, stride=1, scale=scale
)
self.block_list.append(conv3_1)
conv3_2 = DepthwiseSeparable(
num_channels=int(128 * scale),
num_filters1=128,
num_filters2=256,
num_groups=128,
stride=(2, 1),
scale=scale,
)
self.block_list.append(conv3_2)
conv4_1 = DepthwiseSeparable(
num_channels=int(256 * scale), num_filters1=256, num_filters2=256, num_groups=256, stride=1, scale=scale
)
self.block_list.append(conv4_1)
conv4_2 = DepthwiseSeparable(
num_channels=int(256 * scale),
num_filters1=256,
num_filters2=512,
num_groups=256,
stride=(2, 1),
scale=scale,
)
self.block_list.append(conv4_2)
for _ in range(5):
conv5 = DepthwiseSeparable(
num_channels=int(512 * scale),
num_filters1=512,
num_filters2=512,
num_groups=512,
stride=1,
dw_size=5,
padding=2,
scale=scale,
use_se=False,
)
self.block_list.append(conv5)
conv5_6 = DepthwiseSeparable(
num_channels=int(512 * scale),
num_filters1=512,
num_filters2=1024,
num_groups=512,
stride=(2, 1),
dw_size=5,
padding=2,
scale=scale,
use_se=True,
)
self.block_list.append(conv5_6)
conv6 = DepthwiseSeparable(
num_channels=int(1024 * scale),
num_filters1=1024,
num_filters2=1024,
num_groups=1024,
stride=last_conv_stride,
dw_size=5,
padding=2,
use_se=True,
scale=scale,
)
self.block_list.append(conv6)
self.block_list = nn.Sequential(*self.block_list)
if last_pool_type == "avg":
self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
else:
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
self.out_channels = int(1024 * scale)
def forward(self, inputs):
y = self.conv1(inputs)
y = self.block_list(y)
y = self.pool(y)
return y
def hardsigmoid(x):
return F.relu6(x + 3.0, inplace=True) / 6.0
class SEModule(nn.Module):
def __init__(self, channel, reduction=4):
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(
in_channels=channel, out_channels=channel // reduction, kernel_size=1, stride=1, padding=0, bias=True
)
self.conv2 = nn.Conv2d(
in_channels=channel // reduction, out_channels=channel, kernel_size=1, stride=1, padding=0, bias=True
)
def forward(self, inputs):
outputs = self.avg_pool(inputs)
outputs = self.conv1(outputs)
outputs = F.relu(outputs)
outputs = self.conv2(outputs)
outputs = hardsigmoid(outputs)
x = torch.mul(inputs, outputs)
return x
@@ -0,0 +1,570 @@
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional
from torch.nn.init import ones_, trunc_normal_, zeros_
def drop_path(x, drop_prob=0.0, training=False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
"""
if drop_prob == 0.0 or not training:
return x
keep_prob = torch.tensor(1 - drop_prob)
shape = (x.size()[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype)
random_tensor = torch.floor(random_tensor) # binarize
output = x.divide(keep_prob) * random_tensor
return output
class Swish(nn.Module):
def __int__(self):
super(Swish, self).__int__()
def forward(self, x):
return x * torch.sigmoid(x)
class ConvBNLayer(nn.Module):
def __init__(
self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias_attr=False, groups=1, act=nn.GELU
):
super().__init__()
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
# weight_attr=paddle.ParamAttr(initializer=nn.initializer.KaimingUniform()),
bias=bias_attr,
)
self.norm = nn.BatchNorm2d(out_channels)
self.act = act()
def forward(self, inputs):
out = self.conv(inputs)
out = self.norm(out)
out = self.act(out)
return out
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, input):
return input
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
if isinstance(act_layer, str):
self.act = Swish()
else:
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class ConvMixer(nn.Module):
def __init__(
self,
dim,
num_heads=8,
HW=(8, 25),
local_k=(3, 3),
):
super().__init__()
self.HW = HW
self.dim = dim
self.local_mixer = nn.Conv2d(
dim,
dim,
local_k,
1,
(local_k[0] // 2, local_k[1] // 2),
groups=num_heads,
# weight_attr=ParamAttr(initializer=KaimingNormal())
)
def forward(self, x):
h = self.HW[0]
w = self.HW[1]
x = x.transpose([0, 2, 1]).reshape([0, self.dim, h, w])
x = self.local_mixer(x)
x = x.flatten(2).transpose([0, 2, 1])
return x
class Attention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
mixer="Global",
HW=(8, 25),
local_k=(7, 11),
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.HW = HW
if HW is not None:
H = HW[0]
W = HW[1]
self.N = H * W
self.C = dim
if mixer == "Local" and HW is not None:
hk = local_k[0]
wk = local_k[1]
mask = torch.ones([H * W, H + hk - 1, W + wk - 1])
for h in range(0, H):
for w in range(0, W):
mask[h * W + w, h : h + hk, w : w + wk] = 0.0
mask_paddle = mask[:, hk // 2 : H + hk // 2, wk // 2 : W + wk // 2].flatten(1)
mask_inf = torch.full([H * W, H * W], fill_value=float("-inf"))
mask = torch.where(mask_paddle < 1, mask_paddle, mask_inf)
self.mask = mask[None, None, :]
# self.mask = mask.unsqueeze([0, 1])
self.mixer = mixer
def forward(self, x):
if self.HW is not None:
N = self.N
C = self.C
else:
_, N, C = x.shape
qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C // self.num_heads)).permute((2, 0, 3, 1, 4))
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
attn = q.matmul(k.permute((0, 1, 3, 2)))
if self.mixer == "Local":
attn += self.mask
attn = functional.softmax(attn, dim=-1)
attn = self.attn_drop(attn)
x = (attn.matmul(v)).permute((0, 2, 1, 3)).reshape((-1, N, C))
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(
self,
dim,
num_heads,
mixer="Global",
local_mixer=(7, 11),
HW=(8, 25),
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer="nn.LayerNorm",
epsilon=1e-6,
prenorm=True,
):
super().__init__()
if isinstance(norm_layer, str):
self.norm1 = eval(norm_layer)(dim, eps=epsilon)
else:
self.norm1 = norm_layer(dim)
if mixer == "Global" or mixer == "Local":
self.mixer = Attention(
dim,
num_heads=num_heads,
mixer=mixer,
HW=HW,
local_k=local_mixer,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
elif mixer == "Conv":
self.mixer = ConvMixer(dim, num_heads=num_heads, HW=HW, local_k=local_mixer)
else:
raise TypeError("The mixer must be one of [Global, Local, Conv]")
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
if isinstance(norm_layer, str):
self.norm2 = eval(norm_layer)(dim, eps=epsilon)
else:
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp_ratio = mlp_ratio
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.prenorm = prenorm
def forward(self, x):
if self.prenorm:
x = self.norm1(x + self.drop_path(self.mixer(x)))
x = self.norm2(x + self.drop_path(self.mlp(x)))
else:
x = x + self.drop_path(self.mixer(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
"""Image to Patch Embedding"""
def __init__(self, img_size=(32, 100), in_channels=3, embed_dim=768, sub_num=2):
super().__init__()
num_patches = (img_size[1] // (2**sub_num)) * (img_size[0] // (2**sub_num))
self.img_size = img_size
self.num_patches = num_patches
self.embed_dim = embed_dim
self.norm = None
if sub_num == 2:
self.proj = nn.Sequential(
ConvBNLayer(
in_channels=in_channels,
out_channels=embed_dim // 2,
kernel_size=3,
stride=2,
padding=1,
act=nn.GELU,
bias_attr=False,
),
ConvBNLayer(
in_channels=embed_dim // 2,
out_channels=embed_dim,
kernel_size=3,
stride=2,
padding=1,
act=nn.GELU,
bias_attr=False,
),
)
if sub_num == 3:
self.proj = nn.Sequential(
ConvBNLayer(
in_channels=in_channels,
out_channels=embed_dim // 4,
kernel_size=3,
stride=2,
padding=1,
act=nn.GELU,
bias_attr=False,
),
ConvBNLayer(
in_channels=embed_dim // 4,
out_channels=embed_dim // 2,
kernel_size=3,
stride=2,
padding=1,
act=nn.GELU,
bias_attr=False,
),
ConvBNLayer(
in_channels=embed_dim // 2,
out_channels=embed_dim,
kernel_size=3,
stride=2,
padding=1,
act=nn.GELU,
bias_attr=False,
),
)
def forward(self, x):
B, C, H, W = x.shape
assert (
H == self.img_size[0] and W == self.img_size[1]
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).permute(0, 2, 1)
return x
class SubSample(nn.Module):
def __init__(self, in_channels, out_channels, types="Pool", stride=(2, 1), sub_norm="nn.LayerNorm", act=None):
super().__init__()
self.types = types
if types == "Pool":
self.avgpool = nn.AvgPool2d(kernel_size=(3, 5), stride=stride, padding=(1, 2))
self.maxpool = nn.MaxPool2d(kernel_size=(3, 5), stride=stride, padding=(1, 2))
self.proj = nn.Linear(in_channels, out_channels)
else:
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
# weight_attr=ParamAttr(initializer=KaimingNormal())
)
self.norm = eval(sub_norm)(out_channels)
if act is not None:
self.act = act()
else:
self.act = None
def forward(self, x):
if self.types == "Pool":
x1 = self.avgpool(x)
x2 = self.maxpool(x)
x = (x1 + x2) * 0.5
out = self.proj(x.flatten(2).permute((0, 2, 1)))
else:
x = self.conv(x)
out = x.flatten(2).permute((0, 2, 1))
out = self.norm(out)
if self.act is not None:
out = self.act(out)
return out
class SVTRNet(nn.Module):
def __init__(
self,
img_size=[48, 100],
in_channels=3,
embed_dim=[64, 128, 256],
depth=[3, 6, 3],
num_heads=[2, 4, 8],
mixer=["Local"] * 6 + ["Global"] * 6, # Local atten, Global atten, Conv
local_mixer=[[7, 11], [7, 11], [7, 11]],
patch_merging="Conv", # Conv, Pool, None
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.0,
last_drop=0.1,
attn_drop_rate=0.0,
drop_path_rate=0.1,
norm_layer="nn.LayerNorm",
sub_norm="nn.LayerNorm",
epsilon=1e-6,
out_channels=192,
out_char_num=25,
block_unit="Block",
act="nn.GELU",
last_stage=True,
sub_num=2,
prenorm=True,
use_lenhead=False,
**kwargs,
):
super().__init__()
self.img_size = img_size
self.embed_dim = embed_dim
self.out_channels = out_channels
self.prenorm = prenorm
patch_merging = None if patch_merging != "Conv" and patch_merging != "Pool" else patch_merging
self.patch_embed = PatchEmbed(
img_size=img_size, in_channels=in_channels, embed_dim=embed_dim[0], sub_num=sub_num
)
num_patches = self.patch_embed.num_patches
self.HW = [img_size[0] // (2**sub_num), img_size[1] // (2**sub_num)]
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim[0]))
# self.pos_embed = self.create_parameter(
# shape=[1, num_patches, embed_dim[0]], default_initializer=zeros_)
# self.add_parameter("pos_embed", self.pos_embed)
self.pos_drop = nn.Dropout(p=drop_rate)
Block_unit = eval(block_unit)
dpr = np.linspace(0, drop_path_rate, sum(depth))
self.blocks1 = nn.ModuleList(
[
Block_unit(
dim=embed_dim[0],
num_heads=num_heads[0],
mixer=mixer[0 : depth[0]][i],
HW=self.HW,
local_mixer=local_mixer[0],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
act_layer=eval(act),
attn_drop=attn_drop_rate,
drop_path=dpr[0 : depth[0]][i],
norm_layer=norm_layer,
epsilon=epsilon,
prenorm=prenorm,
)
for i in range(depth[0])
]
)
if patch_merging is not None:
self.sub_sample1 = SubSample(
embed_dim[0], embed_dim[1], sub_norm=sub_norm, stride=[2, 1], types=patch_merging
)
HW = [self.HW[0] // 2, self.HW[1]]
else:
HW = self.HW
self.patch_merging = patch_merging
self.blocks2 = nn.ModuleList(
[
Block_unit(
dim=embed_dim[1],
num_heads=num_heads[1],
mixer=mixer[depth[0] : depth[0] + depth[1]][i],
HW=HW,
local_mixer=local_mixer[1],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
act_layer=eval(act),
attn_drop=attn_drop_rate,
drop_path=dpr[depth[0] : depth[0] + depth[1]][i],
norm_layer=norm_layer,
epsilon=epsilon,
prenorm=prenorm,
)
for i in range(depth[1])
]
)
if patch_merging is not None:
self.sub_sample2 = SubSample(
embed_dim[1], embed_dim[2], sub_norm=sub_norm, stride=[2, 1], types=patch_merging
)
HW = [self.HW[0] // 4, self.HW[1]]
else:
HW = self.HW
self.blocks3 = nn.ModuleList(
[
Block_unit(
dim=embed_dim[2],
num_heads=num_heads[2],
mixer=mixer[depth[0] + depth[1] :][i],
HW=HW,
local_mixer=local_mixer[2],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
act_layer=eval(act),
attn_drop=attn_drop_rate,
drop_path=dpr[depth[0] + depth[1] :][i],
norm_layer=norm_layer,
epsilon=epsilon,
prenorm=prenorm,
)
for i in range(depth[2])
]
)
self.last_stage = last_stage
if last_stage:
self.avg_pool = nn.AdaptiveAvgPool2d((1, out_char_num))
self.last_conv = nn.Conv2d(
in_channels=embed_dim[2],
out_channels=self.out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False,
)
self.hardswish = nn.Hardswish()
self.dropout = nn.Dropout(p=last_drop)
if not prenorm:
self.norm = eval(norm_layer)(embed_dim[-1], epsilon=epsilon)
self.use_lenhead = use_lenhead
if use_lenhead:
self.len_conv = nn.Linear(embed_dim[2], self.out_channels)
self.hardswish_len = nn.Hardswish()
self.dropout_len = nn.Dropout(p=last_drop)
trunc_normal_(self.pos_embed, std=0.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
zeros_(m.bias)
ones_(m.weight)
def forward_features(self, x):
x = self.patch_embed(x)
x = x + self.pos_embed
x = self.pos_drop(x)
for blk in self.blocks1:
x = blk(x)
if self.patch_merging is not None:
x = self.sub_sample1(x.permute([0, 2, 1]).reshape([-1, self.embed_dim[0], self.HW[0], self.HW[1]]))
for blk in self.blocks2:
x = blk(x)
if self.patch_merging is not None:
x = self.sub_sample2(x.permute([0, 2, 1]).reshape([-1, self.embed_dim[1], self.HW[0] // 2, self.HW[1]]))
for blk in self.blocks3:
x = blk(x)
if not self.prenorm:
x = self.norm(x)
return x
def forward(self, x):
x = self.forward_features(x)
if self.use_lenhead:
len_x = self.len_conv(x.mean(1))
len_x = self.dropout_len(self.hardswish_len(len_x))
if self.last_stage:
if self.patch_merging is not None:
h = self.HW[0] // 4
else:
h = self.HW[0]
x = self.avg_pool(x.permute([0, 2, 1]).reshape([-1, self.embed_dim[2], h, self.HW[1]]))
x = self.last_conv(x)
x = self.hardswish(x)
x = self.dropout(x)
if self.use_lenhead:
return x, len_x
return x
if __name__ == "__main__":
a = torch.rand(1, 3, 48, 100)
svtr = SVTRNet()
out = svtr(a)
print(svtr)
print(out.size())
@@ -0,0 +1,74 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class Hswish(nn.Module):
def __init__(self, inplace=True):
super(Hswish, self).__init__()
self.inplace = inplace
def forward(self, x):
return x * F.relu6(x + 3.0, inplace=self.inplace) / 6.0
# out = max(0, min(1, slop*x+offset))
# paddle.fluid.layers.hard_sigmoid(x, slope=0.2, offset=0.5, name=None)
class Hsigmoid(nn.Module):
def __init__(self, inplace=True):
super(Hsigmoid, self).__init__()
self.inplace = inplace
def forward(self, x):
# torch: F.relu6(x + 3., inplace=self.inplace) / 6.
# paddle: F.relu6(1.2 * x + 3., inplace=self.inplace) / 6.
return F.relu6(1.2 * x + 3.0, inplace=self.inplace) / 6.0
class GELU(nn.Module):
def __init__(self, inplace=True):
super(GELU, self).__init__()
self.inplace = inplace
def forward(self, x):
return torch.nn.functional.gelu(x)
class Swish(nn.Module):
def __init__(self, inplace=True):
super(Swish, self).__init__()
self.inplace = inplace
def forward(self, x):
if self.inplace:
x.mul_(torch.sigmoid(x))
return x
else:
return x * torch.sigmoid(x)
class Activation(nn.Module):
def __init__(self, act_type, inplace=True):
super(Activation, self).__init__()
act_type = act_type.lower()
if act_type == "relu":
self.act = nn.ReLU(inplace=inplace)
elif act_type == "relu6":
self.act = nn.ReLU6(inplace=inplace)
elif act_type == "sigmoid":
raise NotImplementedError
elif act_type == "hard_sigmoid":
self.act = Hsigmoid(inplace)
elif act_type == "hard_swish":
self.act = Hswish(inplace=inplace)
elif act_type == "leakyrelu":
self.act = nn.LeakyReLU(inplace=inplace)
elif act_type == "gelu":
self.act = GELU(inplace=inplace)
elif act_type == "swish":
self.act = Swish(inplace=inplace)
else:
raise NotImplementedError
def forward(self, inputs):
return self.act(inputs)
@@ -0,0 +1,95 @@
0
1
2
3
4
5
6
7
8
9
:
;
<
=
>
?
@
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
[
\
]
^
_
`
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
{
|
}
~
!
"
#
$
%
&
'
(
)
*
+
,
-
.
/
@@ -6,4 +6,4 @@ torch==2.2.0
torchvision>=0.16
ftfy==6.1.1
tensorboard==2.14.0
Jinja2==3.1.5
Jinja2==3.1.6
@@ -137,7 +137,7 @@ def log_validation(
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
pipeline_args = {"prompt": args.validation_prompt}
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
@@ -1241,7 +1241,11 @@ def main(args):
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = (
torch.Generator(device=accelerator.device).manual_seed(args.seed)
if args.seed is not None
else None
)
pipeline_args = {"prompt": args.validation_prompt}
with autocast_ctx:
@@ -1305,7 +1309,9 @@ def main(args):
images = []
if args.validation_prompt and args.num_validation_images > 0:
pipeline = pipeline.to(accelerator.device)
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
generator = (
torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
)
with autocast_ctx:
images = [
+94 -19
View File
@@ -3,11 +3,19 @@ from typing import Any, Dict
import torch
from accelerate import init_empty_weights
from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer
from transformers import (
AutoModel,
AutoTokenizer,
CLIPImageProcessor,
CLIPTextModel,
CLIPTokenizer,
LlavaForConditionalGeneration,
)
from diffusers import (
AutoencoderKLHunyuanVideo,
FlowMatchEulerDiscreteScheduler,
HunyuanVideoImageToVideoPipeline,
HunyuanVideoPipeline,
HunyuanVideoTransformer3DModel,
)
@@ -134,6 +142,46 @@ VAE_KEYS_RENAME_DICT = {}
VAE_SPECIAL_KEYS_REMAP = {}
TRANSFORMER_CONFIGS = {
"HYVideo-T/2-cfgdistill": {
"in_channels": 16,
"out_channels": 16,
"num_attention_heads": 24,
"attention_head_dim": 128,
"num_layers": 20,
"num_single_layers": 40,
"num_refiner_layers": 2,
"mlp_ratio": 4.0,
"patch_size": 2,
"patch_size_t": 1,
"qk_norm": "rms_norm",
"guidance_embeds": True,
"text_embed_dim": 4096,
"pooled_projection_dim": 768,
"rope_theta": 256.0,
"rope_axes_dim": (16, 56, 56),
},
"HYVideo-T/2-I2V": {
"in_channels": 16 * 2 + 1,
"out_channels": 16,
"num_attention_heads": 24,
"attention_head_dim": 128,
"num_layers": 20,
"num_single_layers": 40,
"num_refiner_layers": 2,
"mlp_ratio": 4.0,
"patch_size": 2,
"patch_size_t": 1,
"qk_norm": "rms_norm",
"guidance_embeds": False,
"text_embed_dim": 4096,
"pooled_projection_dim": 768,
"rope_theta": 256.0,
"rope_axes_dim": (16, 56, 56),
},
}
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
state_dict[new_key] = state_dict.pop(old_key)
@@ -149,11 +197,12 @@ def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
return state_dict
def convert_transformer(ckpt_path: str):
def convert_transformer(ckpt_path: str, transformer_type: str):
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))
config = TRANSFORMER_CONFIGS[transformer_type]
with init_empty_weights():
transformer = HunyuanVideoTransformer3DModel()
transformer = HunyuanVideoTransformer3DModel(**config)
for key in list(original_state_dict.keys()):
new_key = key[:]
@@ -205,6 +254,10 @@ def get_args():
parser.add_argument("--save_pipeline", action="store_true")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
parser.add_argument(
"--transformer_type", type=str, default="HYVideo-T/2-cfgdistill", choices=list(TRANSFORMER_CONFIGS.keys())
)
parser.add_argument("--flow_shift", type=float, default=7.0)
return parser.parse_args()
@@ -228,7 +281,7 @@ if __name__ == "__main__":
assert args.text_encoder_2_path is not None
if args.transformer_ckpt_path is not None:
transformer = convert_transformer(args.transformer_ckpt_path)
transformer = convert_transformer(args.transformer_ckpt_path, args.transformer_type)
transformer = transformer.to(dtype=dtype)
if not args.save_pipeline:
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
@@ -239,19 +292,41 @@ if __name__ == "__main__":
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
if args.save_pipeline:
text_encoder = AutoModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right")
text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16)
tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path)
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
if args.transformer_type == "HYVideo-T/2-cfgdistill":
text_encoder = AutoModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right")
text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16)
tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path)
scheduler = FlowMatchEulerDiscreteScheduler(shift=args.flow_shift)
pipe = HunyuanVideoPipeline(
transformer=transformer,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
scheduler=scheduler,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
pipe = HunyuanVideoPipeline(
transformer=transformer,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
scheduler=scheduler,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
else:
text_encoder = LlavaForConditionalGeneration.from_pretrained(
args.text_encoder_path, torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right")
text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16)
tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path)
scheduler = FlowMatchEulerDiscreteScheduler(shift=args.flow_shift)
image_processor = CLIPImageProcessor.from_pretrained(args.text_encoder_path)
pipe = HunyuanVideoImageToVideoPipeline(
transformer=transformer,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
scheduler=scheduler,
image_processor=image_processor,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
+423
View File
@@ -0,0 +1,423 @@
import argparse
import pathlib
from typing import Any, Dict
import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download, snapshot_download
from safetensors.torch import load_file
from transformers import AutoProcessor, AutoTokenizer, CLIPVisionModelWithProjection, UMT5EncoderModel
from diffusers import (
AutoencoderKLWan,
UniPCMultistepScheduler,
WanImageToVideoPipeline,
WanPipeline,
WanTransformer3DModel,
)
TRANSFORMER_KEYS_RENAME_DICT = {
"time_embedding.0": "condition_embedder.time_embedder.linear_1",
"time_embedding.2": "condition_embedder.time_embedder.linear_2",
"text_embedding.0": "condition_embedder.text_embedder.linear_1",
"text_embedding.2": "condition_embedder.text_embedder.linear_2",
"time_projection.1": "condition_embedder.time_proj",
"head.modulation": "scale_shift_table",
"head.head": "proj_out",
"modulation": "scale_shift_table",
"ffn.0": "ffn.net.0.proj",
"ffn.2": "ffn.net.2",
# Hack to swap the layer names
# The original model calls the norms in following order: norm1, norm3, norm2
# We convert it to: norm1, norm2, norm3
"norm2": "norm__placeholder",
"norm3": "norm2",
"norm__placeholder": "norm3",
# For the I2V model
"img_emb.proj.0": "condition_embedder.image_embedder.norm1",
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
}
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
state_dict[new_key] = state_dict.pop(old_key)
def load_sharded_safetensors(dir: pathlib.Path):
file_paths = list(dir.glob("diffusion_pytorch_model*.safetensors"))
state_dict = {}
for path in file_paths:
state_dict.update(load_file(path))
return state_dict
def get_transformer_config(model_type: str) -> Dict[str, Any]:
if model_type == "Wan-T2V-1.3B":
config = {
"model_id": "StevenZhang/Wan2.1-T2V-1.3B-Diff",
"diffusers_config": {
"added_kv_proj_dim": None,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 8960,
"freq_dim": 256,
"in_channels": 16,
"num_attention_heads": 12,
"num_layers": 30,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
},
}
elif model_type == "Wan-T2V-14B":
config = {
"model_id": "StevenZhang/Wan2.1-T2V-14B-Diff",
"diffusers_config": {
"added_kv_proj_dim": None,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_channels": 16,
"num_attention_heads": 40,
"num_layers": 40,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
},
}
elif model_type == "Wan-I2V-14B-480p":
config = {
"model_id": "StevenZhang/Wan2.1-I2V-14B-480P-Diff",
"diffusers_config": {
"image_dim": 1280,
"added_kv_proj_dim": 5120,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_channels": 36,
"num_attention_heads": 40,
"num_layers": 40,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
},
}
elif model_type == "Wan-I2V-14B-720p":
config = {
"model_id": "StevenZhang/Wan2.1-I2V-14B-720P-Diff",
"diffusers_config": {
"image_dim": 1280,
"added_kv_proj_dim": 5120,
"attention_head_dim": 128,
"cross_attn_norm": True,
"eps": 1e-06,
"ffn_dim": 13824,
"freq_dim": 256,
"in_channels": 36,
"num_attention_heads": 40,
"num_layers": 40,
"out_channels": 16,
"patch_size": [1, 2, 2],
"qk_norm": "rms_norm_across_heads",
"text_dim": 4096,
},
}
return config
def convert_transformer(model_type: str):
config = get_transformer_config(model_type)
diffusers_config = config["diffusers_config"]
model_id = config["model_id"]
model_dir = pathlib.Path(snapshot_download(model_id, repo_type="model"))
original_state_dict = load_sharded_safetensors(model_dir)
with init_empty_weights():
transformer = WanTransformer3DModel.from_config(diffusers_config)
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_(original_state_dict, key, new_key)
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
return transformer
def convert_vae():
vae_ckpt_path = hf_hub_download("Wan-AI/Wan2.1-T2V-14B", "Wan2.1_VAE.pth")
old_state_dict = torch.load(vae_ckpt_path, weights_only=True)
new_state_dict = {}
# Create mappings for specific components
middle_key_mapping = {
# Encoder middle block
"encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma",
"encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias",
"encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight",
"encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma",
"encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias",
"encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight",
"encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma",
"encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias",
"encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight",
"encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma",
"encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias",
"encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight",
# Decoder middle block
"decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma",
"decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias",
"decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight",
"decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma",
"decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias",
"decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight",
"decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma",
"decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias",
"decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight",
"decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma",
"decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias",
"decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight",
}
# Create a mapping for attention blocks
attention_mapping = {
# Encoder middle attention
"encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma",
"encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight",
"encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias",
"encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight",
"encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias",
# Decoder middle attention
"decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma",
"decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight",
"decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias",
"decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight",
"decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias",
}
# Create a mapping for the head components
head_mapping = {
# Encoder head
"encoder.head.0.gamma": "encoder.norm_out.gamma",
"encoder.head.2.bias": "encoder.conv_out.bias",
"encoder.head.2.weight": "encoder.conv_out.weight",
# Decoder head
"decoder.head.0.gamma": "decoder.norm_out.gamma",
"decoder.head.2.bias": "decoder.conv_out.bias",
"decoder.head.2.weight": "decoder.conv_out.weight",
}
# Create a mapping for the quant components
quant_mapping = {
"conv1.weight": "quant_conv.weight",
"conv1.bias": "quant_conv.bias",
"conv2.weight": "post_quant_conv.weight",
"conv2.bias": "post_quant_conv.bias",
}
# Process each key in the state dict
for key, value in old_state_dict.items():
# Handle middle block keys using the mapping
if key in middle_key_mapping:
new_key = middle_key_mapping[key]
new_state_dict[new_key] = value
# Handle attention blocks using the mapping
elif key in attention_mapping:
new_key = attention_mapping[key]
new_state_dict[new_key] = value
# Handle head keys using the mapping
elif key in head_mapping:
new_key = head_mapping[key]
new_state_dict[new_key] = value
# Handle quant keys using the mapping
elif key in quant_mapping:
new_key = quant_mapping[key]
new_state_dict[new_key] = value
# Handle encoder conv1
elif key == "encoder.conv1.weight":
new_state_dict["encoder.conv_in.weight"] = value
elif key == "encoder.conv1.bias":
new_state_dict["encoder.conv_in.bias"] = value
# Handle decoder conv1
elif key == "decoder.conv1.weight":
new_state_dict["decoder.conv_in.weight"] = value
elif key == "decoder.conv1.bias":
new_state_dict["decoder.conv_in.bias"] = value
# Handle encoder downsamples
elif key.startswith("encoder.downsamples."):
# Convert to down_blocks
new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.")
# Convert residual block naming but keep the original structure
if ".residual.0.gamma" in new_key:
new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
elif ".residual.2.bias" in new_key:
new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
elif ".residual.2.weight" in new_key:
new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
elif ".residual.3.gamma" in new_key:
new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
elif ".residual.6.bias" in new_key:
new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
elif ".residual.6.weight" in new_key:
new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
elif ".shortcut.bias" in new_key:
new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
elif ".shortcut.weight" in new_key:
new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
new_state_dict[new_key] = value
# Handle decoder upsamples
elif key.startswith("decoder.upsamples."):
# Convert to up_blocks
parts = key.split(".")
block_idx = int(parts[2])
# Group residual blocks
if "residual" in key:
if block_idx in [0, 1, 2]:
new_block_idx = 0
resnet_idx = block_idx
elif block_idx in [4, 5, 6]:
new_block_idx = 1
resnet_idx = block_idx - 4
elif block_idx in [8, 9, 10]:
new_block_idx = 2
resnet_idx = block_idx - 8
elif block_idx in [12, 13, 14]:
new_block_idx = 3
resnet_idx = block_idx - 12
else:
# Keep as is for other blocks
new_state_dict[key] = value
continue
# Convert residual block naming
if ".residual.0.gamma" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma"
elif ".residual.2.bias" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias"
elif ".residual.2.weight" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight"
elif ".residual.3.gamma" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma"
elif ".residual.6.bias" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias"
elif ".residual.6.weight" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight"
else:
new_key = key
new_state_dict[new_key] = value
# Handle shortcut connections
elif ".shortcut." in key:
if block_idx == 4:
new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.")
new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1")
else:
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
new_key = new_key.replace(".shortcut.", ".conv_shortcut.")
new_state_dict[new_key] = value
# Handle upsamplers
elif ".resample." in key or ".time_conv." in key:
if block_idx == 3:
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0")
elif block_idx == 7:
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0")
elif block_idx == 11:
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0")
else:
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
new_state_dict[new_key] = value
else:
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
new_state_dict[new_key] = value
else:
# Keep other keys unchanged
new_state_dict[key] = value
with init_empty_weights():
vae = AutoencoderKLWan()
vae.load_state_dict(new_state_dict, strict=True, assign=True)
return vae
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_type", type=str, default=None)
parser.add_argument("--output_path", type=str, required=True)
parser.add_argument("--dtype", default="fp32")
return parser.parse_args()
DTYPE_MAPPING = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
if __name__ == "__main__":
args = get_args()
transformer = None
dtype = DTYPE_MAPPING[args.dtype]
transformer = convert_transformer(args.model_type).to(dtype=dtype)
vae = convert_vae()
text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl")
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
scheduler = UniPCMultistepScheduler(
prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=3.0
)
if "I2V" in args.model_type:
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
)
image_processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
pipe = WanImageToVideoPipeline(
transformer=transformer,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
scheduler=scheduler,
image_encoder=image_encoder,
image_processor=image_processor,
)
else:
pipe = WanPipeline(
transformer=transformer,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
scheduler=scheduler,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
+9
View File
@@ -128,6 +128,10 @@ _deps = [
"GitPython<3.1.19",
"scipy",
"onnx",
"optimum_quanto>=0.2.6",
"gguf>=0.10.0",
"torchao>=0.7.0",
"bitsandbytes>=0.43.3",
"regex!=2019.12.17",
"requests",
"tensorboard",
@@ -235,6 +239,11 @@ extras["test"] = deps_list(
)
extras["torch"] = deps_list("torch", "accelerate")
extras["bitsandbytes"] = deps_list("bitsandbytes", "accelerate")
extras["gguf"] = deps_list("gguf", "accelerate")
extras["optimum_quanto"] = deps_list("optimum_quanto", "accelerate")
extras["torchao"] = deps_list("torchao", "accelerate")
if os.name == "nt": # windows
extras["flax"] = [] # jax is not supported on windows
else:
+106 -2
View File
@@ -6,14 +6,19 @@ from .utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
is_accelerate_available,
is_bitsandbytes_available,
is_flax_available,
is_gguf_available,
is_k_diffusion_available,
is_librosa_available,
is_note_seq_available,
is_onnx_available,
is_optimum_quanto_available,
is_scipy_available,
is_sentencepiece_available,
is_torch_available,
is_torchao_available,
is_torchsde_available,
is_transformers_available,
)
@@ -32,7 +37,7 @@ _import_structure = {
"loaders": ["FromOriginalModelMixin"],
"models": [],
"pipelines": [],
"quantizers.quantization_config": ["BitsAndBytesConfig", "GGUFQuantizationConfig", "TorchAoConfig"],
"quantizers.quantization_config": [],
"schedulers": [],
"utils": [
"OptionalDependencyNotAvailable",
@@ -54,6 +59,54 @@ _import_structure = {
],
}
try:
if not is_torch_available() and not is_accelerate_available() and not is_bitsandbytes_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_bitsandbytes_objects
_import_structure["utils.dummy_bitsandbytes_objects"] = [
name for name in dir(dummy_bitsandbytes_objects) if not name.startswith("_")
]
else:
_import_structure["quantizers.quantization_config"].append("BitsAndBytesConfig")
try:
if not is_torch_available() and not is_accelerate_available() and not is_gguf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_gguf_objects
_import_structure["utils.dummy_gguf_objects"] = [
name for name in dir(dummy_gguf_objects) if not name.startswith("_")
]
else:
_import_structure["quantizers.quantization_config"].append("GGUFQuantizationConfig")
try:
if not is_torch_available() and not is_accelerate_available() and not is_torchao_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_torchao_objects
_import_structure["utils.dummy_torchao_objects"] = [
name for name in dir(dummy_torchao_objects) if not name.startswith("_")
]
else:
_import_structure["quantizers.quantization_config"].append("TorchAoConfig")
try:
if not is_torch_available() and not is_accelerate_available() and not is_optimum_quanto_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_optimum_quanto_objects
_import_structure["utils.dummy_optimum_quanto_objects"] = [
name for name in dir(dummy_optimum_quanto_objects) if not name.startswith("_")
]
else:
_import_structure["quantizers.quantization_config"].append("QuantoConfig")
try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
@@ -94,8 +147,10 @@ else:
"AutoencoderKLCogVideoX",
"AutoencoderKLHunyuanVideo",
"AutoencoderKLLTXVideo",
"AutoencoderKLMagvit",
"AutoencoderKLMochi",
"AutoencoderKLTemporalDecoder",
"AutoencoderKLWan",
"AutoencoderOobleck",
"AutoencoderTiny",
"CacheMixin",
@@ -108,6 +163,7 @@ else:
"ControlNetUnionModel",
"ControlNetXSAdapter",
"DiTTransformer2DModel",
"EasyAnimateTransformer3DModel",
"FluxControlNetModel",
"FluxMultiControlNetModel",
"FluxTransformer2DModel",
@@ -148,6 +204,7 @@ else:
"UNetSpatioTemporalConditionModel",
"UVit2DModel",
"VQModel",
"WanTransformer3DModel",
]
)
_import_structure["optimization"] = [
@@ -291,6 +348,9 @@ else:
"CogView4Pipeline",
"ConsisIDPipeline",
"CycleDiffusionPipeline",
"EasyAnimateControlPipeline",
"EasyAnimateInpaintPipeline",
"EasyAnimatePipeline",
"FluxControlImg2ImgPipeline",
"FluxControlInpaintPipeline",
"FluxControlNetImg2ImgPipeline",
@@ -306,6 +366,7 @@ else:
"HunyuanDiTPAGPipeline",
"HunyuanDiTPipeline",
"HunyuanSkyreelsImageToVideoPipeline",
"HunyuanVideoImageToVideoPipeline",
"HunyuanVideoPipeline",
"I2VGenXLPipeline",
"IFImg2ImgPipeline",
@@ -438,6 +499,8 @@ else:
"VersatileDiffusionTextToImagePipeline",
"VideoToVideoSDPipeline",
"VQDiffusionPipeline",
"WanImageToVideoPipeline",
"WanPipeline",
"WuerstchenCombinedPipeline",
"WuerstchenDecoderPipeline",
"WuerstchenPriorPipeline",
@@ -589,7 +652,38 @@ else:
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .configuration_utils import ConfigMixin
from .quantizers.quantization_config import BitsAndBytesConfig, GGUFQuantizationConfig, TorchAoConfig
try:
if not is_bitsandbytes_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_bitsandbytes_objects import *
else:
from .quantizers.quantization_config import BitsAndBytesConfig
try:
if not is_gguf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_gguf_objects import *
else:
from .quantizers.quantization_config import GGUFQuantizationConfig
try:
if not is_torchao_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_torchao_objects import *
else:
from .quantizers.quantization_config import TorchAoConfig
try:
if not is_optimum_quanto_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_optimum_quanto_objects import *
else:
from .quantizers.quantization_config import QuantoConfig
try:
if not is_onnx_available():
@@ -616,8 +710,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLCogVideoX,
AutoencoderKLHunyuanVideo,
AutoencoderKLLTXVideo,
AutoencoderKLMagvit,
AutoencoderKLMochi,
AutoencoderKLTemporalDecoder,
AutoencoderKLWan,
AutoencoderOobleck,
AutoencoderTiny,
CacheMixin,
@@ -630,6 +726,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ControlNetUnionModel,
ControlNetXSAdapter,
DiTTransformer2DModel,
EasyAnimateTransformer3DModel,
FluxControlNetModel,
FluxMultiControlNetModel,
FluxTransformer2DModel,
@@ -669,6 +766,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
UNetSpatioTemporalConditionModel,
UVit2DModel,
VQModel,
WanTransformer3DModel,
)
from .optimization import (
get_constant_schedule,
@@ -792,6 +890,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CogView4Pipeline,
ConsisIDPipeline,
CycleDiffusionPipeline,
EasyAnimateControlPipeline,
EasyAnimateInpaintPipeline,
EasyAnimatePipeline,
FluxControlImg2ImgPipeline,
FluxControlInpaintPipeline,
FluxControlNetImg2ImgPipeline,
@@ -807,6 +908,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanDiTPAGPipeline,
HunyuanDiTPipeline,
HunyuanSkyreelsImageToVideoPipeline,
HunyuanVideoImageToVideoPipeline,
HunyuanVideoPipeline,
I2VGenXLPipeline,
IFImg2ImgPipeline,
@@ -938,6 +1040,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
VersatileDiffusionTextToImagePipeline,
VideoToVideoSDPipeline,
VQDiffusionPipeline,
WanImageToVideoPipeline,
WanPipeline,
WuerstchenCombinedPipeline,
WuerstchenDecoderPipeline,
WuerstchenPriorPipeline,
@@ -35,6 +35,10 @@ deps = {
"GitPython": "GitPython<3.1.19",
"scipy": "scipy",
"onnx": "onnx",
"optimum_quanto": "optimum_quanto>=0.2.6",
"gguf": "gguf>=0.10.0",
"torchao": "torchao>=0.7.0",
"bitsandbytes": "bitsandbytes>=0.43.3",
"regex": "regex!=2019.12.17",
"requests": "requests",
"tensorboard": "tensorboard",
+132 -33
View File
@@ -29,11 +29,16 @@ if is_accelerate_available():
logger = get_logger(__name__) # pylint: disable=invalid-name
# Removed PinnedGroupManager - we no longer use pinned memory to avoid CPU memory spikes
# fmt: off
_GROUP_OFFLOADING = "group_offloading"
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
# Always use memory-efficient CPU offloading to minimize RAM usage
_SUPPORTED_PYTORCH_LAYERS = (
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
@@ -56,7 +61,6 @@ class ModuleGroup:
buffers: Optional[List[torch.Tensor]] = None,
non_blocking: bool = False,
stream: Optional[torch.cuda.Stream] = None,
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None,
onload_self: bool = True,
) -> None:
self.modules = modules
@@ -68,12 +72,8 @@ class ModuleGroup:
self.buffers = buffers
self.non_blocking = non_blocking or stream is not None
self.stream = stream
self.cpu_param_dict = cpu_param_dict
self.onload_self = onload_self
if self.stream is not None and self.cpu_param_dict is None:
raise ValueError("cpu_param_dict must be provided when using stream for data transfer.")
def onload_(self):
r"""Onloads the group of modules to the onload_device."""
context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream)
@@ -82,23 +82,125 @@ class ModuleGroup:
self.stream.synchronize()
with context:
for group_module in self.modules:
group_module.to(self.onload_device, non_blocking=self.non_blocking)
# Use the most efficient module-level transfer when possible
# This approach mirrors how PyTorch handles full model transfers
if self.modules:
for group_module in self.modules:
# Only onload if some parameters are not on the target device
if any(p.device != self.onload_device for p in group_module.parameters()):
try:
# Try the most efficient approach using _apply
if hasattr(group_module, "_apply"):
# This is what module.to() uses internally
def to_device(t):
if t.device != self.onload_device:
if self.onload_device.type == "cuda":
return t.cuda(self.onload_device.index,
non_blocking=self.non_blocking)
else:
return t.to(self.onload_device,
non_blocking=self.non_blocking)
return t
# Apply to all tensors without unnecessary copies
group_module._apply(to_device)
else:
# Fallback to direct parameter transfer
for param in group_module.parameters():
if param.device != self.onload_device:
if self.onload_device.type == "cuda":
param.data = param.data.cuda(self.onload_device.index,
non_blocking=self.non_blocking)
else:
param.data = param.data.to(self.onload_device,
non_blocking=self.non_blocking)
except Exception as e:
# If optimization fails, fall back to direct parameter transfer
logger.warning(f"Optimized onloading failed: {e}, falling back to direct method")
for param in group_module.parameters():
if param.device != self.onload_device:
if self.onload_device.type == "cuda":
param.data = param.data.cuda(self.onload_device.index,
non_blocking=self.non_blocking)
else:
param.data = param.data.to(self.onload_device,
non_blocking=self.non_blocking)
# Handle explicit parameters
if self.parameters is not None:
for param in self.parameters:
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
if param.device != self.onload_device:
if self.onload_device.type == "cuda":
param.data = param.data.cuda(self.onload_device.index,
non_blocking=self.non_blocking)
else:
param.data = param.data.to(self.onload_device,
non_blocking=self.non_blocking)
# Handle buffers
if self.buffers is not None:
for buffer in self.buffers:
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
if buffer.device != self.onload_device:
if self.onload_device.type == "cuda":
buffer.data = buffer.data.cuda(self.onload_device.index,
non_blocking=self.non_blocking)
else:
buffer.data = buffer.data.to(self.onload_device,
non_blocking=self.non_blocking)
def offload_(self):
r"""Offloads the group of modules to the offload_device."""
if self.stream is not None:
torch.cuda.current_stream().synchronize()
for group_module in self.modules:
for param in group_module.parameters():
param.data = self.cpu_param_dict[param]
# For CPU offloading
if self.offload_device.type == "cpu":
# Synchronize if using stream
if self.stream is not None:
torch.cuda.current_stream().synchronize()
# Empty GPU cache before offloading to reduce memory fragmentation
if torch.cuda.is_available():
torch.cuda.empty_cache()
# For module groups, use a single, unified approach that is closest to
# the behavior of model.to("cpu")
if self.modules:
for group_module in self.modules:
# Check if we need to offload this module
if any(p.device.type != "cpu" for p in group_module.parameters()):
# Use PyTorch's built-in to() method directly, which preserves
# memory mapping when moving to CPU
try:
# Non-blocking=False for CPU transfers, as it ensures memory is
# immediately available and potentially preserves memory mapping
group_module.to("cpu", non_blocking=False)
except Exception as e:
# If there's any error, fall back to parameter-level offloading
logger.warning(f"Module-level CPU offloading failed: {e}, falling back to parameter-level")
for param in group_module.parameters():
if param.device.type != "cpu":
param.data = param.data.to("cpu", non_blocking=False)
# Handle explicit parameters - move directly to CPU with non-blocking=False
# which can preserve memory mapping in some PyTorch versions
if self.parameters is not None:
for param in self.parameters:
if param.device.type != "cpu":
param.data = param.data.to("cpu", non_blocking=False)
# Handle buffers
if self.buffers is not None:
for buffer in self.buffers:
if buffer.device.type != "cpu":
buffer.data = buffer.data.to("cpu", non_blocking=False)
# Let Python's normal reference counting handle cleanup
# We don't force garbage collection to avoid slowing down inference
else:
# For non-CPU offloading, synchronize if using stream
if self.stream is not None:
torch.cuda.current_stream().synchronize()
# For non-CPU offloading, use the regular approach
for group_module in self.modules:
group_module.to(self.offload_device, non_blocking=self.non_blocking)
if self.parameters is not None:
@@ -108,6 +210,9 @@ class ModuleGroup:
for buffer in self.buffers:
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
# After offloading, we can unpin the memory if configured to do so
# We'll keep it pinned by default for better performance
class GroupOffloadingHook(ModelHook):
r"""
@@ -129,6 +234,7 @@ class GroupOffloadingHook(ModelHook):
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
if self.group.offload_leader == module:
# Offload to CPU
self.group.offload_()
return module
@@ -313,7 +419,8 @@ def apply_group_offloading(
If True, offloading and onloading is done with non-blocking data transfer.
use_stream (`bool`, defaults to `False`):
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
overlapping computation and data transfer.
overlapping computation and data transfer. Memory-efficient CPU offloading is automatically used
to minimize RAM usage by preserving memory mapping benefits and avoiding unnecessary copies.
Example:
```python
@@ -344,12 +451,19 @@ def apply_group_offloading(
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
# We no longer need a pinned group manager as we're not using pinned memory
if offload_type == "block_level":
if num_blocks_per_group is None:
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
_apply_group_offloading_block_level(
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream
module,
num_blocks_per_group,
offload_device,
onload_device,
non_blocking,
stream,
)
elif offload_type == "leaf_level":
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream)
@@ -384,12 +498,7 @@ def _apply_group_offloading_block_level(
for overlapping computation and data transfer.
"""
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
cpu_param_dict = None
if stream is not None:
for param in module.parameters():
param.data = param.data.cpu().pin_memory()
cpu_param_dict = {param: param.data for param in module.parameters()}
# We no longer need a CPU parameter dictionary
# Create module groups for ModuleList and Sequential blocks
modules_with_group_offloading = set()
@@ -411,7 +520,6 @@ def _apply_group_offloading_block_level(
onload_leader=current_modules[0],
non_blocking=non_blocking,
stream=stream,
cpu_param_dict=cpu_param_dict,
onload_self=stream is None,
)
matched_module_groups.append(group)
@@ -448,7 +556,6 @@ def _apply_group_offloading_block_level(
buffers=buffers,
non_blocking=False,
stream=None,
cpu_param_dict=None,
onload_self=True,
)
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
@@ -483,12 +590,7 @@ def _apply_group_offloading_leaf_level(
for overlapping computation and data transfer.
"""
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
cpu_param_dict = None
if stream is not None:
for param in module.parameters():
param.data = param.data.cpu().pin_memory()
cpu_param_dict = {param: param.data for param in module.parameters()}
# We no longer need a CPU parameter dictionary
# Create module groups for leaf modules and apply group offloading hooks
modules_with_group_offloading = set()
@@ -503,7 +605,6 @@ def _apply_group_offloading_leaf_level(
onload_leader=submodule,
non_blocking=non_blocking,
stream=stream,
cpu_param_dict=cpu_param_dict,
onload_self=True,
)
_apply_group_offloading_hook(submodule, group, None)
@@ -548,7 +649,6 @@ def _apply_group_offloading_leaf_level(
buffers=buffers,
non_blocking=non_blocking,
stream=stream,
cpu_param_dict=cpu_param_dict,
onload_self=True,
)
_apply_group_offloading_hook(parent_module, group, None)
@@ -567,7 +667,6 @@ def _apply_group_offloading_leaf_level(
buffers=None,
non_blocking=False,
stream=None,
cpu_param_dict=None,
onload_self=True,
)
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
+4
View File
@@ -70,10 +70,12 @@ if is_torch_available():
"LoraLoaderMixin",
"FluxLoraLoaderMixin",
"CogVideoXLoraLoaderMixin",
"CogView4LoraLoaderMixin",
"Mochi1LoraLoaderMixin",
"HunyuanVideoLoraLoaderMixin",
"SanaLoraLoaderMixin",
"Lumina2LoraLoaderMixin",
"WanLoraLoaderMixin",
]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = [
@@ -102,6 +104,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .lora_pipeline import (
AmusedLoraLoaderMixin,
CogVideoXLoraLoaderMixin,
CogView4LoraLoaderMixin,
FluxLoraLoaderMixin,
HunyuanVideoLoraLoaderMixin,
LoraLoaderMixin,
@@ -112,6 +115,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
SD3LoraLoaderMixin,
StableDiffusionLoraLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
WanLoraLoaderMixin,
)
from .single_file import FromSingleFileMixin
from .textual_inversion import TextualInversionLoaderMixin
+7 -5
View File
@@ -215,7 +215,8 @@ class IPAdapterMixin:
low_cpu_mem_usage=low_cpu_mem_usage,
cache_dir=cache_dir,
local_files_only=local_files_only,
).to(self.device, dtype=self.dtype)
torch_dtype=self.dtype,
).to(self.device)
self.register_modules(image_encoder=image_encoder)
else:
raise ValueError(
@@ -526,8 +527,9 @@ class FluxIPAdapterMixin:
low_cpu_mem_usage=low_cpu_mem_usage,
cache_dir=cache_dir,
local_files_only=local_files_only,
dtype=image_encoder_dtype,
)
.to(self.device, dtype=image_encoder_dtype)
.to(self.device)
.eval()
)
self.register_modules(image_encoder=image_encoder)
@@ -805,9 +807,9 @@ class SD3IPAdapterMixin:
feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs).to(
self.device, dtype=self.dtype
),
image_encoder=SiglipVisionModel.from_pretrained(image_encoder_subfolder, **kwargs).to(
self.device, dtype=self.dtype
),
image_encoder=SiglipVisionModel.from_pretrained(
image_encoder_subfolder, torch_dtype=self.dtype, **kwargs
).to(self.device),
)
else:
raise ValueError(
+70 -70
View File
@@ -339,93 +339,93 @@ def _load_lora_into_text_encoder(
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
prefix = text_encoder_name if prefix is None else prefix
# Safe prefix to check with.
if any(text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(prefix) and k.split(".")[0] == prefix]
text_encoder_lora_state_dict = {
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}
# Load the layers corresponding to text encoder and make necessary adjustments.
if prefix is not None:
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {prefix}.")
rank = {}
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
if len(state_dict) > 0:
logger.info(f"Loading {prefix}.")
rank = {}
state_dict = convert_state_dict_to_diffusers(state_dict)
# convert state dict
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
# convert state dict
state_dict = convert_state_dict_to_peft(state_dict)
for name, _ in text_encoder_attn_modules(text_encoder):
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
for name, _ in text_encoder_attn_modules(text_encoder):
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in state_dict:
continue
rank[rank_key] = state_dict[rank_key].shape[1]
for name, _ in text_encoder_mlp_modules(text_encoder):
for module in ("fc1", "fc2"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in text_encoder_lora_state_dict:
continue
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
for name, _ in text_encoder_mlp_modules(text_encoder):
for module in ("fc1", "fc2"):
rank_key = f"{name}.{module}.lora_B.weight"
if rank_key not in state_dict:
continue
rank[rank_key] = state_dict[rank_key].shape[1]
if network_alphas is not None:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
if network_alphas is not None:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
if "lora_bias" in lora_config_kwargs:
if lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError(
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")
if "lora_bias" in lora_config_kwargs:
if lora_config_kwargs["lora_bias"]:
if is_peft_version("<=", "0.13.2"):
raise ValueError(
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")
lora_config = LoraConfig(**lora_config_kwargs)
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
**peft_kwargs,
)
# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=state_dict,
peft_config=lora_config,
**peft_kwargs,
)
# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)
# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
if prefix is not None and not state_dict:
logger.info(
f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. This is safe to ignore if LoRA state dict didn't originally have any {text_encoder.__class__.__name__} related params. Open an issue if you think it's unexpected: https://github.com/huggingface/diffusers/issues/new"
)
def _func_optionally_disable_offloading(_pipeline):
@@ -654,6 +654,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
_convert(k, diffusers_key, state_dict, new_state_dict)
remaining_all_unet = False
if state_dict:
remaining_all_unet = all(k.startswith("lora_unet_") for k in state_dict)
if remaining_all_unet:
@@ -1276,3 +1277,74 @@ def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
return converted_state_dict
def _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict):
# Remove "diffusion_model." prefix from keys.
state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
converted_state_dict = {}
def get_num_layers(keys, pattern):
layers = set()
for key in keys:
match = re.search(pattern, key)
if match:
layers.add(int(match.group(1)))
return len(layers)
def process_block(prefix, index, convert_norm):
# Process attention qkv: pop lora_A and lora_B weights.
lora_down = state_dict.pop(f"{prefix}.{index}.attention.qkv.lora_A.weight")
lora_up = state_dict.pop(f"{prefix}.{index}.attention.qkv.lora_B.weight")
for attn_key in ["to_q", "to_k", "to_v"]:
converted_state_dict[f"{prefix}.{index}.attn.{attn_key}.lora_A.weight"] = lora_down
for attn_key, weight in zip(["to_q", "to_k", "to_v"], torch.split(lora_up, [2304, 768, 768], dim=0)):
converted_state_dict[f"{prefix}.{index}.attn.{attn_key}.lora_B.weight"] = weight
# Process attention out weights.
converted_state_dict[f"{prefix}.{index}.attn.to_out.0.lora_A.weight"] = state_dict.pop(
f"{prefix}.{index}.attention.out.lora_A.weight"
)
converted_state_dict[f"{prefix}.{index}.attn.to_out.0.lora_B.weight"] = state_dict.pop(
f"{prefix}.{index}.attention.out.lora_B.weight"
)
# Process feed-forward weights for layers 1, 2, and 3.
for layer in range(1, 4):
converted_state_dict[f"{prefix}.{index}.feed_forward.linear_{layer}.lora_A.weight"] = state_dict.pop(
f"{prefix}.{index}.feed_forward.w{layer}.lora_A.weight"
)
converted_state_dict[f"{prefix}.{index}.feed_forward.linear_{layer}.lora_B.weight"] = state_dict.pop(
f"{prefix}.{index}.feed_forward.w{layer}.lora_B.weight"
)
if convert_norm:
converted_state_dict[f"{prefix}.{index}.norm1.linear.lora_A.weight"] = state_dict.pop(
f"{prefix}.{index}.adaLN_modulation.1.lora_A.weight"
)
converted_state_dict[f"{prefix}.{index}.norm1.linear.lora_B.weight"] = state_dict.pop(
f"{prefix}.{index}.adaLN_modulation.1.lora_B.weight"
)
noise_refiner_pattern = r"noise_refiner\.(\d+)\."
num_noise_refiner_layers = get_num_layers(state_dict.keys(), noise_refiner_pattern)
for i in range(num_noise_refiner_layers):
process_block("noise_refiner", i, convert_norm=True)
context_refiner_pattern = r"context_refiner\.(\d+)\."
num_context_refiner_layers = get_num_layers(state_dict.keys(), context_refiner_pattern)
for i in range(num_context_refiner_layers):
process_block("context_refiner", i, convert_norm=False)
core_transformer_pattern = r"layers\.(\d+)\."
num_core_transformer_layers = get_num_layers(state_dict.keys(), core_transformer_pattern)
for i in range(num_core_transformer_layers):
process_block("layers", i, convert_norm=True)
if len(state_dict) > 0:
raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")
for key in list(converted_state_dict.keys()):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
return converted_state_dict
File diff suppressed because it is too large Load Diff
+12 -19
View File
@@ -53,6 +53,8 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
"SanaTransformer2DModel": lambda model_cls, weights: weights,
"Lumina2Transformer2DModel": lambda model_cls, weights: weights,
"WanTransformer3DModel": lambda model_cls, weights: weights,
"CogView4Transformer2DModel": lambda model_cls, weights: weights,
}
@@ -191,11 +193,6 @@ class PeftAdapterMixin:
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
from peft.tuners.tuners_utils import BaseTunerLayer
try:
from peft.utils.constants import FULLY_QUALIFIED_PATTERN_KEY_PREFIX
except ImportError:
FULLY_QUALIFIED_PATTERN_KEY_PREFIX = None
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
@@ -239,10 +236,7 @@ class PeftAdapterMixin:
raise ValueError("`network_alphas` cannot be None when `prefix` is None.")
if prefix is not None:
keys = list(state_dict.keys())
model_keys = [k for k in keys if k.startswith(f"{prefix}.")]
if len(model_keys) > 0:
state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in model_keys}
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
if len(state_dict) > 0:
if adapter_name in getattr(self, "peft_config", {}):
@@ -260,22 +254,16 @@ class PeftAdapterMixin:
# Cannot figure out rank from lora layers that don't have atleast 2 dimensions.
# Bias layers in LoRA only have a single dimension
if "lora_B" in key and val.ndim > 1:
# Support to handle cases where layer patterns are treated as full layer names
# was added later in PEFT. So, we handle it accordingly.
# TODO: when we fix the minimal PEFT version for Diffusers,
# we should remove `_maybe_adjust_config()`.
if FULLY_QUALIFIED_PATTERN_KEY_PREFIX:
rank[f"{FULLY_QUALIFIED_PATTERN_KEY_PREFIX}{key}"] = val.shape[1]
else:
rank[key] = val.shape[1]
# TODO: revisit this after https://github.com/huggingface/peft/pull/2382 is merged.
rank[key] = val.shape[1]
if network_alphas is not None and len(network_alphas) >= 1:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
if not FULLY_QUALIFIED_PATTERN_KEY_PREFIX:
lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs)
# TODO: revisit this after https://github.com/huggingface/peft/pull/2382 is merged.
lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
@@ -365,6 +353,11 @@ class PeftAdapterMixin:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
if prefix is not None and not state_dict:
logger.info(
f"No LoRA keys associated to {self.__class__.__name__} found with the {prefix=}. This is safe to ignore if LoRA state dict didn't originally have any {self.__class__.__name__} related params. Open an issue if you think it's unexpected: https://github.com/huggingface/diffusers/issues/new"
)
def save_lora_adapter(
self,
save_directory,
@@ -37,8 +37,11 @@ from .single_file_utils import (
convert_ltx_vae_checkpoint_to_diffusers,
convert_lumina2_to_diffusers,
convert_mochi_transformer_checkpoint_to_diffusers,
convert_sana_transformer_to_diffusers,
convert_sd3_transformer_checkpoint_to_diffusers,
convert_stable_cascade_unet_single_file_to_diffusers,
convert_wan_transformer_to_diffusers,
convert_wan_vae_to_diffusers,
create_controlnet_diffusers_config_from_ldm,
create_unet_diffusers_config_from_ldm,
create_vae_diffusers_config_from_ldm,
@@ -117,6 +120,18 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"checkpoint_mapping_fn": convert_lumina2_to_diffusers,
"default_subfolder": "transformer",
},
"SanaTransformer2DModel": {
"checkpoint_mapping_fn": convert_sana_transformer_to_diffusers,
"default_subfolder": "transformer",
},
"WanTransformer3DModel": {
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
"default_subfolder": "transformer",
},
"AutoencoderKLWan": {
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
"default_subfolder": "vae",
},
}
+450 -48
View File
@@ -117,6 +117,14 @@ CHECKPOINT_KEY_NAMES = {
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
"instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
"lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"],
"sana": [
"blocks.0.cross_attn.q_linear.weight",
"blocks.0.cross_attn.q_linear.bias",
"blocks.0.cross_attn.kv_linear.weight",
"blocks.0.cross_attn.kv_linear.bias",
],
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
"wan_vae": "decoder.middle.0.residual.0.gamma",
}
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -176,6 +184,10 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
"instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"},
"lumina2": {"pretrained_model_name_or_path": "Alpha-VLLM/Lumina-Image-2.0"},
"sana": {"pretrained_model_name_or_path": "Efficient-Large-Model/Sana_1600M_1024px_diffusers"},
"wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
}
# Use to configure model sample size when original config is provided
@@ -397,6 +409,7 @@ def load_single_file_checkpoint(
else:
repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path)
user_agent = {"file_type": "single_file", "framework": "pytorch"}
pretrained_model_link_or_path = _get_model_file(
repo_id,
weights_name=weights_name,
@@ -406,6 +419,7 @@ def load_single_file_checkpoint(
local_files_only=local_files_only,
token=token,
revision=revision,
user_agent=user_agent,
)
checkpoint = load_state_dict(pretrained_model_link_or_path, disable_mmap=disable_mmap)
@@ -662,6 +676,24 @@ def infer_diffusers_model_type(checkpoint):
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]):
model_type = "lumina2"
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sana"]):
model_type = "sana"
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["wan"]):
if "model.diffusion_model.patch_embedding.weight" in checkpoint:
target_key = "model.diffusion_model.patch_embedding.weight"
else:
target_key = "patch_embedding.weight"
if checkpoint[target_key].shape[0] == 1536:
model_type = "wan-t2v-1.3B"
elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16:
model_type = "wan-t2v-14B"
else:
model_type = "wan-i2v-14B"
elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint:
# All Wan models use the same VAE so we can use the same default model repo to fetch the config
model_type = "wan-t2v-14B"
else:
model_type = "v1"
@@ -1448,8 +1480,8 @@ def convert_open_clip_checkpoint(
if text_proj_key in checkpoint:
text_proj_dim = int(checkpoint[text_proj_key].shape[0])
elif hasattr(text_model.config, "projection_dim"):
text_proj_dim = text_model.config.projection_dim
elif hasattr(text_model.config, "hidden_size"):
text_proj_dim = text_model.config.hidden_size
else:
text_proj_dim = LDM_OPEN_CLIP_TEXT_PROJECTION_DIM
@@ -2468,7 +2500,7 @@ def convert_autoencoder_dc_checkpoint_to_diffusers(checkpoint, **kwargs):
def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
new_state_dict = {}
converted_state_dict = {}
# Comfy checkpoints add this prefix
keys = list(checkpoint.keys())
@@ -2477,22 +2509,22 @@ def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
# Convert patch_embed
new_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
new_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
# Convert time_embed
new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight")
new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight")
new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
new_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight")
new_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias")
new_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight")
new_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias")
new_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight")
new_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias")
new_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight")
new_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias")
converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight")
converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight")
converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
converted_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight")
converted_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias")
converted_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight")
converted_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias")
converted_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight")
converted_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias")
converted_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight")
converted_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias")
# Convert transformer blocks
num_layers = 48
@@ -2501,68 +2533,84 @@ def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
old_prefix = f"blocks.{i}."
# norm1
new_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight")
new_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias")
converted_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight")
converted_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias")
if i < num_layers - 1:
new_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop(old_prefix + "mod_y.weight")
new_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop(old_prefix + "mod_y.bias")
else:
new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop(
converted_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop(
old_prefix + "mod_y.weight"
)
new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop(old_prefix + "mod_y.bias")
converted_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop(
old_prefix + "mod_y.bias"
)
else:
converted_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop(
old_prefix + "mod_y.weight"
)
converted_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop(
old_prefix + "mod_y.bias"
)
# Visual attention
qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_x.weight")
q, k, v = qkv_weight.chunk(3, dim=0)
new_state_dict[block_prefix + "attn1.to_q.weight"] = q
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
new_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop(old_prefix + "attn.q_norm_x.weight")
new_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop(old_prefix + "attn.k_norm_x.weight")
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop(old_prefix + "attn.proj_x.weight")
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias")
converted_state_dict[block_prefix + "attn1.to_q.weight"] = q
converted_state_dict[block_prefix + "attn1.to_k.weight"] = k
converted_state_dict[block_prefix + "attn1.to_v.weight"] = v
converted_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop(
old_prefix + "attn.q_norm_x.weight"
)
converted_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop(
old_prefix + "attn.k_norm_x.weight"
)
converted_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop(
old_prefix + "attn.proj_x.weight"
)
converted_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias")
# Context attention
qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_y.weight")
q, k, v = qkv_weight.chunk(3, dim=0)
new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q
new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k
new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v
new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop(
converted_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q
converted_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k
converted_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v
converted_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop(
old_prefix + "attn.q_norm_y.weight"
)
new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop(
converted_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop(
old_prefix + "attn.k_norm_y.weight"
)
if i < num_layers - 1:
new_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop(
converted_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop(
old_prefix + "attn.proj_y.weight"
)
new_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop(old_prefix + "attn.proj_y.bias")
converted_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop(
old_prefix + "attn.proj_y.bias"
)
# MLP
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate(
converted_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate(
checkpoint.pop(old_prefix + "mlp_x.w1.weight")
)
new_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight")
converted_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight")
if i < num_layers - 1:
new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate(
converted_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate(
checkpoint.pop(old_prefix + "mlp_y.w1.weight")
)
new_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_y.w2.weight")
converted_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop(
old_prefix + "mlp_y.w2.weight"
)
# Output layers
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0)
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0)
new_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
new_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0)
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0)
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
new_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies")
converted_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies")
return new_state_dict
return converted_state_dict
def convert_hunyuan_video_transformer_to_diffusers(checkpoint, **kwargs):
@@ -2857,3 +2905,357 @@ def convert_lumina2_to_diffusers(checkpoint, **kwargs):
converted_state_dict[diffusers_key] = checkpoint.pop(key)
return converted_state_dict
def convert_sana_transformer_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
keys = list(checkpoint.keys())
for k in keys:
if "model.diffusion_model." in k:
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "blocks" in k))[-1] + 1 # noqa: C401
# Positional and patch embeddings.
checkpoint.pop("pos_embed")
converted_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
converted_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
# Timestep embeddings.
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = checkpoint.pop(
"t_embedder.mlp.0.weight"
)
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = checkpoint.pop(
"t_embedder.mlp.2.weight"
)
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
converted_state_dict["time_embed.linear.weight"] = checkpoint.pop("t_block.1.weight")
converted_state_dict["time_embed.linear.bias"] = checkpoint.pop("t_block.1.bias")
# Caption Projection.
checkpoint.pop("y_embedder.y_embedding")
converted_state_dict["caption_projection.linear_1.weight"] = checkpoint.pop("y_embedder.y_proj.fc1.weight")
converted_state_dict["caption_projection.linear_1.bias"] = checkpoint.pop("y_embedder.y_proj.fc1.bias")
converted_state_dict["caption_projection.linear_2.weight"] = checkpoint.pop("y_embedder.y_proj.fc2.weight")
converted_state_dict["caption_projection.linear_2.bias"] = checkpoint.pop("y_embedder.y_proj.fc2.bias")
converted_state_dict["caption_norm.weight"] = checkpoint.pop("attention_y_norm.weight")
for i in range(num_layers):
converted_state_dict[f"transformer_blocks.{i}.scale_shift_table"] = checkpoint.pop(
f"blocks.{i}.scale_shift_table"
)
# Self-Attention
sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"blocks.{i}.attn.qkv.weight"), 3, dim=0)
converted_state_dict[f"transformer_blocks.{i}.attn1.to_q.weight"] = torch.cat([sample_q])
converted_state_dict[f"transformer_blocks.{i}.attn1.to_k.weight"] = torch.cat([sample_k])
converted_state_dict[f"transformer_blocks.{i}.attn1.to_v.weight"] = torch.cat([sample_v])
# Output Projections
converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.weight"] = checkpoint.pop(
f"blocks.{i}.attn.proj.weight"
)
converted_state_dict[f"transformer_blocks.{i}.attn1.to_out.0.bias"] = checkpoint.pop(
f"blocks.{i}.attn.proj.bias"
)
# Cross-Attention
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.weight"] = checkpoint.pop(
f"blocks.{i}.cross_attn.q_linear.weight"
)
converted_state_dict[f"transformer_blocks.{i}.attn2.to_q.bias"] = checkpoint.pop(
f"blocks.{i}.cross_attn.q_linear.bias"
)
linear_sample_k, linear_sample_v = torch.chunk(
checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.weight"), 2, dim=0
)
linear_sample_k_bias, linear_sample_v_bias = torch.chunk(
checkpoint.pop(f"blocks.{i}.cross_attn.kv_linear.bias"), 2, dim=0
)
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.weight"] = linear_sample_k
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.weight"] = linear_sample_v
converted_state_dict[f"transformer_blocks.{i}.attn2.to_k.bias"] = linear_sample_k_bias
converted_state_dict[f"transformer_blocks.{i}.attn2.to_v.bias"] = linear_sample_v_bias
# Output Projections
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.weight"] = checkpoint.pop(
f"blocks.{i}.cross_attn.proj.weight"
)
converted_state_dict[f"transformer_blocks.{i}.attn2.to_out.0.bias"] = checkpoint.pop(
f"blocks.{i}.cross_attn.proj.bias"
)
# MLP
converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.weight"] = checkpoint.pop(
f"blocks.{i}.mlp.inverted_conv.conv.weight"
)
converted_state_dict[f"transformer_blocks.{i}.ff.conv_inverted.bias"] = checkpoint.pop(
f"blocks.{i}.mlp.inverted_conv.conv.bias"
)
converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.weight"] = checkpoint.pop(
f"blocks.{i}.mlp.depth_conv.conv.weight"
)
converted_state_dict[f"transformer_blocks.{i}.ff.conv_depth.bias"] = checkpoint.pop(
f"blocks.{i}.mlp.depth_conv.conv.bias"
)
converted_state_dict[f"transformer_blocks.{i}.ff.conv_point.weight"] = checkpoint.pop(
f"blocks.{i}.mlp.point_conv.conv.weight"
)
# Final layer
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
converted_state_dict["scale_shift_table"] = checkpoint.pop("final_layer.scale_shift_table")
return converted_state_dict
def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
keys = list(checkpoint.keys())
for k in keys:
if "model.diffusion_model." in k:
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
TRANSFORMER_KEYS_RENAME_DICT = {
"time_embedding.0": "condition_embedder.time_embedder.linear_1",
"time_embedding.2": "condition_embedder.time_embedder.linear_2",
"text_embedding.0": "condition_embedder.text_embedder.linear_1",
"text_embedding.2": "condition_embedder.text_embedder.linear_2",
"time_projection.1": "condition_embedder.time_proj",
"cross_attn": "attn2",
"self_attn": "attn1",
".o.": ".to_out.0.",
".q.": ".to_q.",
".k.": ".to_k.",
".v.": ".to_v.",
".k_img.": ".add_k_proj.",
".v_img.": ".add_v_proj.",
".norm_k_img.": ".norm_added_k.",
"head.modulation": "scale_shift_table",
"head.head": "proj_out",
"modulation": "scale_shift_table",
"ffn.0": "ffn.net.0.proj",
"ffn.2": "ffn.net.2",
# Hack to swap the layer names
# The original model calls the norms in following order: norm1, norm3, norm2
# We convert it to: norm1, norm2, norm3
"norm2": "norm__placeholder",
"norm3": "norm2",
"norm__placeholder": "norm3",
# For the I2V model
"img_emb.proj.0": "condition_embedder.image_embedder.norm1",
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
}
for key in list(checkpoint.keys()):
new_key = key[:]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
converted_state_dict[new_key] = checkpoint.pop(key)
return converted_state_dict
def convert_wan_vae_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
# Create mappings for specific components
middle_key_mapping = {
# Encoder middle block
"encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma",
"encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias",
"encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight",
"encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma",
"encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias",
"encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight",
"encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma",
"encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias",
"encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight",
"encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma",
"encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias",
"encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight",
# Decoder middle block
"decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma",
"decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias",
"decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight",
"decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma",
"decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias",
"decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight",
"decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma",
"decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias",
"decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight",
"decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma",
"decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias",
"decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight",
}
# Create a mapping for attention blocks
attention_mapping = {
# Encoder middle attention
"encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma",
"encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight",
"encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias",
"encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight",
"encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias",
# Decoder middle attention
"decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma",
"decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight",
"decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias",
"decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight",
"decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias",
}
# Create a mapping for the head components
head_mapping = {
# Encoder head
"encoder.head.0.gamma": "encoder.norm_out.gamma",
"encoder.head.2.bias": "encoder.conv_out.bias",
"encoder.head.2.weight": "encoder.conv_out.weight",
# Decoder head
"decoder.head.0.gamma": "decoder.norm_out.gamma",
"decoder.head.2.bias": "decoder.conv_out.bias",
"decoder.head.2.weight": "decoder.conv_out.weight",
}
# Create a mapping for the quant components
quant_mapping = {
"conv1.weight": "quant_conv.weight",
"conv1.bias": "quant_conv.bias",
"conv2.weight": "post_quant_conv.weight",
"conv2.bias": "post_quant_conv.bias",
}
# Process each key in the state dict
for key, value in checkpoint.items():
# Handle middle block keys using the mapping
if key in middle_key_mapping:
new_key = middle_key_mapping[key]
converted_state_dict[new_key] = value
# Handle attention blocks using the mapping
elif key in attention_mapping:
new_key = attention_mapping[key]
converted_state_dict[new_key] = value
# Handle head keys using the mapping
elif key in head_mapping:
new_key = head_mapping[key]
converted_state_dict[new_key] = value
# Handle quant keys using the mapping
elif key in quant_mapping:
new_key = quant_mapping[key]
converted_state_dict[new_key] = value
# Handle encoder conv1
elif key == "encoder.conv1.weight":
converted_state_dict["encoder.conv_in.weight"] = value
elif key == "encoder.conv1.bias":
converted_state_dict["encoder.conv_in.bias"] = value
# Handle decoder conv1
elif key == "decoder.conv1.weight":
converted_state_dict["decoder.conv_in.weight"] = value
elif key == "decoder.conv1.bias":
converted_state_dict["decoder.conv_in.bias"] = value
# Handle encoder downsamples
elif key.startswith("encoder.downsamples."):
# Convert to down_blocks
new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.")
# Convert residual block naming but keep the original structure
if ".residual.0.gamma" in new_key:
new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
elif ".residual.2.bias" in new_key:
new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
elif ".residual.2.weight" in new_key:
new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
elif ".residual.3.gamma" in new_key:
new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
elif ".residual.6.bias" in new_key:
new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
elif ".residual.6.weight" in new_key:
new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
elif ".shortcut.bias" in new_key:
new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
elif ".shortcut.weight" in new_key:
new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
converted_state_dict[new_key] = value
# Handle decoder upsamples
elif key.startswith("decoder.upsamples."):
# Convert to up_blocks
parts = key.split(".")
block_idx = int(parts[2])
# Group residual blocks
if "residual" in key:
if block_idx in [0, 1, 2]:
new_block_idx = 0
resnet_idx = block_idx
elif block_idx in [4, 5, 6]:
new_block_idx = 1
resnet_idx = block_idx - 4
elif block_idx in [8, 9, 10]:
new_block_idx = 2
resnet_idx = block_idx - 8
elif block_idx in [12, 13, 14]:
new_block_idx = 3
resnet_idx = block_idx - 12
else:
# Keep as is for other blocks
converted_state_dict[key] = value
continue
# Convert residual block naming
if ".residual.0.gamma" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma"
elif ".residual.2.bias" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias"
elif ".residual.2.weight" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight"
elif ".residual.3.gamma" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma"
elif ".residual.6.bias" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias"
elif ".residual.6.weight" in key:
new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight"
else:
new_key = key
converted_state_dict[new_key] = value
# Handle shortcut connections
elif ".shortcut." in key:
if block_idx == 4:
new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.")
new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1")
else:
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
new_key = new_key.replace(".shortcut.", ".conv_shortcut.")
converted_state_dict[new_key] = value
# Handle upsamplers
elif ".resample." in key or ".time_conv." in key:
if block_idx == 3:
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0")
elif block_idx == 7:
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0")
elif block_idx == 11:
new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0")
else:
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
converted_state_dict[new_key] = value
else:
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
converted_state_dict[new_key] = value
else:
# Keep other keys unchanged
converted_state_dict[key] = value
return converted_state_dict
Regular → Executable
+8
View File
@@ -33,8 +33,10 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
_import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
_import_structure["autoencoders.autoencoder_kl_wan"] = ["AutoencoderKLWan"]
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
@@ -71,6 +73,7 @@ if is_torch_available():
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
_import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"]
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
@@ -79,6 +82,7 @@ if is_torch_available():
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
_import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"]
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
_import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
@@ -107,8 +111,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLCogVideoX,
AutoencoderKLHunyuanVideo,
AutoencoderKLLTXVideo,
AutoencoderKLMagvit,
AutoencoderKLMochi,
AutoencoderKLTemporalDecoder,
AutoencoderKLWan,
AutoencoderOobleck,
AutoencoderTiny,
ConsistencyDecoderVAE,
@@ -141,6 +147,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ConsisIDTransformer3DModel,
DiTTransformer2DModel,
DualTransformer2DModel,
EasyAnimateTransformer3DModel,
FluxTransformer2DModel,
HunyuanDiT2DModel,
HunyuanVideoTransformer3DModel,
@@ -158,6 +165,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
T5FilmDecoder,
Transformer2DModel,
TransformerTemporalModel,
WanTransformer3DModel,
)
from .unets import (
I2VGenXLUNet,
+9 -1
View File
@@ -274,12 +274,20 @@ class Attention(nn.Module):
self.to_add_out = None
if qk_norm is not None and added_kv_proj_dim is not None:
if qk_norm == "fp32_layer_norm":
if qk_norm == "layer_norm":
self.norm_added_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_added_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
elif qk_norm == "fp32_layer_norm":
self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
elif qk_norm == "rms_norm":
self.norm_added_q = RMSNorm(dim_head, eps=eps)
self.norm_added_k = RMSNorm(dim_head, eps=eps)
elif qk_norm == "rms_norm_across_heads":
# Wan applies qk norm across all heads
# Wan also doesn't apply a q norm
self.norm_added_q = None
self.norm_added_k = RMSNorm(dim_head * kv_heads, eps=eps)
else:
raise ValueError(
f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`"
@@ -5,8 +5,10 @@ from .autoencoder_kl_allegro import AutoencoderKLAllegro
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
from .autoencoder_kl_magvit import AutoencoderKLMagvit
from .autoencoder_kl_mochi import AutoencoderKLMochi
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
from .autoencoder_kl_wan import AutoencoderKLWan
from .autoencoder_oobleck import AutoencoderOobleck
from .autoencoder_tiny import AutoencoderTiny
from .consistency_decoder_vae import ConsistencyDecoderVAE
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,855 @@
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils import logging
from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
CACHE_T = 2
class WanCausalConv3d(nn.Conv3d):
r"""
A custom 3D causal convolution layer with feature caching support.
This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature
caching for efficient inference.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
) -> None:
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
)
# Set up causal padding
self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
self.padding = (0, 0, 0)
def forward(self, x, cache_x=None):
padding = list(self._padding)
if cache_x is not None and self._padding[4] > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
padding[4] -= cache_x.shape[2]
x = F.pad(x, padding)
return super().forward(x)
class WanRMS_norm(nn.Module):
r"""
A custom RMS normalization layer.
Args:
dim (int): The number of dimensions to normalize over.
channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
Default is True.
images (bool, optional): Whether the input represents image data. Default is True.
bias (bool, optional): Whether to include a learnable bias term. Default is False.
"""
def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
class WanUpsample(nn.Upsample):
r"""
Perform upsampling while ensuring the output tensor has the same data type as the input.
Args:
x (torch.Tensor): Input tensor to be upsampled.
Returns:
torch.Tensor: Upsampled tensor with the same data type as the input.
"""
def forward(self, x):
return super().forward(x.float()).type_as(x)
class WanResample(nn.Module):
r"""
A custom resampling module for 2D and 3D data.
Args:
dim (int): The number of input/output channels.
mode (str): The resampling mode. Must be one of:
- 'none': No resampling (identity operation).
- 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.
- 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.
- 'downsample2d': 2D downsampling with zero-padding and convolution.
- 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
"""
def __init__(self, dim: int, mode: str) -> None:
super().__init__()
self.dim = dim
self.mode = mode
# layers
if mode == "upsample2d":
self.resample = nn.Sequential(
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
)
elif mode == "upsample3d":
self.resample = nn.Sequential(
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
)
self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
elif mode == "downsample2d":
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
elif mode == "downsample3d":
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
self.time_conv = WanCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
else:
self.resample = nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
b, c, t, h, w = x.size()
if self.mode == "upsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = "Rep"
feat_idx[0] += 1
else:
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
# cache last frame of last two chunk
cache_x = torch.cat(
[feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
)
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
if feat_cache[idx] == "Rep":
x = self.time_conv(x)
else:
x = self.time_conv(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
x = x.reshape(b, c, t * 2, h, w)
t = x.shape[2]
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
x = self.resample(x)
x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)
if self.mode == "downsample3d":
if feat_cache is not None:
idx = feat_idx[0]
if feat_cache[idx] is None:
feat_cache[idx] = x.clone()
feat_idx[0] += 1
else:
cache_x = x[:, :, -1:, :, :].clone()
x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x
feat_idx[0] += 1
return x
class WanResidualBlock(nn.Module):
r"""
A custom residual block module.
Args:
in_dim (int): Number of input channels.
out_dim (int): Number of output channels.
dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.
non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
"""
def __init__(
self,
in_dim: int,
out_dim: int,
dropout: float = 0.0,
non_linearity: str = "silu",
) -> None:
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.nonlinearity = get_activation(non_linearity)
# layers
self.norm1 = WanRMS_norm(in_dim, images=False)
self.conv1 = WanCausalConv3d(in_dim, out_dim, 3, padding=1)
self.norm2 = WanRMS_norm(out_dim, images=False)
self.dropout = nn.Dropout(dropout)
self.conv2 = WanCausalConv3d(out_dim, out_dim, 3, padding=1)
self.conv_shortcut = WanCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]):
# Apply shortcut connection
h = self.conv_shortcut(x)
# First normalization and activation
x = self.norm1(x)
x = self.nonlinearity(x)
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv1(x)
# Second normalization and activation
x = self.norm2(x)
x = self.nonlinearity(x)
# Dropout
x = self.dropout(x)
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv2(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv2(x)
# Add residual connection
return x + h
class WanAttentionBlock(nn.Module):
r"""
Causal self-attention with a single head.
Args:
dim (int): The number of channels in the input tensor.
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
# layers
self.norm = WanRMS_norm(dim)
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
self.proj = nn.Conv2d(dim, dim, 1)
def forward(self, x):
identity = x
batch_size, channels, time, height, width = x.size()
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width)
x = self.norm(x)
# compute query, key, value
qkv = self.to_qkv(x)
qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
qkv = qkv.permute(0, 1, 3, 2).contiguous()
q, k, v = qkv.chunk(3, dim=-1)
# apply attention
x = F.scaled_dot_product_attention(q, k, v)
x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width)
# output projection
x = self.proj(x)
# Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w]
x = x.view(batch_size, time, channels, height, width)
x = x.permute(0, 2, 1, 3, 4)
return x + identity
class WanMidBlock(nn.Module):
"""
Middle block for WanVAE encoder and decoder.
Args:
dim (int): Number of input/output channels.
dropout (float): Dropout rate.
non_linearity (str): Type of non-linearity to use.
"""
def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1):
super().__init__()
self.dim = dim
# Create the components
resnets = [WanResidualBlock(dim, dim, dropout, non_linearity)]
attentions = []
for _ in range(num_layers):
attentions.append(WanAttentionBlock(dim))
resnets.append(WanResidualBlock(dim, dim, dropout, non_linearity))
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0]):
# First residual block
x = self.resnets[0](x, feat_cache, feat_idx)
# Process through attention and residual blocks
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
x = attn(x)
x = resnet(x, feat_cache, feat_idx)
return x
class WanEncoder3d(nn.Module):
r"""
A 3D encoder module.
Args:
dim (int): The base number of channels in the first layer.
z_dim (int): The dimensionality of the latent space.
dim_mult (list of int): Multipliers for the number of channels in each block.
num_res_blocks (int): Number of residual blocks in each block.
attn_scales (list of float): Scales at which to apply attention mechanisms.
temperal_downsample (list of bool): Whether to downsample temporally in each block.
dropout (float): Dropout rate for the dropout layers.
non_linearity (str): Type of non-linearity to use.
"""
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[True, True, False],
dropout=0.0,
non_linearity: str = "silu",
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_downsample = temperal_downsample
self.nonlinearity = get_activation(non_linearity)
# dimensions
dims = [dim * u for u in [1] + dim_mult]
scale = 1.0
# init block
self.conv_in = WanCausalConv3d(3, dims[0], 3, padding=1)
# downsample blocks
self.down_blocks = nn.ModuleList([])
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
for _ in range(num_res_blocks):
self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
if scale in attn_scales:
self.down_blocks.append(WanAttentionBlock(out_dim))
in_dim = out_dim
# downsample block
if i != len(dim_mult) - 1:
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
self.down_blocks.append(WanResample(out_dim, mode=mode))
scale /= 2.0
# middle blocks
self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1)
# output blocks
self.norm_out = WanRMS_norm(out_dim, images=False)
self.conv_out = WanCausalConv3d(out_dim, z_dim, 3, padding=1)
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0]):
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv_in(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv_in(x)
## downsamples
for layer in self.down_blocks:
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## middle
x = self.mid_block(x, feat_cache, feat_idx)
## head
x = self.norm_out(x)
x = self.nonlinearity(x)
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv_out(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv_out(x)
return x
class WanUpBlock(nn.Module):
"""
A block that handles upsampling for the WanVAE decoder.
Args:
in_dim (int): Input dimension
out_dim (int): Output dimension
num_res_blocks (int): Number of residual blocks
dropout (float): Dropout rate
upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')
non_linearity (str): Type of non-linearity to use
"""
def __init__(
self,
in_dim: int,
out_dim: int,
num_res_blocks: int,
dropout: float = 0.0,
upsample_mode: Optional[str] = None,
non_linearity: str = "silu",
):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
# Create layers list
resnets = []
# Add residual blocks and attention if needed
current_dim = in_dim
for _ in range(num_res_blocks + 1):
resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity))
current_dim = out_dim
self.resnets = nn.ModuleList(resnets)
# Add upsampling layer if needed
self.upsamplers = None
if upsample_mode is not None:
self.upsamplers = nn.ModuleList([WanResample(out_dim, mode=upsample_mode)])
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0]):
"""
Forward pass through the upsampling block.
Args:
x (torch.Tensor): Input tensor
feat_cache (list, optional): Feature cache for causal convolutions
feat_idx (list, optional): Feature index for cache management
Returns:
torch.Tensor: Output tensor
"""
for resnet in self.resnets:
if feat_cache is not None:
x = resnet(x, feat_cache, feat_idx)
else:
x = resnet(x)
if self.upsamplers is not None:
if feat_cache is not None:
x = self.upsamplers[0](x, feat_cache, feat_idx)
else:
x = self.upsamplers[0](x)
return x
class WanDecoder3d(nn.Module):
r"""
A 3D decoder module.
Args:
dim (int): The base number of channels in the first layer.
z_dim (int): The dimensionality of the latent space.
dim_mult (list of int): Multipliers for the number of channels in each block.
num_res_blocks (int): Number of residual blocks in each block.
attn_scales (list of float): Scales at which to apply attention mechanisms.
temperal_upsample (list of bool): Whether to upsample temporally in each block.
dropout (float): Dropout rate for the dropout layers.
non_linearity (str): Type of non-linearity to use.
"""
def __init__(
self,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_upsample=[False, True, True],
dropout=0.0,
non_linearity: str = "silu",
):
super().__init__()
self.dim = dim
self.z_dim = z_dim
self.dim_mult = dim_mult
self.num_res_blocks = num_res_blocks
self.attn_scales = attn_scales
self.temperal_upsample = temperal_upsample
self.nonlinearity = get_activation(non_linearity)
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
scale = 1.0 / 2 ** (len(dim_mult) - 2)
# init block
self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1)
# middle blocks
self.mid_block = WanMidBlock(dims[0], dropout, non_linearity, num_layers=1)
# upsample blocks
self.up_blocks = nn.ModuleList([])
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
if i > 0:
in_dim = in_dim // 2
# Determine if we need upsampling
upsample_mode = None
if i != len(dim_mult) - 1:
upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
# Create and add the upsampling block
up_block = WanUpBlock(
in_dim=in_dim,
out_dim=out_dim,
num_res_blocks=num_res_blocks,
dropout=dropout,
upsample_mode=upsample_mode,
non_linearity=non_linearity,
)
self.up_blocks.append(up_block)
# Update scale for next iteration
if upsample_mode is not None:
scale *= 2.0
# output blocks
self.norm_out = WanRMS_norm(out_dim, images=False)
self.conv_out = WanCausalConv3d(out_dim, 3, 3, padding=1)
self.gradient_checkpointing = False
def forward(self, x, feat_cache=None, feat_idx=[0]):
## conv1
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv_in(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv_in(x)
## middle
x = self.mid_block(x, feat_cache, feat_idx)
## upsamples
for up_block in self.up_blocks:
x = up_block(x, feat_cache, feat_idx)
## head
x = self.norm_out(x)
x = self.nonlinearity(x)
if feat_cache is not None:
idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.conv_out(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = self.conv_out(x)
return x
class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
Introduced in [Wan 2.1].
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
"""
_supports_gradient_checkpointing = False
@register_to_config
def __init__(
self,
base_dim: int = 96,
z_dim: int = 16,
dim_mult: Tuple[int] = [1, 2, 4, 4],
num_res_blocks: int = 2,
attn_scales: List[float] = [],
temperal_downsample: List[bool] = [False, True, True],
dropout: float = 0.0,
latents_mean: List[float] = [
-0.7571,
-0.7089,
-0.9113,
0.1075,
-0.1745,
0.9653,
-0.1517,
1.5508,
0.4134,
-0.0715,
0.5517,
-0.3632,
-0.1922,
-0.9497,
0.2503,
-0.2921,
],
latents_std: List[float] = [
2.8184,
1.4541,
2.3275,
2.6558,
1.2196,
1.7708,
2.6052,
2.0743,
3.2687,
2.1526,
2.8652,
1.5579,
1.6382,
1.1253,
2.8251,
1.9160,
],
) -> None:
super().__init__()
self.z_dim = z_dim
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
self.encoder = WanEncoder3d(
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
)
self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1)
self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1)
self.decoder = WanDecoder3d(
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
)
def clear_cache(self):
def _count_conv3d(model):
count = 0
for m in model.modules():
if isinstance(m, WanCausalConv3d):
count += 1
return count
self._conv_num = _count_conv3d(self.decoder)
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
# cache encode
self._enc_conv_num = _count_conv3d(self.encoder)
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num
def _encode(self, x: torch.Tensor) -> torch.Tensor:
self.clear_cache()
## cache
t = x.shape[2]
iter_ = 1 + (t - 1) // 4
for i in range(iter_):
self._enc_conv_idx = [0]
if i == 0:
out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
else:
out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
feat_cache=self._enc_feat_map,
feat_idx=self._enc_conv_idx,
)
out = torch.cat([out, out_], 2)
enc = self.quant_conv(out)
mu, logvar = enc[:, : self.z_dim, :, :, :], enc[:, self.z_dim :, :, :, :]
enc = torch.cat([mu, logvar], dim=1)
self.clear_cache()
return enc
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
r"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded videos. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
self.clear_cache()
iter_ = z.shape[2]
x = self.post_quant_conv(z)
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
else:
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2)
out = torch.clamp(out, min=-1.0, max=1.0)
self.clear_cache()
if not return_dict:
return (out,)
return DecoderOutput(sample=out)
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.Tensor]:
"""
Args:
sample (`torch.Tensor`): Input sample.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z, return_dict=return_dict)
return dec
@@ -605,12 +605,13 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
controlnet_cond: List[torch.Tensor],
control_type: torch.Tensor,
control_type_idx: List[int],
conditioning_scale: float = 1.0,
conditioning_scale: Union[float, List[float]] = 1.0,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
from_multi: bool = False,
guess_mode: bool = False,
return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
@@ -647,6 +648,8 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
Additional conditions for the Stable Diffusion XL UNet.
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
from_multi (`bool`, defaults to `False`):
Use standard scaling when called from `MultiControlNetUnionModel`.
guess_mode (`bool`, defaults to `False`):
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
@@ -658,6 +661,9 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
returned where the first element is the sample tensor.
"""
if isinstance(conditioning_scale, float):
conditioning_scale = [conditioning_scale] * len(controlnet_cond)
# check channel order
channel_order = self.config.controlnet_conditioning_channel_order
@@ -742,12 +748,16 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
inputs = []
condition_list = []
for cond, control_idx in zip(controlnet_cond, control_type_idx):
for cond, control_idx, scale in zip(controlnet_cond, control_type_idx, conditioning_scale):
condition = self.controlnet_cond_embedding(cond)
feat_seq = torch.mean(condition, dim=(2, 3))
feat_seq = feat_seq + self.task_embedding[control_idx]
inputs.append(feat_seq.unsqueeze(1))
condition_list.append(condition)
if from_multi:
inputs.append(feat_seq.unsqueeze(1))
condition_list.append(condition)
else:
inputs.append(feat_seq.unsqueeze(1) * scale)
condition_list.append(condition * scale)
condition = sample
feat_seq = torch.mean(condition, dim=(2, 3))
@@ -759,10 +769,13 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
x = layer(x)
controlnet_cond_fuser = sample * 0.0
for idx, condition in enumerate(condition_list[:-1]):
for (idx, condition), scale in zip(enumerate(condition_list[:-1]), conditioning_scale):
alpha = self.spatial_ch_projs(x[:, idx])
alpha = alpha.unsqueeze(-1).unsqueeze(-1)
controlnet_cond_fuser += condition + alpha
if from_multi:
controlnet_cond_fuser += condition + alpha
else:
controlnet_cond_fuser += condition + alpha * scale
sample = sample + controlnet_cond_fuser
@@ -806,12 +819,13 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
# 6. scaling
if guess_mode and not self.config.global_pool_conditions:
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
scales = scales * conditioning_scale
if from_multi:
scales = scales * conditioning_scale[0]
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
else:
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
mid_block_res_sample = mid_block_res_sample * conditioning_scale
elif from_multi:
down_block_res_samples = [sample * conditioning_scale[0] for sample in down_block_res_samples]
mid_block_res_sample = mid_block_res_sample * conditioning_scale[0]
if self.config.global_pool_conditions:
down_block_res_samples = [
@@ -47,9 +47,12 @@ class MultiControlNetUnionModel(ModelMixin):
guess_mode: bool = False,
return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple]:
down_block_res_samples, mid_block_res_sample = None, None
for i, (image, ctype, ctype_idx, scale, controlnet) in enumerate(
zip(controlnet_cond, control_type, control_type_idx, conditioning_scale, self.nets)
):
if scale == 0.0:
continue
down_samples, mid_sample = controlnet(
sample=sample,
timestep=timestep,
@@ -63,12 +66,13 @@ class MultiControlNetUnionModel(ModelMixin):
attention_mask=attention_mask,
added_cond_kwargs=added_cond_kwargs,
cross_attention_kwargs=cross_attention_kwargs,
from_multi=True,
guess_mode=guess_mode,
return_dict=return_dict,
)
# merge samples
if i == 0:
if down_block_res_samples is None and mid_block_res_sample is None:
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
else:
down_block_res_samples = [
+6 -1
View File
@@ -245,6 +245,9 @@ def load_model_dict_into_meta(
):
param = param.to(torch.float32)
set_module_kwargs["dtype"] = torch.float32
# For quantizers have save weights using torch.float8_e4m3fn
elif hf_quantizer is not None and param.dtype == getattr(torch, "float8_e4m3fn", None):
pass
else:
param = param.to(dtype)
set_module_kwargs["dtype"] = dtype
@@ -292,7 +295,9 @@ def load_model_dict_into_meta(
elif is_quantized and (
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=param_device)
):
hf_quantizer.create_quantized_param(model, param, param_name, param_device, state_dict, unexpected_keys)
hf_quantizer.create_quantized_param(
model, param, param_name, param_device, state_dict, unexpected_keys, dtype=dtype
)
else:
set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs)
+5 -1
View File
@@ -166,8 +166,12 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
# 2. If no dtype modifying hooks are attached, return the dtype of the first floating point parameter/buffer
last_dtype = None
for param in parameter.parameters():
for name, param in parameter.named_parameters():
last_dtype = param.dtype
if parameter._keep_in_fp32_modules and any(m in name for m in parameter._keep_in_fp32_modules):
continue
if param.is_floating_point():
return param.dtype
+2
View File
@@ -19,6 +19,7 @@ if is_torch_available():
from .transformer_allegro import AllegroTransformer3DModel
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
from .transformer_cogview4 import CogView4Transformer2DModel
from .transformer_easyanimate import EasyAnimateTransformer3DModel
from .transformer_flux import FluxTransformer2DModel
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
from .transformer_ltx import LTXVideoTransformer3DModel
@@ -27,3 +28,4 @@ if is_torch_available():
from .transformer_omnigen import OmniGenTransformer2DModel
from .transformer_sd3 import SD3Transformer2DModel
from .transformer_temporal import TransformerTemporalModel
from .transformer_wan import WanTransformer3DModel
@@ -18,7 +18,7 @@ import torch
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention_processor import (
Attention,
@@ -195,7 +195,7 @@ class SanaTransformerBlock(nn.Module):
return hidden_states
class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
r"""
A 2D Transformer model introduced in [Sana](https://huggingface.co/papers/2410.10629) family of models.
@@ -12,20 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.attention import FeedForward
from ...models.attention_processor import Attention
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous
from ...utils import logging
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import FeedForward
from ..attention_processor import Attention
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -244,30 +245,34 @@ class CogView4RotaryPosEmbed(nn.Module):
def __init__(self, dim: int, patch_size: int, rope_axes_dim: Tuple[int, int], theta: float = 10000.0) -> None:
super().__init__()
self.dim = dim
self.patch_size = patch_size
self.rope_axes_dim = rope_axes_dim
dim_h, dim_w = dim // 2, dim // 2
h_inv_freq = 1.0 / (theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h))
w_inv_freq = 1.0 / (theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w))
h_seq = torch.arange(self.rope_axes_dim[0])
w_seq = torch.arange(self.rope_axes_dim[1])
self.freqs_h = torch.outer(h_seq, h_inv_freq)
self.freqs_w = torch.outer(w_seq, w_inv_freq)
self.theta = theta
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, num_channels, height, width = hidden_states.shape
height, width = height // self.patch_size, width // self.patch_size
h_idx = torch.arange(height)
w_idx = torch.arange(width)
dim_h, dim_w = self.dim // 2, self.dim // 2
h_inv_freq = 1.0 / (
self.theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h)
)
w_inv_freq = 1.0 / (
self.theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w)
)
h_seq = torch.arange(self.rope_axes_dim[0])
w_seq = torch.arange(self.rope_axes_dim[1])
freqs_h = torch.outer(h_seq, h_inv_freq)
freqs_w = torch.outer(w_seq, w_inv_freq)
h_idx = torch.arange(height, device=freqs_h.device)
w_idx = torch.arange(width, device=freqs_w.device)
inner_h_idx = h_idx * self.rope_axes_dim[0] // height
inner_w_idx = w_idx * self.rope_axes_dim[1] // width
self.freqs_h = self.freqs_h.to(hidden_states.device)
self.freqs_w = self.freqs_w.to(hidden_states.device)
freqs_h = self.freqs_h[inner_h_idx]
freqs_w = self.freqs_w[inner_w_idx]
freqs_h = freqs_h[inner_h_idx]
freqs_w = freqs_w[inner_w_idx]
# Create position matrices for height and width
# [height, 1, dim//4] and [1, width, dim//4]
@@ -284,7 +289,7 @@ class CogView4RotaryPosEmbed(nn.Module):
return (freqs.cos(), freqs.sin())
class CogView4Transformer2DModel(ModelMixin, ConfigMixin):
class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
r"""
Args:
patch_size (`int`, defaults to `2`):
@@ -379,8 +384,24 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin):
original_size: torch.Tensor,
target_size: torch.Tensor,
crop_coords: torch.Tensor,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, height, width = hidden_states.shape
# 1. RoPE
@@ -415,6 +436,10 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin):
hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p)
output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
@@ -0,0 +1,527 @@
# Copyright 2025 The EasyAnimate team and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, FeedForward
from ..embeddings import TimestepEmbedding, Timesteps, get_3d_rotary_pos_embed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNorm, FP32LayerNorm, RMSNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class EasyAnimateLayerNormZero(nn.Module):
def __init__(
self,
conditioning_dim: int,
embedding_dim: int,
elementwise_affine: bool = True,
eps: float = 1e-5,
bias: bool = True,
norm_type: str = "fp32_layer_norm",
) -> None:
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias)
if norm_type == "layer_norm":
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps)
elif norm_type == "fp32_layer_norm":
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps)
else:
raise ValueError(
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
)
def forward(
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
hidden_states = self.norm(hidden_states) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale.unsqueeze(1)) + enc_shift.unsqueeze(
1
)
return hidden_states, encoder_hidden_states, gate, enc_gate
class EasyAnimateRotaryPosEmbed(nn.Module):
def __init__(self, patch_size: int, rope_dim: List[int]) -> None:
super().__init__()
self.patch_size = patch_size
self.rope_dim = rope_dim
def get_resize_crop_region_for_grid(self, src, tgt_width, tgt_height):
tw = tgt_width
th = tgt_height
h, w = src
r = h / w
if r > (th / tw):
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h))
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
bs, c, num_frames, grid_height, grid_width = hidden_states.size()
grid_height = grid_height // self.patch_size
grid_width = grid_width // self.patch_size
base_size_width = 90 // self.patch_size
base_size_height = 60 // self.patch_size
grid_crops_coords = self.get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height
)
image_rotary_emb = get_3d_rotary_pos_embed(
self.rope_dim,
grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=hidden_states.size(2),
use_real=True,
)
return image_rotary_emb
class EasyAnimateAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the EasyAnimateTransformer3DModel model.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"EasyAnimateAttnProcessor2_0 requires PyTorch 2.0 or above. To use it, please install PyTorch 2.0."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if attn.add_q_proj is None and encoder_hidden_states is not None:
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
# 1. QKV projections
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
# 2. QK normalization
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# 3. Encoder condition QKV projection and normalization
if attn.add_q_proj is not None and encoder_hidden_states is not None:
encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)
encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_query = attn.norm_added_q(encoder_query)
if attn.norm_added_k is not None:
encoder_key = attn.norm_added_k(encoder_key)
query = torch.cat([encoder_query, query], dim=2)
key = torch.cat([encoder_key, key], dim=2)
value = torch.cat([encoder_value, value], dim=2)
if image_rotary_emb is not None:
from ..embeddings import apply_rotary_emb
query[:, :, encoder_hidden_states.shape[1] :] = apply_rotary_emb(
query[:, :, encoder_hidden_states.shape[1] :], image_rotary_emb
)
if not attn.is_cross_attention:
key[:, :, encoder_hidden_states.shape[1] :] = apply_rotary_emb(
key[:, :, encoder_hidden_states.shape[1] :], image_rotary_emb
)
# 5. Attention
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
# 6. Output projection
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
if getattr(attn, "to_out", None) is not None:
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
if getattr(attn, "to_add_out", None) is not None:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
else:
if getattr(attn, "to_out", None) is not None:
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states, encoder_hidden_states
@maybe_allow_in_graph
class EasyAnimateTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
time_embed_dim: int,
dropout: float = 0.0,
activation_fn: str = "gelu-approximate",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-6,
final_dropout: bool = True,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
qk_norm: bool = True,
after_norm: bool = False,
norm_type: str = "fp32_layer_norm",
is_mmdit_block: bool = True,
):
super().__init__()
# Attention Part
self.norm1 = EasyAnimateLayerNormZero(
time_embed_dim, dim, norm_elementwise_affine, norm_eps, norm_type=norm_type, bias=True
)
self.attn1 = Attention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
qk_norm="layer_norm" if qk_norm else None,
eps=1e-6,
bias=True,
added_proj_bias=True,
added_kv_proj_dim=dim if is_mmdit_block else None,
context_pre_only=False if is_mmdit_block else None,
processor=EasyAnimateAttnProcessor2_0(),
)
# FFN Part
self.norm2 = EasyAnimateLayerNormZero(
time_embed_dim, dim, norm_elementwise_affine, norm_eps, norm_type=norm_type, bias=True
)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
self.txt_ff = None
if is_mmdit_block:
self.txt_ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
self.norm3 = None
if after_norm:
self.norm3 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Attention
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
hidden_states, encoder_hidden_states, temb
)
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_hidden_states
encoder_hidden_states = encoder_hidden_states + enc_gate_msa.unsqueeze(1) * attn_encoder_hidden_states
# 2. Feed-forward
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
hidden_states, encoder_hidden_states, temb
)
if self.norm3 is not None:
norm_hidden_states = self.norm3(self.ff(norm_hidden_states))
if self.txt_ff is not None:
norm_encoder_hidden_states = self.norm3(self.txt_ff(norm_encoder_hidden_states))
else:
norm_encoder_hidden_states = self.norm3(self.ff(norm_encoder_hidden_states))
else:
norm_hidden_states = self.ff(norm_hidden_states)
if self.txt_ff is not None:
norm_encoder_hidden_states = self.txt_ff(norm_encoder_hidden_states)
else:
norm_encoder_hidden_states = self.ff(norm_encoder_hidden_states)
hidden_states = hidden_states + gate_ff.unsqueeze(1) * norm_hidden_states
encoder_hidden_states = encoder_hidden_states + enc_gate_ff.unsqueeze(1) * norm_encoder_hidden_states
return hidden_states, encoder_hidden_states
class EasyAnimateTransformer3DModel(ModelMixin, ConfigMixin):
"""
A Transformer model for video-like data in [EasyAnimate](https://github.com/aigc-apps/EasyAnimate).
Parameters:
num_attention_heads (`int`, defaults to `48`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`, defaults to `64`):
The number of channels in each head.
in_channels (`int`, defaults to `16`):
The number of channels in the input.
out_channels (`int`, *optional*, defaults to `16`):
The number of channels in the output.
patch_size (`int`, defaults to `2`):
The size of the patches to use in the patch embedding layer.
sample_width (`int`, defaults to `90`):
The width of the input latents.
sample_height (`int`, defaults to `60`):
The height of the input latents.
activation_fn (`str`, defaults to `"gelu-approximate"`):
Activation function to use in feed-forward.
timestep_activation_fn (`str`, defaults to `"silu"`):
Activation function to use when generating the timestep embeddings.
num_layers (`int`, defaults to `30`):
The number of layers of Transformer blocks to use.
mmdit_layers (`int`, defaults to `1000`):
The number of layers of Multi Modal Transformer blocks to use.
dropout (`float`, defaults to `0.0`):
The dropout probability to use.
time_embed_dim (`int`, defaults to `512`):
Output dimension of timestep embeddings.
text_embed_dim (`int`, defaults to `4096`):
Input dimension of text embeddings from the text encoder.
norm_eps (`float`, defaults to `1e-5`):
The epsilon value to use in normalization layers.
norm_elementwise_affine (`bool`, defaults to `True`):
Whether to use elementwise affine in normalization layers.
flip_sin_to_cos (`bool`, defaults to `True`):
Whether to flip the sin to cos in the time embedding.
time_position_encoding_type (`str`, defaults to `3d_rope`):
Type of time position encoding.
after_norm (`bool`, defaults to `False`):
Flag to apply normalization after.
resize_inpaint_mask_directly (`bool`, defaults to `True`):
Flag to resize inpaint mask directly.
enable_text_attention_mask (`bool`, defaults to `True`):
Flag to enable text attention mask.
add_noise_in_inpaint_model (`bool`, defaults to `False`):
Flag to add noise in inpaint model.
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["EasyAnimateTransformerBlock"]
_skip_layerwise_casting_patterns = ["^proj$", "norm", "^proj_out$"]
@register_to_config
def __init__(
self,
num_attention_heads: int = 48,
attention_head_dim: int = 64,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
patch_size: Optional[int] = None,
sample_width: int = 90,
sample_height: int = 60,
activation_fn: str = "gelu-approximate",
timestep_activation_fn: str = "silu",
freq_shift: int = 0,
num_layers: int = 48,
mmdit_layers: int = 48,
dropout: float = 0.0,
time_embed_dim: int = 512,
add_norm_text_encoder: bool = False,
text_embed_dim: int = 3584,
text_embed_dim_t5: int = None,
norm_eps: float = 1e-5,
norm_elementwise_affine: bool = True,
flip_sin_to_cos: bool = True,
time_position_encoding_type: str = "3d_rope",
after_norm=False,
resize_inpaint_mask_directly: bool = True,
enable_text_attention_mask: bool = True,
add_noise_in_inpaint_model: bool = True,
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
# 1. Timestep embedding
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
self.rope_embedding = EasyAnimateRotaryPosEmbed(patch_size, attention_head_dim)
# 2. Patch embedding
self.proj = nn.Conv2d(
in_channels, inner_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=True
)
# 3. Text refined embedding
self.text_proj = None
self.text_proj_t5 = None
if not add_norm_text_encoder:
self.text_proj = nn.Linear(text_embed_dim, inner_dim)
if text_embed_dim_t5 is not None:
self.text_proj_t5 = nn.Linear(text_embed_dim_t5, inner_dim)
else:
self.text_proj = nn.Sequential(
RMSNorm(text_embed_dim, 1e-6, elementwise_affine=True), nn.Linear(text_embed_dim, inner_dim)
)
if text_embed_dim_t5 is not None:
self.text_proj_t5 = nn.Sequential(
RMSNorm(text_embed_dim, 1e-6, elementwise_affine=True), nn.Linear(text_embed_dim_t5, inner_dim)
)
# 4. Transformer blocks
self.transformer_blocks = nn.ModuleList(
[
EasyAnimateTransformerBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
dropout=dropout,
activation_fn=activation_fn,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
after_norm=after_norm,
is_mmdit_block=True if _ < mmdit_layers else False,
)
for _ in range(num_layers)
]
)
self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
# 5. Output norm & projection
self.norm_out = AdaLayerNorm(
embedding_dim=time_embed_dim,
output_dim=2 * inner_dim,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
chunk_dim=1,
)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
timestep_cond: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_hidden_states_t5: Optional[torch.Tensor] = None,
inpaint_latents: Optional[torch.Tensor] = None,
control_latents: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
batch_size, channels, video_length, height, width = hidden_states.size()
p = self.config.patch_size
post_patch_height = height // p
post_patch_width = width // p
# 1. Time embedding
temb = self.time_proj(timestep).to(dtype=hidden_states.dtype)
temb = self.time_embedding(temb, timestep_cond)
image_rotary_emb = self.rope_embedding(hidden_states)
# 2. Patch embedding
if inpaint_latents is not None:
hidden_states = torch.concat([hidden_states, inpaint_latents], 1)
if control_latents is not None:
hidden_states = torch.concat([hidden_states, control_latents], 1)
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) # [B, C, F, H, W] -> [BF, C, H, W]
hidden_states = self.proj(hidden_states)
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(
0, 2, 1, 3, 4
) # [BF, C, H, W] -> [B, F, C, H, W]
hidden_states = hidden_states.flatten(2, 4).transpose(1, 2) # [B, F, C, H, W] -> [B, FHW, C]
# 3. Text embedding
encoder_hidden_states = self.text_proj(encoder_hidden_states)
if encoder_hidden_states_t5 is not None:
encoder_hidden_states_t5 = self.text_proj_t5(encoder_hidden_states_t5)
encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=1).contiguous()
# 4. Transformer blocks
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block, hidden_states, encoder_hidden_states, temb, image_rotary_emb
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, temb, image_rotary_emb
)
hidden_states = self.norm_final(hidden_states)
# 5. Output norm & projection
hidden_states = self.norm_out(hidden_states, temb=temb)
hidden_states = self.proj_out(hidden_states)
# 6. Unpatchify
p = self.config.patch_size
output = hidden_states.reshape(batch_size, video_length, post_patch_height, post_patch_width, channels, p, p)
output = output.permute(0, 4, 1, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
@@ -581,7 +581,11 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
self.context_embedder = HunyuanVideoTokenRefiner(
text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
)
self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
if guidance_embeds:
self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
else:
self.time_text_embed = CombinedTimestepTextProjEmbeddings(inner_dim, pooled_projection_dim)
# 2. RoPE
self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
@@ -708,7 +712,11 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
image_rotary_emb = self.rope(hidden_states)
# 2. Conditional embeddings
temb = self.time_text_embed(timestep, guidance, pooled_projections)
if self.config.guidance_embeds:
temb = self.time_text_embed(timestep, guidance, pooled_projections)
else:
temb = self.time_text_embed(timestep, pooled_projections)
hidden_states = self.x_embedder(hidden_states)
encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
@@ -0,0 +1,460 @@
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import FeedForward
from ..attention_processor import Attention
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import FP32LayerNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class WanAttnProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
encoder_hidden_states_img = None
if attn.add_k_proj is not None:
encoder_hidden_states_img = encoder_hidden_states[:, :257]
encoder_hidden_states = encoder_hidden_states[:, 257:]
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
if rotary_emb is not None:
def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
return x_out.type_as(hidden_states)
query = apply_rotary_emb(query, rotary_emb)
key = apply_rotary_emb(key, rotary_emb)
# I2V task
hidden_states_img = None
if encoder_hidden_states_img is not None:
key_img = attn.add_k_proj(encoder_hidden_states_img)
key_img = attn.norm_added_k(key_img)
value_img = attn.add_v_proj(encoder_hidden_states_img)
key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
hidden_states_img = F.scaled_dot_product_attention(
query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
)
hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
hidden_states_img = hidden_states_img.type_as(query)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.type_as(query)
if hidden_states_img is not None:
hidden_states = hidden_states + hidden_states_img
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class WanImageEmbedding(torch.nn.Module):
def __init__(self, in_features: int, out_features: int):
super().__init__()
self.norm1 = FP32LayerNorm(in_features)
self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
self.norm2 = FP32LayerNorm(out_features)
def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
hidden_states = self.norm1(encoder_hidden_states_image)
hidden_states = self.ff(hidden_states)
hidden_states = self.norm2(hidden_states)
return hidden_states
class WanTimeTextImageEmbedding(nn.Module):
def __init__(
self,
dim: int,
time_freq_dim: int,
time_proj_dim: int,
text_embed_dim: int,
image_embed_dim: Optional[int] = None,
):
super().__init__()
self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
self.act_fn = nn.SiLU()
self.time_proj = nn.Linear(dim, time_proj_dim)
self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
self.image_embedder = None
if image_embed_dim is not None:
self.image_embedder = WanImageEmbedding(image_embed_dim, dim)
def forward(
self,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_image: Optional[torch.Tensor] = None,
):
timestep = self.timesteps_proj(timestep)
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
timestep = timestep.to(time_embedder_dtype)
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
timestep_proj = self.time_proj(self.act_fn(temb))
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
if encoder_hidden_states_image is not None:
encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
class WanRotaryPosEmbed(nn.Module):
def __init__(
self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0
):
super().__init__()
self.attention_head_dim = attention_head_dim
self.patch_size = patch_size
self.max_seq_len = max_seq_len
h_dim = w_dim = 2 * (attention_head_dim // 6)
t_dim = attention_head_dim - h_dim - w_dim
freqs = []
for dim in [t_dim, h_dim, w_dim]:
freq = get_1d_rotary_pos_embed(
dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64
)
freqs.append(freq)
self.freqs = torch.cat(freqs, dim=1)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.patch_size
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
self.freqs = self.freqs.to(hidden_states.device)
freqs = self.freqs.split_with_sizes(
[
self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
self.attention_head_dim // 6,
self.attention_head_dim // 6,
],
dim=1,
)
freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
return freqs
class WanTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
ffn_dim: int,
num_heads: int,
qk_norm: str = "rms_norm_across_heads",
cross_attn_norm: bool = False,
eps: float = 1e-6,
added_kv_proj_dim: Optional[int] = None,
):
super().__init__()
# 1. Self-attention
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
self.attn1 = Attention(
query_dim=dim,
heads=num_heads,
kv_heads=num_heads,
dim_head=dim // num_heads,
qk_norm=qk_norm,
eps=eps,
bias=True,
cross_attention_dim=None,
out_bias=True,
processor=WanAttnProcessor2_0(),
)
# 2. Cross-attention
self.attn2 = Attention(
query_dim=dim,
heads=num_heads,
kv_heads=num_heads,
dim_head=dim // num_heads,
qk_norm=qk_norm,
eps=eps,
bias=True,
cross_attention_dim=None,
out_bias=True,
added_kv_proj_dim=added_kv_proj_dim,
added_proj_bias=True,
processor=WanAttnProcessor2_0(),
)
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
# 3. Feed-forward
self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
rotary_emb: torch.Tensor,
) -> torch.Tensor:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
self.scale_shift_table + temb.float()
).chunk(6, dim=1)
# 1. Self-attention
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
# 2. Cross-attention
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
hidden_states = hidden_states + attn_output
# 3. Feed-forward
norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
hidden_states
)
ff_output = self.ffn(norm_hidden_states)
hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
return hidden_states
class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
r"""
A Transformer model for video-like data used in the Wan model.
Args:
patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
num_attention_heads (`int`, defaults to `40`):
Fixed length for text embeddings.
attention_head_dim (`int`, defaults to `128`):
The number of channels in each head.
in_channels (`int`, defaults to `16`):
The number of channels in the input.
out_channels (`int`, defaults to `16`):
The number of channels in the output.
text_dim (`int`, defaults to `512`):
Input dimension for text embeddings.
freq_dim (`int`, defaults to `256`):
Dimension for sinusoidal time embeddings.
ffn_dim (`int`, defaults to `13824`):
Intermediate dimension in feed-forward network.
num_layers (`int`, defaults to `40`):
The number of layers of transformer blocks to use.
window_size (`Tuple[int]`, defaults to `(-1, -1)`):
Window size for local attention (-1 indicates global attention).
cross_attn_norm (`bool`, defaults to `True`):
Enable cross-attention normalization.
qk_norm (`bool`, defaults to `True`):
Enable query/key normalization.
eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
add_img_emb (`bool`, defaults to `False`):
Whether to use img_emb.
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
The number of channels to use for the added key and value projections. If `None`, no projection is used.
"""
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
_no_split_modules = ["WanTransformerBlock"]
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
@register_to_config
def __init__(
self,
patch_size: Tuple[int] = (1, 2, 2),
num_attention_heads: int = 40,
attention_head_dim: int = 128,
in_channels: int = 16,
out_channels: int = 16,
text_dim: int = 4096,
freq_dim: int = 256,
ffn_dim: int = 13824,
num_layers: int = 40,
cross_attn_norm: bool = True,
qk_norm: Optional[str] = "rms_norm_across_heads",
eps: float = 1e-6,
image_dim: Optional[int] = None,
added_kv_proj_dim: Optional[int] = None,
rope_max_seq_len: int = 1024,
) -> None:
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
out_channels = out_channels or in_channels
# 1. Patch & position embedding
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
# 2. Condition embeddings
# image_embedding_dim=1280 for I2V model
self.condition_embedder = WanTimeTextImageEmbedding(
dim=inner_dim,
time_freq_dim=freq_dim,
time_proj_dim=inner_dim * 6,
text_embed_dim=text_dim,
image_embed_dim=image_dim,
)
# 3. Transformer blocks
self.blocks = nn.ModuleList(
[
WanTransformerBlock(
inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
)
for _ in range(num_layers)
]
)
# 4. Output norm & projection
self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_image: Optional[torch.Tensor] = None,
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.config.patch_size
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p_h
post_patch_width = width // p_w
rotary_emb = self.rope(hidden_states)
hidden_states = self.patch_embedding(hidden_states)
hidden_states = hidden_states.flatten(2).transpose(1, 2)
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
timestep, encoder_hidden_states, encoder_hidden_states_image
)
timestep_proj = timestep_proj.unflatten(1, (6, -1))
if encoder_hidden_states_image is not None:
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
# 4. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing:
for block in self.blocks:
hidden_states = self._gradient_checkpointing_func(
block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
)
else:
for block in self.blocks:
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
# 5. Output norm, projection & unpatchify
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
)
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
-1
View File
@@ -240,7 +240,6 @@ class UNet2DModel(ModelMixin, ConfigMixin):
dropout=dropout,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
+22 -2
View File
@@ -216,8 +216,17 @@ else:
"IFPipeline",
"IFSuperResolutionPipeline",
]
_import_structure["easyanimate"] = [
"EasyAnimatePipeline",
"EasyAnimateInpaintPipeline",
"EasyAnimateControlPipeline",
]
_import_structure["hunyuandit"] = ["HunyuanDiTPipeline"]
_import_structure["hunyuan_video"] = ["HunyuanVideoPipeline", "HunyuanSkyreelsImageToVideoPipeline"]
_import_structure["hunyuan_video"] = [
"HunyuanVideoPipeline",
"HunyuanSkyreelsImageToVideoPipeline",
"HunyuanVideoImageToVideoPipeline",
]
_import_structure["kandinsky"] = [
"KandinskyCombinedPipeline",
"KandinskyImg2ImgCombinedPipeline",
@@ -347,6 +356,7 @@ else:
"WuerstchenDecoderPipeline",
"WuerstchenPriorPipeline",
]
_import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline"]
try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
@@ -545,6 +555,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
VersatileDiffusionTextToImagePipeline,
VQDiffusionPipeline,
)
from .easyanimate import (
EasyAnimateControlPipeline,
EasyAnimateInpaintPipeline,
EasyAnimatePipeline,
)
from .flux import (
FluxControlImg2ImgPipeline,
FluxControlInpaintPipeline,
@@ -559,7 +574,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxPriorReduxPipeline,
ReduxImageEncoder,
)
from .hunyuan_video import HunyuanSkyreelsImageToVideoPipeline, HunyuanVideoPipeline
from .hunyuan_video import (
HunyuanSkyreelsImageToVideoPipeline,
HunyuanVideoImageToVideoPipeline,
HunyuanVideoPipeline,
)
from .hunyuandit import HunyuanDiTPipeline
from .i2vgen_xl import I2VGenXLPipeline
from .kandinsky import (
@@ -690,6 +709,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
UniDiffuserPipeline,
UniDiffuserTextDecoder,
)
from .wan import WanImageToVideoPipeline, WanPipeline
from .wuerstchen import (
WuerstchenCombinedPipeline,
WuerstchenDecoderPipeline,
@@ -19,7 +19,7 @@ import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...image_processor import PipelineImageInput
from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...models.unets.unet_motion_model import MotionAdapter
@@ -83,6 +83,7 @@ class AnimateDiffPipeline(
StableDiffusionLoraLoaderMixin,
FreeInitMixin,
AnimateDiffFreeNoiseMixin,
FromSingleFileMixin,
):
r"""
Pipeline for text-to-video generation.
@@ -20,7 +20,7 @@ import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...image_processor import PipelineImageInput
from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import (
AutoencoderKL,
ControlNetModel,
@@ -125,6 +125,7 @@ class AnimateDiffControlNetPipeline(
StableDiffusionLoraLoaderMixin,
FreeInitMixin,
AnimateDiffFreeNoiseMixin,
FromSingleFileMixin,
):
r"""
Pipeline for text-to-video generation with ControlNet guidance.
@@ -22,7 +22,7 @@ import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
from ...models.controlnets.controlnet_sparsectrl import SparseControlNetModel
from ...models.lora import adjust_lora_scale_text_encoder
@@ -136,6 +136,7 @@ class AnimateDiffSparseControlNetPipeline(
IPAdapterMixin,
StableDiffusionLoraLoaderMixin,
FreeInitMixin,
FromSingleFileMixin,
):
r"""
Pipeline for controlled text-to-video generation using the method described in [SparseCtrl: Adding Sparse Controls
@@ -19,7 +19,7 @@ import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...image_processor import PipelineImageInput
from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...models.unets.unet_motion_model import MotionAdapter
@@ -186,6 +186,7 @@ class AnimateDiffVideoToVideoPipeline(
StableDiffusionLoraLoaderMixin,
FreeInitMixin,
AnimateDiffFreeNoiseMixin,
FromSingleFileMixin,
):
r"""
Pipeline for video-to-video generation.
@@ -20,7 +20,7 @@ import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...image_processor import PipelineImageInput
from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import (
AutoencoderKL,
ControlNetModel,
@@ -204,6 +204,7 @@ class AnimateDiffVideoToVideoControlNetPipeline(
StableDiffusionLoraLoaderMixin,
FreeInitMixin,
AnimateDiffFreeNoiseMixin,
FromSingleFileMixin,
):
r"""
Pipeline for video-to-video generation with ControlNet guidance.
@@ -14,7 +14,7 @@
# limitations under the License.
import inspect
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
@@ -22,6 +22,7 @@ from transformers import AutoTokenizer, GlmModel
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import VaeImageProcessor
from ...loaders import CogView4LoraLoaderMixin
from ...models import AutoencoderKL, CogView4Transformer2DModel
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -133,7 +134,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
class CogView4Pipeline(DiffusionPipeline):
class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
r"""
Pipeline for text-to-image generation using CogView4.
@@ -143,13 +144,11 @@ class CogView4Pipeline(DiffusionPipeline):
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`T5EncoderModel`]):
Frozen text-encoder. CogView4 uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
tokenizer (`T5Tokenizer`):
text_encoder ([`GLMModel`]):
Frozen text-encoder. CogView4 uses [glm-4-9b-hf](https://huggingface.co/THUDM/glm-4-9b-hf).
tokenizer (`PreTrainedTokenizer`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
[PreTrainedTokenizer](https://huggingface.co/docs/transformers/main/en/main_classes/tokenizer#transformers.PreTrainedTokenizer).
transformer ([`CogView4Transformer2DModel`]):
A text conditioned `CogView4Transformer2DModel` to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
@@ -215,7 +214,7 @@ class CogView4Pipeline(DiffusionPipeline):
)
text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
prompt_embeds = self.text_encoder(
text_input_ids.to(self.text_encoder.model.device), output_hidden_states=True
text_input_ids.to(self.text_encoder.device), output_hidden_states=True
).hidden_states[-2]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
@@ -362,10 +361,16 @@ class CogView4Pipeline(DiffusionPipeline):
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
if prompt_embeds.shape[0] != negative_prompt_embeds.shape[0]:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
"`prompt_embeds` and `negative_prompt_embeds` must have the same batch size when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} and `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if prompt_embeds.shape[-1] != negative_prompt_embeds.shape[-1]:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same dimension when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} and `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
@@ -388,6 +393,10 @@ class CogView4Pipeline(DiffusionPipeline):
def interrupt(self):
return self._interrupt
@property
def attention_kwargs(self):
return self._attention_kwargs
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
@@ -409,6 +418,7 @@ class CogView4Pipeline(DiffusionPipeline):
crops_coords_top_left: Tuple[int, int] = (0, 0),
output_type: str = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
@@ -522,6 +532,7 @@ class CogView4Pipeline(DiffusionPipeline):
negative_prompt_embeds,
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._interrupt = False
# Default call parameters
@@ -611,6 +622,7 @@ class CogView4Pipeline(DiffusionPipeline):
original_size=original_size,
target_size=target_size,
crop_coords=crops_coords_top_left,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
@@ -623,6 +635,7 @@ class CogView4Pipeline(DiffusionPipeline):
original_size=original_size,
target_size=target_size,
crop_coords=crops_coords_top_left,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
@@ -757,15 +757,9 @@ class StableDiffusionXLControlNetUnionPipeline(
for images_ in image:
for image_ in images_:
self.check_image(image_, prompt, prompt_embeds)
else:
assert False
# Check `controlnet_conditioning_scale`
# TODO Update for https://github.com/huggingface/diffusers/pull/10723
if isinstance(controlnet, ControlNetUnionModel):
if not isinstance(controlnet_conditioning_scale, float):
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
elif isinstance(controlnet, MultiControlNetUnionModel):
if isinstance(controlnet, MultiControlNetUnionModel):
if isinstance(controlnet_conditioning_scale, list):
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
raise ValueError("A single batch of multiple conditionings is not supported at the moment.")
@@ -776,8 +770,6 @@ class StableDiffusionXLControlNetUnionPipeline(
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
" the same length as the number of controlnets"
)
else:
assert False
if len(control_guidance_start) != len(control_guidance_end):
raise ValueError(
@@ -808,8 +800,6 @@ class StableDiffusionXLControlNetUnionPipeline(
for _control_mode, _controlnet in zip(control_mode, self.controlnet.nets):
if max(_control_mode) >= _controlnet.config.num_control_type:
raise ValueError(f"control_mode: must be lower than {_controlnet.config.num_control_type}.")
else:
assert False
# Equal number of `image` and `control_mode` elements
if isinstance(controlnet, ControlNetUnionModel):
@@ -823,8 +813,6 @@ class StableDiffusionXLControlNetUnionPipeline(
elif sum(len(x) for x in image) != sum(len(x) for x in control_mode):
raise ValueError("Expected len(control_image) == len(control_mode)")
else:
assert False
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
@@ -1201,18 +1189,6 @@ class StableDiffusionXLControlNetUnionPipeline(
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else 1
control_guidance_start, control_guidance_end = (
mult * [control_guidance_start],
mult * [control_guidance_end],
)
if not isinstance(control_image, list):
control_image = [control_image]
else:
@@ -1221,8 +1197,25 @@ class StableDiffusionXLControlNetUnionPipeline(
if not isinstance(control_mode, list):
control_mode = [control_mode]
if isinstance(controlnet, MultiControlNetUnionModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
if isinstance(controlnet, MultiControlNetUnionModel):
control_image = [[item] for item in control_image]
control_mode = [[item] for item in control_mode]
# align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
control_guidance_start, control_guidance_end = (
mult * [control_guidance_start],
mult * [control_guidance_end],
)
if isinstance(controlnet_conditioning_scale, float):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
controlnet_conditioning_scale = [controlnet_conditioning_scale] * mult
# 1. Check inputs
self.check_inputs(
@@ -1357,9 +1350,6 @@ class StableDiffusionXLControlNetUnionPipeline(
control_image = control_images
height, width = control_image[0][0].shape[-2:]
else:
assert False
# 5. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas
@@ -1397,7 +1387,7 @@ class StableDiffusionXLControlNetUnionPipeline(
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
for s, e in zip(control_guidance_start, control_guidance_end)
]
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetUnionModel) else keeps)
controlnet_keep.append(keeps)
# 7.2 Prepare added time ids & embeddings
original_size = original_size or (height, width)
@@ -0,0 +1,52 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_easyanimate"] = ["EasyAnimatePipeline"]
_import_structure["pipeline_easyanimate_control"] = ["EasyAnimateControlPipeline"]
_import_structure["pipeline_easyanimate_inpaint"] = ["EasyAnimateInpaintPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_easyanimate import EasyAnimatePipeline
from .pipeline_easyanimate_control import EasyAnimateControlPipeline
from .pipeline_easyanimate_inpaint import EasyAnimateInpaintPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
+770
View File
@@ -0,0 +1,770 @@
# Copyright 2025 The EasyAnimate team and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, Dict, List, Optional, Union
import torch
from transformers import (
BertModel,
BertTokenizer,
Qwen2Tokenizer,
Qwen2VLForConditionalGeneration,
)
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from .pipeline_output import EasyAnimatePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers import EasyAnimatePipeline
>>> from diffusers.utils import export_to_video
>>> # Models: "alibaba-pai/EasyAnimateV5.1-12b-zh"
>>> pipe = EasyAnimatePipeline.from_pretrained(
... "alibaba-pai/EasyAnimateV5.1-7b-zh-diffusers", torch_dtype=torch.float16
... ).to("cuda")
>>> prompt = (
... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
... "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
... "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
... "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
... "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
... "atmosphere of this unique musical performance."
... )
>>> sample_size = (512, 512)
>>> video = pipe(
... prompt=prompt,
... guidance_scale=6,
... negative_prompt="bad detailed",
... height=sample_size[0],
... width=sample_size[1],
... num_inference_steps=50,
... ).frames[0]
>>> export_to_video(video, "output.mp4", fps=8)
```
"""
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
tw = tgt_width
th = tgt_height
h, w = src
r = h / w
if r > (th / tw):
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h))
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Args:
noise_cfg (`torch.Tensor`):
The predicted noise tensor for the guided diffusion process.
noise_pred_text (`torch.Tensor`):
The predicted noise tensor for the text-guided diffusion process.
guidance_rescale (`float`, *optional*, defaults to 0.0):
A rescale factor applied to the noise predictions.
Returns:
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class EasyAnimatePipeline(DiffusionPipeline):
r"""
Pipeline for text-to-video generation using EasyAnimate.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
EasyAnimate uses one text encoder [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
Args:
vae ([`AutoencoderKLMagvit`]):
Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations.
text_encoder (Optional[`~transformers.Qwen2VLForConditionalGeneration`, `~transformers.BertModel`]):
EasyAnimate uses [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
tokenizer (Optional[`~transformers.Qwen2Tokenizer`, `~transformers.BertTokenizer`]):
A `Qwen2Tokenizer` or `BertTokenizer` to tokenize text.
transformer ([`EasyAnimateTransformer3DModel`]):
The EasyAnimate model designed by EasyAnimate Team.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents.
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
vae: AutoencoderKLMagvit,
text_encoder: Union[Qwen2VLForConditionalGeneration, BertModel],
tokenizer: Union[Qwen2Tokenizer, BertTokenizer],
transformer: EasyAnimateTransformer3DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
)
self.enable_text_attention_mask = (
self.transformer.config.enable_text_attention_mask
if getattr(self, "transformer", None) is not None
else True
)
self.vae_spatial_compression_ratio = (
self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 8
)
self.vae_temporal_compression_ratio = (
self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 4
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
def encode_prompt(
self,
prompt: Union[str, List[str]],
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 256,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
dtype (`torch.dtype`):
torch dtype
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
"""
dtype = dtype or self.text_encoder.dtype
device = device or self.text_encoder.device
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
if isinstance(prompt, str):
messages = [
{
"role": "user",
"content": [{"type": "text", "text": prompt}],
}
]
else:
messages = [
{
"role": "user",
"content": [{"type": "text", "text": _prompt}],
}
for _prompt in prompt
]
text = [
self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages
]
text_inputs = self.tokenizer(
text=text,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_attention_mask=True,
padding_side="right",
return_tensors="pt",
)
text_inputs = text_inputs.to(self.text_encoder.device)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
if self.enable_text_attention_mask:
# Inference: Generation of the output
prompt_embeds = self.text_encoder(
input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
).hidden_states[-2]
else:
raise ValueError("LLM needs attention_mask")
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.to(device=device)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
if negative_prompt is not None and isinstance(negative_prompt, str):
messages = [
{
"role": "user",
"content": [{"type": "text", "text": negative_prompt}],
}
]
else:
messages = [
{
"role": "user",
"content": [{"type": "text", "text": _negative_prompt}],
}
for _negative_prompt in negative_prompt
]
text = [
self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages
]
text_inputs = self.tokenizer(
text=text,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_attention_mask=True,
padding_side="right",
return_tensors="pt",
)
text_inputs = text_inputs.to(self.text_encoder.device)
text_input_ids = text_inputs.input_ids
negative_prompt_attention_mask = text_inputs.attention_mask
if self.enable_text_attention_mask:
# Inference: Generation of the output
negative_prompt_embeds = self.text_encoder(
input_ids=text_input_ids,
attention_mask=negative_prompt_attention_mask,
output_hidden_states=True,
).hidden_states[-2]
else:
raise ValueError("LLM needs attention_mask")
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device=device)
return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(
self,
prompt,
height,
width,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
callback_on_step_end_tensor_inputs=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt_embeds is not None and prompt_attention_mask is None:
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
):
if latents is not None:
return latents.to(device=device, dtype=dtype)
shape = (
batch_size,
num_channels_latents,
(num_frames - 1) // self.vae_temporal_compression_ratio + 1,
height // self.vae_spatial_compression_ratio,
width // self.vae_spatial_compression_ratio,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
# scale the initial noise by the standard deviation required by the scheduler
if hasattr(self.scheduler, "init_noise_sigma"):
latents = latents * self.scheduler.init_noise_sigma
return latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def guidance_rescale(self):
return self._guidance_rescale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
num_frames: Optional[int] = 49,
height: Optional[int] = 512,
width: Optional[int] = 512,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
timesteps: Optional[List[int]] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
guidance_rescale: float = 0.0,
):
r"""
Generates images or video using the EasyAnimate pipeline based on the provided prompts.
Examples:
prompt (`str` or `List[str]`, *optional*):
Text prompts to guide the image or video generation. If not provided, use `prompt_embeds` instead.
num_frames (`int`, *optional*):
Length of the generated video (in frames).
height (`int`, *optional*):
Height of the generated image in pixels.
width (`int`, *optional*):
Width of the generated image in pixels.
num_inference_steps (`int`, *optional*, defaults to 50):
Number of denoising steps during generation. More steps generally yield higher quality images but slow
down inference.
guidance_scale (`float`, *optional*, defaults to 5.0):
Encourages the model to align outputs with prompts. A higher value may decrease image quality.
negative_prompt (`str` or `List[str]`, *optional*):
Prompts indicating what to exclude in generation. If not specified, use `negative_prompt_embeds`.
num_images_per_prompt (`int`, *optional*, defaults to 1):
Number of images to generate for each prompt.
eta (`float`, *optional*, defaults to 0.0):
Applies to DDIM scheduling. Controlled by the eta parameter from the related literature.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A generator to ensure reproducibility in image generation.
latents (`torch.Tensor`, *optional*):
Predefined latent tensors to condition generation.
prompt_embeds (`torch.Tensor`, *optional*):
Text embeddings for the prompts. Overrides prompt string inputs for more flexibility.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Embeddings for negative prompts. Overrides string inputs if defined.
prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for the primary prompt embeddings.
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for negative prompt embeddings.
output_type (`str`, *optional*, defaults to "latent"):
Format of the generated output, either as a PIL image or as a NumPy array.
return_dict (`bool`, *optional*, defaults to `True`):
If `True`, returns a structured output. Otherwise returns a simple tuple.
callback_on_step_end (`Callable`, *optional*):
Functions called at the end of each denoising step.
callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
Tensor names to be included in callback function calls.
guidance_rescale (`float`, *optional*, defaults to 0.0):
Adjusts noise levels based on guidance scale.
original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`):
Original dimensions of the output.
target_size (`Tuple[int, int]`, *optional*):
Desired output dimensions for calculations.
crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`):
Coordinates for cropping.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list with the generated images and the
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 0. default height and width
height = int((height // 16) * 16)
width = int((width // 16) * 16)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._guidance_rescale = guidance_rescale
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
if self.text_encoder is not None:
dtype = self.text_encoder.dtype
else:
dtype = self.transformer.dtype
# 3. Encode input prompt
(
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt=prompt,
device=device,
dtype=dtype,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
)
# 4. Prepare timesteps
if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, mu=1
)
else:
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
num_frames,
height,
width,
dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
prompt_embeds = prompt_embeds.to(device=device)
prompt_attention_mask = prompt_attention_mask.to(device=device)
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
if hasattr(self.scheduler, "scale_model_input"):
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
dtype=latent_model_input.dtype
)
# predict the noise residual
noise_pred = self.transformer(
latent_model_input,
t_expand,
encoder_hidden_states=prompt_embeds,
return_dict=False,
)[0]
if noise_pred.size()[1] != self.vae.config.latent_channels:
noise_pred, _ = noise_pred.chunk(2, dim=1)
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if not output_type == "latent":
latents = 1 / self.vae.config.scaling_factor * latents
video = self.vae.decode(latents, return_dict=False)[0]
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return EasyAnimatePipelineOutput(frames=video)
@@ -0,0 +1,994 @@
# Copyright 2025 The EasyAnimate team and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, Dict, List, Optional, Union
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from transformers import (
BertModel,
BertTokenizer,
Qwen2Tokenizer,
Qwen2VLForConditionalGeneration,
)
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKLMagvit, EasyAnimateTransformer3DModel
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from .pipeline_output import EasyAnimatePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers import EasyAnimateControlPipeline
>>> from diffusers.pipelines.easyanimate.pipeline_easyanimate_control import get_video_to_video_latent
>>> from diffusers.utils import export_to_video, load_video
>>> pipe = EasyAnimateControlPipeline.from_pretrained(
... "alibaba-pai/EasyAnimateV5.1-12b-zh-Control-diffusers", torch_dtype=torch.bfloat16
... )
>>> pipe.to("cuda")
>>> control_video = load_video(
... "https://huggingface.co/alibaba-pai/EasyAnimateV5.1-12b-zh-Control/blob/main/asset/pose.mp4"
... )
>>> prompt = (
... "In this sunlit outdoor garden, a beautiful woman is dressed in a knee-length, sleeveless white dress. "
... "The hem of her dress gently sways with her graceful dance, much like a butterfly fluttering in the breeze. "
... "Sunlight filters through the leaves, casting dappled shadows that highlight her soft features and clear eyes, "
... "making her appear exceptionally elegant. It seems as if every movement she makes speaks of youth and vitality. "
... "As she twirls on the grass, her dress flutters, as if the entire garden is rejoicing in her dance. "
... "The colorful flowers around her sway in the gentle breeze, with roses, chrysanthemums, and lilies each "
... "releasing their fragrances, creating a relaxed and joyful atmosphere."
... )
>>> sample_size = (672, 384)
>>> num_frames = 49
>>> input_video, _, _ = get_video_to_video_latent(control_video, num_frames, sample_size)
>>> video = pipe(
... prompt,
... num_frames=num_frames,
... negative_prompt="Twisted body, limb deformities, text subtitles, comics, stillness, ugliness, errors, garbled text.",
... height=sample_size[0],
... width=sample_size[1],
... control_video=input_video,
... ).frames[0]
>>> export_to_video(video, "output.mp4", fps=8)
```
"""
def preprocess_image(image, sample_size):
"""
Preprocess a single image (PIL.Image, numpy.ndarray, or torch.Tensor) to a resized tensor.
"""
if isinstance(image, torch.Tensor):
# If input is a tensor, assume it's in CHW format and resize using interpolation
image = torch.nn.functional.interpolate(
image.unsqueeze(0), size=sample_size, mode="bilinear", align_corners=False
).squeeze(0)
elif isinstance(image, Image.Image):
# If input is a PIL image, resize and convert to numpy array
image = image.resize((sample_size[1], sample_size[0]))
image = np.array(image)
elif isinstance(image, np.ndarray):
# If input is a numpy array, resize using PIL
image = Image.fromarray(image).resize((sample_size[1], sample_size[0]))
image = np.array(image)
else:
raise ValueError("Unsupported input type. Expected PIL.Image, numpy.ndarray, or torch.Tensor.")
# Convert to tensor if not already
if not isinstance(image, torch.Tensor):
image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0 # HWC -> CHW, normalize to [0, 1]
return image
def get_video_to_video_latent(input_video, num_frames, sample_size, validation_video_mask=None, ref_image=None):
if input_video is not None:
# Convert each frame in the list to tensor
input_video = [preprocess_image(frame, sample_size=sample_size) for frame in input_video]
# Stack all frames into a single tensor (F, C, H, W)
input_video = torch.stack(input_video)[:num_frames]
# Add batch dimension (B, F, C, H, W)
input_video = input_video.permute(1, 0, 2, 3).unsqueeze(0)
if validation_video_mask is not None:
# Handle mask input
validation_video_mask = preprocess_image(validation_video_mask, size=sample_size)
input_video_mask = torch.where(validation_video_mask < 240 / 255.0, 0.0, 255)
# Adjust mask dimensions to match video
input_video_mask = input_video_mask.unsqueeze(0).unsqueeze(-1).permute([3, 0, 1, 2]).unsqueeze(0)
input_video_mask = torch.tile(input_video_mask, [1, 1, input_video.size()[2], 1, 1])
input_video_mask = input_video_mask.to(input_video.device, input_video.dtype)
else:
input_video_mask = torch.zeros_like(input_video[:, :1])
input_video_mask[:, :, :] = 255
else:
input_video, input_video_mask = None, None
if ref_image is not None:
# Convert reference image to tensor
ref_image = preprocess_image(ref_image, size=sample_size)
ref_image = ref_image.permute(1, 0, 2, 3).unsqueeze(0) # Add batch dimension (B, C, H, W)
else:
ref_image = None
return input_video, input_video_mask, ref_image
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
tw = tgt_width
th = tgt_height
h, w = src
r = h / w
if r > (th / tw):
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h))
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Args:
noise_cfg (`torch.Tensor`):
The predicted noise tensor for the guided diffusion process.
noise_pred_text (`torch.Tensor`):
The predicted noise tensor for the text-guided diffusion process.
guidance_rescale (`float`, *optional*, defaults to 0.0):
A rescale factor applied to the noise predictions.
Returns:
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
# Resize mask information in magvit
def resize_mask(mask, latent, process_first_frame_only=True):
latent_size = latent.size()
if process_first_frame_only:
target_size = list(latent_size[2:])
target_size[0] = 1
first_frame_resized = F.interpolate(
mask[:, :, 0:1, :, :], size=target_size, mode="trilinear", align_corners=False
)
target_size = list(latent_size[2:])
target_size[0] = target_size[0] - 1
if target_size[0] != 0:
remaining_frames_resized = F.interpolate(
mask[:, :, 1:, :, :], size=target_size, mode="trilinear", align_corners=False
)
resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2)
else:
resized_mask = first_frame_resized
else:
target_size = list(latent_size[2:])
resized_mask = F.interpolate(mask, size=target_size, mode="trilinear", align_corners=False)
return resized_mask
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class EasyAnimateControlPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-video generation using EasyAnimate.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
EasyAnimate uses one text encoder [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
Args:
vae ([`AutoencoderKLMagvit`]):
Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations.
text_encoder (Optional[`~transformers.Qwen2VLForConditionalGeneration`, `~transformers.BertModel`]):
EasyAnimate uses [qwen2 vl](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) in V5.1.
tokenizer (Optional[`~transformers.Qwen2Tokenizer`, `~transformers.BertTokenizer`]):
A `Qwen2Tokenizer` or `BertTokenizer` to tokenize text.
transformer ([`EasyAnimateTransformer3DModel`]):
The EasyAnimate model designed by EasyAnimate Team.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with EasyAnimate to denoise the encoded image latents.
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
vae: AutoencoderKLMagvit,
text_encoder: Union[Qwen2VLForConditionalGeneration, BertModel],
tokenizer: Union[Qwen2Tokenizer, BertTokenizer],
transformer: EasyAnimateTransformer3DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
)
self.enable_text_attention_mask = (
self.transformer.config.enable_text_attention_mask
if getattr(self, "transformer", None) is not None
else True
)
self.vae_spatial_compression_ratio = (
self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 8
)
self.vae_temporal_compression_ratio = (
self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 4
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_spatial_compression_ratio,
do_normalize=False,
do_binarize=True,
do_convert_grayscale=True,
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
# Copied from diffusers.pipelines.easyanimate.pipeline_easyanimate.EasyAnimatePipeline.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 256,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
dtype (`torch.dtype`):
torch dtype
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
max_sequence_length (`int`, *optional*): maximum sequence length to use for the prompt.
"""
dtype = dtype or self.text_encoder.dtype
device = device or self.text_encoder.device
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
if isinstance(prompt, str):
messages = [
{
"role": "user",
"content": [{"type": "text", "text": prompt}],
}
]
else:
messages = [
{
"role": "user",
"content": [{"type": "text", "text": _prompt}],
}
for _prompt in prompt
]
text = [
self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages
]
text_inputs = self.tokenizer(
text=text,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_attention_mask=True,
padding_side="right",
return_tensors="pt",
)
text_inputs = text_inputs.to(self.text_encoder.device)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
if self.enable_text_attention_mask:
# Inference: Generation of the output
prompt_embeds = self.text_encoder(
input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
).hidden_states[-2]
else:
raise ValueError("LLM needs attention_mask")
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.to(device=device)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
if negative_prompt is not None and isinstance(negative_prompt, str):
messages = [
{
"role": "user",
"content": [{"type": "text", "text": negative_prompt}],
}
]
else:
messages = [
{
"role": "user",
"content": [{"type": "text", "text": _negative_prompt}],
}
for _negative_prompt in negative_prompt
]
text = [
self.tokenizer.apply_chat_template([m], tokenize=False, add_generation_prompt=True) for m in messages
]
text_inputs = self.tokenizer(
text=text,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_attention_mask=True,
padding_side="right",
return_tensors="pt",
)
text_inputs = text_inputs.to(self.text_encoder.device)
text_input_ids = text_inputs.input_ids
negative_prompt_attention_mask = text_inputs.attention_mask
if self.enable_text_attention_mask:
# Inference: Generation of the output
negative_prompt_embeds = self.text_encoder(
input_ids=text_input_ids,
attention_mask=negative_prompt_attention_mask,
output_hidden_states=True,
).hidden_states[-2]
else:
raise ValueError("LLM needs attention_mask")
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device=device)
return prompt_embeds, negative_prompt_embeds, prompt_attention_mask, negative_prompt_attention_mask
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(
self,
prompt,
height,
width,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
callback_on_step_end_tensor_inputs=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt_embeds is not None and prompt_attention_mask is None:
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
def prepare_latents(
self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None
):
if latents is not None:
return latents.to(device=device, dtype=dtype)
shape = (
batch_size,
num_channels_latents,
(num_frames - 1) // self.vae_temporal_compression_ratio + 1,
height // self.vae_spatial_compression_ratio,
width // self.vae_spatial_compression_ratio,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
# scale the initial noise by the standard deviation required by the scheduler
if hasattr(self.scheduler, "init_noise_sigma"):
latents = latents * self.scheduler.init_noise_sigma
return latents
def prepare_control_latents(
self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
):
# resize the control to latents shape as we concatenate the control to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
if control is not None:
control = control.to(device=device, dtype=dtype)
bs = 1
new_control = []
for i in range(0, control.shape[0], bs):
control_bs = control[i : i + bs]
control_bs = self.vae.encode(control_bs)[0]
control_bs = control_bs.mode()
new_control.append(control_bs)
control = torch.cat(new_control, dim=0)
control = control * self.vae.config.scaling_factor
if control_image is not None:
control_image = control_image.to(device=device, dtype=dtype)
bs = 1
new_control_pixel_values = []
for i in range(0, control_image.shape[0], bs):
control_pixel_values_bs = control_image[i : i + bs]
control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0]
control_pixel_values_bs = control_pixel_values_bs.mode()
new_control_pixel_values.append(control_pixel_values_bs)
control_image_latents = torch.cat(new_control_pixel_values, dim=0)
control_image_latents = control_image_latents * self.vae.config.scaling_factor
else:
control_image_latents = None
return control, control_image_latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def guidance_rescale(self):
return self._guidance_rescale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
num_frames: Optional[int] = 49,
height: Optional[int] = 512,
width: Optional[int] = 512,
control_video: Union[torch.FloatTensor] = None,
control_camera_video: Union[torch.FloatTensor] = None,
ref_image: Union[torch.FloatTensor] = None,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
guidance_rescale: float = 0.0,
timesteps: Optional[List[int]] = None,
):
r"""
Generates images or video using the EasyAnimate pipeline based on the provided prompts.
Examples:
prompt (`str` or `List[str]`, *optional*):
Text prompts to guide the image or video generation. If not provided, use `prompt_embeds` instead.
num_frames (`int`, *optional*):
Length of the generated video (in frames).
height (`int`, *optional*):
Height of the generated image in pixels.
width (`int`, *optional*):
Width of the generated image in pixels.
num_inference_steps (`int`, *optional*, defaults to 50):
Number of denoising steps during generation. More steps generally yield higher quality images but slow
down inference.
guidance_scale (`float`, *optional*, defaults to 5.0):
Encourages the model to align outputs with prompts. A higher value may decrease image quality.
negative_prompt (`str` or `List[str]`, *optional*):
Prompts indicating what to exclude in generation. If not specified, use `negative_prompt_embeds`.
num_images_per_prompt (`int`, *optional*, defaults to 1):
Number of images to generate for each prompt.
eta (`float`, *optional*, defaults to 0.0):
Applies to DDIM scheduling. Controlled by the eta parameter from the related literature.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A generator to ensure reproducibility in image generation.
latents (`torch.Tensor`, *optional*):
Predefined latent tensors to condition generation.
prompt_embeds (`torch.Tensor`, *optional*):
Text embeddings for the prompts. Overrides prompt string inputs for more flexibility.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Embeddings for negative prompts. Overrides string inputs if defined.
prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for the primary prompt embeddings.
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for negative prompt embeddings.
output_type (`str`, *optional*, defaults to "latent"):
Format of the generated output, either as a PIL image or as a NumPy array.
return_dict (`bool`, *optional*, defaults to `True`):
If `True`, returns a structured output. Otherwise returns a simple tuple.
callback_on_step_end (`Callable`, *optional*):
Functions called at the end of each denoising step.
callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
Tensor names to be included in callback function calls.
guidance_rescale (`float`, *optional*, defaults to 0.0):
Adjusts noise levels based on guidance scale.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list with the generated images and the
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 0. default height and width
height = int((height // 16) * 16)
width = int((width // 16) * 16)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._guidance_rescale = guidance_rescale
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
if self.text_encoder is not None:
dtype = self.text_encoder.dtype
else:
dtype = self.transformer.dtype
# 3. Encode input prompt
(
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt=prompt,
device=device,
dtype=dtype,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
text_encoder_index=0,
)
# 4. Prepare timesteps
if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler):
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, mu=1
)
else:
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.vae.config.latent_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
num_frames,
height,
width,
dtype,
device,
generator,
latents,
)
if control_camera_video is not None:
control_video_latents = resize_mask(control_camera_video, latents, process_first_frame_only=True)
control_video_latents = control_video_latents * 6
control_latents = (
torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents
).to(device, dtype)
elif control_video is not None:
batch_size, channels, num_frames, height_video, width_video = control_video.shape
control_video = self.image_processor.preprocess(
control_video.permute(0, 2, 1, 3, 4).reshape(
batch_size * num_frames, channels, height_video, width_video
),
height=height,
width=width,
)
control_video = control_video.to(dtype=torch.float32)
control_video = control_video.reshape(batch_size, num_frames, channels, height, width).permute(
0, 2, 1, 3, 4
)
control_video_latents = self.prepare_control_latents(
None,
control_video,
batch_size,
height,
width,
dtype,
device,
generator,
self.do_classifier_free_guidance,
)[1]
control_latents = (
torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents
).to(device, dtype)
else:
control_video_latents = torch.zeros_like(latents).to(device, dtype)
control_latents = (
torch.cat([control_video_latents] * 2) if self.do_classifier_free_guidance else control_video_latents
).to(device, dtype)
if ref_image is not None:
batch_size, channels, num_frames, height_video, width_video = ref_image.shape
ref_image = self.image_processor.preprocess(
ref_image.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height_video, width_video),
height=height,
width=width,
)
ref_image = ref_image.to(dtype=torch.float32)
ref_image = ref_image.reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
ref_image_latents = self.prepare_control_latents(
None,
ref_image,
batch_size,
height,
width,
prompt_embeds.dtype,
device,
generator,
self.do_classifier_free_guidance,
)[1]
ref_image_latents_conv_in = torch.zeros_like(latents)
if latents.size()[2] != 1:
ref_image_latents_conv_in[:, :, :1] = ref_image_latents
ref_image_latents_conv_in = (
torch.cat([ref_image_latents_conv_in] * 2)
if self.do_classifier_free_guidance
else ref_image_latents_conv_in
).to(device, dtype)
control_latents = torch.cat([control_latents, ref_image_latents_conv_in], dim=1)
else:
ref_image_latents_conv_in = torch.zeros_like(latents)
ref_image_latents_conv_in = (
torch.cat([ref_image_latents_conv_in] * 2)
if self.do_classifier_free_guidance
else ref_image_latents_conv_in
).to(device, dtype)
control_latents = torch.cat([control_latents, ref_image_latents_conv_in], dim=1)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
# To latents.device
prompt_embeds = prompt_embeds.to(device=device)
prompt_attention_mask = prompt_attention_mask.to(device=device)
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
if hasattr(self.scheduler, "scale_model_input"):
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# expand scalar t to 1-D tensor to match the 1st dim of latent_model_input
t_expand = torch.tensor([t] * latent_model_input.shape[0], device=device).to(
dtype=latent_model_input.dtype
)
# predict the noise residual
noise_pred = self.transformer(
latent_model_input,
t_expand,
encoder_hidden_states=prompt_embeds,
control_latents=control_latents,
return_dict=False,
)[0]
if noise_pred.size()[1] != self.vae.config.latent_channels:
noise_pred, _ = noise_pred.chunk(2, dim=1)
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
# Convert to tensor
if not output_type == "latent":
video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return EasyAnimatePipelineOutput(frames=video)
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,20 @@
from dataclasses import dataclass
import torch
from diffusers.utils import BaseOutput
@dataclass
class EasyAnimatePipelineOutput(BaseOutput):
r"""
Output class for EasyAnimate pipelines.
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`.
"""
frames: torch.Tensor
@@ -694,7 +694,7 @@ class FluxPipeline(
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 7.0):
guidance_scale (`float`, *optional*, defaults to 3.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >

Some files were not shown because too many files have changed in this diff Show More