Compare commits

..

51 Commits

Author SHA1 Message Date
Dhruv Nair 6a94ef7388 update 2025-06-12 03:22:03 +02:00
DN6 f20e4afbaa update 2025-06-10 23:56:08 +05:30
DN6 7787ec11c8 update 2025-06-10 23:51:08 +05:30
DN6 c5c7588648 update 2025-06-10 23:45:32 +05:30
Dhruv Nair 6ccaed77ed update 2025-06-10 16:15:20 +02:00
Dhruv Nair f6ece89c6d update 2025-06-10 10:51:46 +02:00
DN6 542a6034d3 update 2025-06-10 13:55:23 +05:30
Hameer Abbasi e95ac9d82f Merge pull request #1 from iddl/chroma-fixes 2025-06-10 03:17:04 +02:00
Ivan DiLernia 104e1636b2 Get chroma to a functioning state 2025-06-09 11:48:50 -04:00
Hameer Abbasi 373106cedb Add attention masking. 2025-05-19 12:59:33 +05:00
Hameer Abbasi 8ceed7d3ae Initial commit: Chroma as a FLUX.1 variant. 2025-05-17 05:25:02 +05:00
Sayak Paul 9836f0e000 [docs] Regional compilation docs (#11556)
* add regional compilation docs.

* minor.

* reviwer feedback.

* Update docs/source/en/optimization/torch2.0.md

Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com>

---------

Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com>
2025-05-15 19:11:24 +05:30
Sayak Paul 20379d9d13 [tests] add tests for combining layerwise upcasting and groupoffloading. (#11558)
* add tests for combining layerwise upcasting and groupoffloading.

* feedback
2025-05-15 17:16:44 +05:30
Animesh Jain 3a6caba8e4 [gguf] Refactor __torch_function__ to avoid unnecessary computation (#11551)
* [gguf] Refactor __torch_function__ to avoid unnecessary computation

This helps with torch.compile compilation lantency. Avoiding unnecessary
computation should also lead to a slightly improved eager latency.

* Apply style fixes

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-05-15 14:38:18 +05:30
Dhruv Nair 4267d8f4eb [Single File] GGUF/Single File Support for HiDream (#11550)
* update

* update

* update

* update

* update

* update

* update
2025-05-15 12:25:18 +05:30
Seokhyeon Jeong f4fa3beee7 [tests] Add torch.compile test for UNet2DConditionModel (#11537)
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-05-14 11:26:12 +05:30
Anwesha Chowdhury 7e3353196c Fix deprecation warnings in test_ltx_image2video.py (#11538)
Fixed 2 warnings that were raised during running LTXImageToVideoPipelineFastTests

Co-authored-by: achowdhury1211@gmail.com <anwesha@LAPTOP-E5QGFMOQ>
2025-05-13 08:05:52 -10:00
Meatfucker 8c249d1401 Update pipeline_flux_img2img.py to add missing vae_slicing and vae_tiling calls. (#11545)
* Update pipeline_flux_img2img.py

Adds missing vae_slicing and vae_tiling calls to FluxImage2ImagePipeline

* Update src/diffusers/pipelines/flux/pipeline_flux_img2img.py

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>

* Update src/diffusers/pipelines/flux/pipeline_flux_img2img.py

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>

* Update src/diffusers/pipelines/flux/pipeline_flux_img2img.py

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>

* Update src/diffusers/pipelines/flux/pipeline_flux_img2img.py

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>

---------

Co-authored-by: Álvaro Somoza <asomoza@users.noreply.github.com>
2025-05-13 12:31:45 -04:00
Sayak Paul b555a03723 [tests] Enable testing for HiDream transformer (#11478)
* add tests for hidream transformer model.

* fix

* Update tests/models/transformers/test_models_transformer_hidream.py

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2025-05-13 21:01:20 +05:30
Aryan 06fee551e9 LTX Video 0.9.7 (#11516)
* add upsampling pipeline

* ltx upsample pipeline conversion; pipeline fixes

* make fix-copies

* remove print

* add vae convenience methods

* update

* add tests

* support denoising strength for upscaling & video-to-video

* update docs

* update doc checkpoints

* update docs

* fix

---------

Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
2025-05-13 14:57:03 +05:30
Linoy Tsaban 8b99f7e157 [LoRA] small change to support Hunyuan LoRA Loading for FramePack (#11546)
init
2025-05-13 10:15:06 +03:00
Kenneth Gerald Hamilton 07dd6f8c0e [train_dreambooth.py] Fix the LR Schedulers when num_train_epochs is passed in a distributed training env (#11239)
Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-05-13 07:34:01 +05:30
johannaSommer f8d4a1e283 fix: remove torch_dtype="auto" option from docstrings (#11513)
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2025-05-12 15:42:35 -10:00
Abdellah Oumida ddd0cfb497 Fix typo in train_diffusion_orpo_sdxl_lora_wds.py (#11541) 2025-05-12 15:28:29 -10:00
Zhong-Yu Li 4f438de35a Add VisualCloze (#11377)
* VisualCloze

* style quality

* add docs

* add docs

* typo

* Update docs/source/en/api/pipelines/visualcloze.md

* delete einops

* style quality

* Update src/diffusers/pipelines/visualcloze/pipeline_visualcloze.py

* reorg

* refine doc

* style quality

* typo

* typo

* Update src/diffusers/image_processor.py

* add comment

* test

* style

* Modified based on review

* style

* restore image_processor

* update example url

* style

* fix-copies

* VisualClozeGenerationPipeline

* combine

* tests docs

* remove VisualClozeUpsamplingPipeline

* style

* quality

* test examples

* quality style

* typo

* make fix-copies

* fix test_callback_cfg and test_save_load_dduf in VisualClozePipelineFastTests

* add EXAMPLE_DOC_STRING to VisualClozeGenerationPipeline

* delete maybe_free_model_hooks from pipeline_visualcloze_combined

* Apply suggestions from code review

* fix test_save_load_local test; add reason for skipping cfg test

* more save_load test fixes

* fix tests in generation pipeline tests
2025-05-13 02:46:51 +05:30
Evan Han 98cc6d05e4 [test_models_transformer_ltx.py] help us test torch.compile() for impactful models (#11512)
* Update test_models_transformer_ltx.py

* Update test_models_transformer_ltx.py

* Update test_models_transformer_ltx.py

* Update test_models_transformer_ltx.py

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-05-12 19:26:15 +05:30
Yao Matrix c3726153fd enable several pipeline integration tests on XPU (#11526)
* enable kandinsky2_2 integration test cases on XPU

Signed-off-by: Yao Matrix <matrix.yao@intel.com>

* fix style

Signed-off-by: Yao Matrix <matrix.yao@intel.com>

* enable latent_diffusion, dance_diffusion, musicldm, shap_e integration
uts on xpu

Signed-off-by: Yao Matrix <matrix.yao@intel.com>

* fix style

Signed-off-by: Yao Matrix <matrix.yao@intel.com>

---------

Signed-off-by: Yao Matrix <matrix.yao@intel.com>
Co-authored-by: Aryan <aryan@huggingface.co>
2025-05-12 16:51:37 +05:30
Aryan e48f6aeeb4 Hunyuan Video Framepack F1 (#11534)
* support framepack f1

* update docs

* update toctree

* remove typo
2025-05-12 16:11:10 +05:30
Sayak Paul 01abfc8736 [tests] add tests for framepack transformer model. (#11520)
* start.

* add tests for framepack transformer model.

* merge conflicts.

* make to square.

* fixes
2025-05-11 09:50:06 +05:30
Aryan 92fe689f06 Change Framepack transformer layer initialization order (#11535)
update
2025-05-09 23:09:50 +05:30
Yao Matrix 0ba1f76d4d enable print_env on xpu (#11507)
* detect xpu in print_env

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* enhance code, test passed on XPU

Signed-off-by: Yao Matrix <matrix.yao@intel.com>

---------

Signed-off-by: YAO Matrix <matrix.yao@intel.com>
Signed-off-by: Yao Matrix <matrix.yao@intel.com>
2025-05-09 16:30:51 +05:30
Yao Matrix d6bf268a4a enable dit integration cases on xpu (#11523)
* enable dit integration test on XPU

Signed-off-by: Yao Matrix <matrix.yao@intel.com>

* fix style

Signed-off-by: Yao Matrix <matrix.yao@intel.com>

---------

Signed-off-by: Yao Matrix <matrix.yao@intel.com>
2025-05-09 16:06:50 +05:30
James Xu 3c0a0129fe [LTXPipeline] Update latents dtype to match VAE dtype (#11533)
fix: update latents dtype to match vae
2025-05-09 16:05:21 +05:30
Yao Matrix 2d380895e5 enable 7 cases on XPU (#11503)
* enable 7 cases on XPU

Signed-off-by: Yao Matrix <matrix.yao@intel.com>

* calibrate A100 expectations

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

---------

Signed-off-by: Yao Matrix <matrix.yao@intel.com>
Signed-off-by: YAO Matrix <matrix.yao@intel.com>
2025-05-09 15:52:08 +05:30
Sayak Paul 0c47c954f3 [LoRA] support non-diffusers hidream loras (#11532)
* support non-diffusers hidream loras

* make fix-copies
2025-05-09 14:42:39 +05:30
Sayak Paul 7acf8345f6 [Tests] Enable more general testing for torch.compile() with LoRA hotswapping (#11322)
* refactor hotswap tester.

* fix seeds..

* add to nightly ci.

* move comment.

* move to nightly
2025-05-09 11:29:06 +05:30
Sayak Paul 599c887164 feat: pipeline-level quantization config (#11130)
* feat: pipeline-level quant config.

Co-authored-by: SunMarc <marc.sun@hotmail.fr>

condition better.

support mapping.

improvements.

[Quantization] Add Quanto backend (#10756)

* update

* updaet

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* Update docs/source/en/quantization/quanto.md

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* Update src/diffusers/quantizers/quanto/utils.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* update

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

[Single File] Add single file loading for SANA Transformer (#10947)

* added support for from_single_file

* added diffusers mapping script

* added testcase

* bug fix

* updated tests

* corrected code quality

* corrected code quality

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

[LoRA] Improve warning messages when LoRA loading becomes a no-op (#10187)

* updates

* updates

* updates

* updates

* notebooks revert

* fix-copies.

* seeing

* fix

* revert

* fixes

* fixes

* fixes

* remove print

* fix

* conflicts ii.

* updates

* fixes

* better filtering of prefix.

---------

Co-authored-by: hlky <hlky@hlky.ac>

[LoRA] CogView4 (#10981)

* update

* make fix-copies

* update

[Tests] improve quantization tests by additionally measuring the inference memory savings (#11021)

* memory usage tests

* fixes

* gguf

[`Research Project`] Add AnyText: Multilingual Visual Text Generation And Editing (#8998)

* Add initial template

* Second template

* feat: Add TextEmbeddingModule to AnyTextPipeline

* feat: Add AuxiliaryLatentModule template to AnyTextPipeline

* Add bert tokenizer from the anytext repo for now

* feat: Update AnyTextPipeline's modify_prompt method

This commit adds improvements to the modify_prompt method in the AnyTextPipeline class. The method now handles special characters and replaces selected string prompts with a placeholder. Additionally, it includes a check for Chinese text and translation using the trans_pipe.

* Fill in the `forward` pass of `AuxiliaryLatentModule`

* `make style && make quality`

* `chore: Update bert_tokenizer.py with a TODO comment suggesting the use of the transformers library`

* Update error handling to raise and logging

* Add `create_glyph_lines` function into `TextEmbeddingModule`

* make style

* Up

* Up

* Up

* Up

* Remove several comments

* refactor: Remove ControlNetConditioningEmbedding and update code accordingly

* Up

* Up

* up

* refactor: Update AnyTextPipeline to include new optional parameters

* up

* feat: Add OCR model and its components

* chore: Update `TextEmbeddingModule` to include OCR model components and dependencies

* chore: Update `AuxiliaryLatentModule` to include VAE model and its dependencies for masked image in the editing task

* `make style`

* refactor: Update `AnyTextPipeline`'s docstring

* Update `AuxiliaryLatentModule` to include info dictionary so that text processing is done once

* simplify

* `make style`

* Converting `TextEmbeddingModule` to ordinary `encode_prompt()` function

* Simplify for now

* `make style`

* Up

* feat: Add scripts to convert AnyText controlnet to diffusers

* `make style`

* Fix: Move glyph rendering to `TextEmbeddingModule` from `AuxiliaryLatentModule`

* make style

* Up

* Simplify

* Up

* feat: Add safetensors module for loading model file

* Fix device issues

* Up

* Up

* refactor: Simplify

* refactor: Simplify code for loading models and handling data types

* `make style`

* refactor: Update to() method in FrozenCLIPEmbedderT3 and TextEmbeddingModule

* refactor: Update dtype in embedding_manager.py to match proj.weight

* Up

* Add attribution and adaptation information to pipeline_anytext.py

* Update usage example

* Will refactor `controlnet_cond_embedding` initialization

* Add `AnyTextControlNetConditioningEmbedding` template

* Refactor organization

* style

* style

* Move custom blocks from `AuxiliaryLatentModule` to `AnyTextControlNetConditioningEmbedding`

* Follow one-file policy

* style

* [Docs] Update README and pipeline_anytext.py to use AnyTextControlNetModel

* [Docs] Update import statement for AnyTextControlNetModel in pipeline_anytext.py

* [Fix] Update import path for ControlNetModel, ControlNetOutput in anytext_controlnet.py

* Refactor AnyTextControlNet to use configurable conditioning embedding channels

* Complete control net conditioning embedding in AnyTextControlNetModel

* up

* [FIX] Ensure embeddings use correct device in AnyTextControlNetModel

* up

* up

* style

* [UPDATE] Revise README and example code for AnyTextPipeline integration with DiffusionPipeline

* [UPDATE] Update example code in anytext.py to use correct font file and improve clarity

* down

* [UPDATE] Refactor BasicTokenizer usage to a new Checker class for text processing

* update pillow

* [UPDATE] Remove commented-out code and unnecessary docstring in anytext.py and anytext_controlnet.py for improved clarity

* [REMOVE] Delete frozen_clip_embedder_t3.py as it is in the anytext.py file

* [UPDATE] Replace edict with dict for configuration in anytext.py and RecModel.py for consistency

* 🆙

* style

* [UPDATE] Revise README.md for clarity, remove unused imports in anytext.py, and add author credits in anytext_controlnet.py

* style

* Update examples/research_projects/anytext/README.md

Co-authored-by: Aryan <contact.aryanvs@gmail.com>

* Remove commented-out image preparation code in AnyTextPipeline

* Remove unnecessary blank line in README.md

[Quantization] Allow loading TorchAO serialized Tensor objects with torch>=2.6  (#11018)

* update

* update

* update

* update

* update

* update

* update

* update

* update

fix: mixture tiling sdxl pipeline - adjust gerating time_ids & embeddings  (#11012)

small fix on generating time_ids & embeddings

[LoRA] support wan i2v loras from the world. (#11025)

* support wan i2v loras from the world.

* remove copied from.

* upates

* add lora.

Fix SD3 IPAdapter feature extractor (#11027)

chore: fix help messages in advanced diffusion examples (#10923)

Fix missing **kwargs in lora_pipeline.py (#11011)

* Update lora_pipeline.py

* Apply style fixes

* fix-copies

---------

Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

Fix for multi-GPU WAN inference (#10997)

Ensure that hidden_state and shift/scale are on the same device when running with multiple GPUs

Co-authored-by: Jimmy <39@🇺🇸.com>

[Refactor] Clean up import utils boilerplate (#11026)

* update

* update

* update

Use `output_size` in `repeat_interleave` (#11030)

[hybrid inference 🍯🐝] Add VAE encode (#11017)

* [hybrid inference 🍯🐝] Add VAE encode

* _toctree: add vae encode

* Add endpoints, tests

* vae_encode docs

* vae encode benchmarks

* api reference

* changelog

* Update docs/source/en/hybrid_inference/overview.md

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

Wan Pipeline scaling fix, type hint warning, multi generator fix (#11007)

* Wan Pipeline scaling fix, type hint warning, multi generator fix

* Apply suggestions from code review

[LoRA] change to warning from info when notifying the users about a LoRA no-op (#11044)

* move to warning.

* test related changes.

Rename Lumina(2)Text2ImgPipeline -> Lumina(2)Pipeline (#10827)

* Rename Lumina(2)Text2ImgPipeline -> Lumina(2)Pipeline

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>

making ```formatted_images``` initialization compact (#10801)

compact writing

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>

Fix aclnnRepeatInterleaveIntWithDim error on NPU for get_1d_rotary_pos_embed (#10820)

* get_1d_rotary_pos_embed support npu

* Update src/diffusers/models/embeddings.py

---------

Co-authored-by: Kai zheng <kaizheng@KaideMacBook-Pro.local>
Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: YiYi Xu <yixu310@gmail.com>

[Tests] restrict memory tests for quanto for certain schemes. (#11052)

* restrict memory tests for quanto for certain schemes.

* Apply suggestions from code review

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* fixes

* style

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

[LoRA] feat: support non-diffusers wan t2v loras. (#11059)

feat: support non-diffusers wan t2v loras.

[examples/controlnet/train_controlnet_sd3.py] Fixes #11050 - Cast prompt_embeds and pooled_prompt_embeds to weight_dtype to prevent dtype mismatch (#11051)

Fix: dtype mismatch of prompt embeddings in sd3 controlnet training

Co-authored-by: Andreas Jörg <andreasjoerg@MacBook-Pro-von-Andreas-2.fritz.box>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

reverts accidental change that removes attn_mask in attn. Improves fl… (#11065)

reverts accidental change that removes attn_mask in attn. Improves flux ptxla by using flash block sizes. Moves encoding outside the for loop.

Co-authored-by: Juan Acevedo <jfacevedo@google.com>

Fix deterministic issue when getting pipeline dtype and device (#10696)

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

[Tests] add requires peft decorator. (#11037)

* add requires peft decorator.

* install peft conditionally.

* conditional deps.

Co-authored-by: DN6 <dhruv.nair@gmail.com>

---------

Co-authored-by: DN6 <dhruv.nair@gmail.com>

CogView4 Control Block (#10809)

* cogview4 control training

---------

Co-authored-by: OleehyO <leehy0357@gmail.com>
Co-authored-by: yiyixuxu <yixu310@gmail.com>

[CI] pin transformers version for benchmarking. (#11067)

pin transformers version for benchmarking.

updates

Fix Wan I2V Quality (#11087)

* fix_wan_i2v_quality

* Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update pipeline_wan_i2v.py

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: hlky <hlky@hlky.ac>

LTX 0.9.5 (#10968)

* update

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: hlky <hlky@hlky.ac>

make PR GPU tests conditioned on styling. (#11099)

Group offloading improvements (#11094)

update

Fix pipeline_flux_controlnet.py (#11095)

* Fix pipeline_flux_controlnet.py

* Fix style

update readme instructions. (#11096)

Co-authored-by: Juan Acevedo <jfacevedo@google.com>

Resolve stride mismatch in UNet's ResNet to support Torch DDP (#11098)

Modify UNet's ResNet implementation to resolve stride mismatch in Torch's DDP

Fix Group offloading behaviour when using streams (#11097)

* update

* update

Quality options in `export_to_video` (#11090)

* Quality options in `export_to_video`

* make style

improve more.

add placeholders for docstrings.

formatting.

smol fix.

solidify validation and annotation

* Revert "feat: pipeline-level quant config."

This reverts commit 316ff46b76.

* feat: implement pipeline-level quantization config

Co-authored-by: SunMarc <marc@huggingface.co>

* update

* fixes

* fix validation.

* add tests and other improvements.

* add tests

* import quality

* remove prints.

* add docs.

* fixes to docs.

* doc fixes.

* doc fixes.

* add validation to the input quantization_config.

* clarify recommendations.

* docs

* add to ci.

* todo.

---------

Co-authored-by: SunMarc <marc@huggingface.co>
2025-05-09 10:04:44 +05:30
Sayak Paul 393aefcdc7 [tests] fix audioldm2 for transformers main. (#11522)
fix audioldm2 for transformers main.
2025-05-08 21:13:42 +05:30
Aryan 6674a5157f Conditionally import torchvision in Cosmos transformer (#11524)
fix
2025-05-08 19:37:47 +05:30
scxue 784db0eaab Add cross attention type for Sana-Sprint training in diffusers. (#11514)
* test permission

* Add cross attention type for Sana-Sprint.

* Add Sana-Sprint training script in diffusers.

* make style && make quality;

* modify the attention processor with `set_attn_processor` and change `SanaAttnProcessor3_0` to `SanaVanillaAttnProcessor`

* Add import for SanaVanillaAttnProcessor

* Add README file.

* Apply suggestions from code review

* style

* Update examples/research_projects/sana/README.md

---------

Co-authored-by: lawrence-cj <cjs1020440147@icloud.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-05-08 18:55:29 +05:30
Linoy Tsaban 66e50d4e24 [LoRA] make lora alpha and dropout configurable (#11467)
* add lora_alpha and lora_dropout

* Apply style fixes

* add lora_alpha and lora_dropout

* Apply style fixes

* revert lora_alpha until #11324 is merged

* Apply style fixes

* empty commit

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-05-08 11:54:50 +03:00
sayakpaul c5c34a4591 Revert "fix audioldm"
This reverts commit 87e508f11f.
2025-05-08 11:30:29 +05:30
sayakpaul 87e508f11f fix audioldm 2025-05-08 11:30:11 +05:30
YiYi Xu 53bd367b03 clean up the __Init__ for stable_diffusion (#11500)
up
2025-05-07 07:01:17 -10:00
Aryan 7b904941bc Cosmos (#10660)
* begin transformer conversion

* refactor

* refactor

* refactor

* refactor

* refactor

* refactor

* update

* add conversion script

* add pipeline

* make fix-copies

* remove einops

* update docs

* gradient checkpointing

* add transformer test

* update

* debug

* remove prints

* match sigmas

* add vae pt. 1

* finish CV* vae

* update

* update

* update

* update

* update

* update

* make fix-copies

* update

* make fix-copies

* fix

* update

* update

* make fix-copies

* update

* update tests

* handle device and dtype for safety checker; required in latest diffusers

* remove enable_gqa and use repeat_interleave instead

* enforce safety checker; use dummy checker in fast tests

* add review suggestion for ONNX export

Co-Authored-By: Asfiya Baig <asfiyab@nvidia.com>

* fix safety_checker issues when not passed explicitly

We could either do what's done in this commit, or update the Cosmos examples to explicitly pass the safety checker

* use cosmos guardrail package

* auto format docs

* update conversion script to support 14B models

* update name CosmosPipeline -> CosmosTextToWorldPipeline

* update docs

* fix docs

* fix group offload test failing for vae

---------

Co-authored-by: Asfiya Baig <asfiyab@nvidia.com>
2025-05-07 20:59:09 +05:30
Sayak Paul fb29132b98 [docs] minor updates to bitsandbytes docs. (#11509)
* minor updates to bitsandbytes docs.

* Apply suggestions from code review
2025-05-06 18:52:18 +05:30
Valeriy Selitskiy 79371661d1 [lora_conversion] Enhance key handling for OneTrainer components in LORA conversion utility (#11441) (#11487)
* [lora_conversion] Enhance key handling for OneTrainer components in LORA conversion utility (#11441)

* Update src/diffusers/loaders/lora_conversion_utils.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-05-06 18:44:58 +05:30
Yao Matrix 8c661ea586 enable lora cases on XPU (#11506)
* enable lora cases on XPU

Signed-off-by: Yao Matrix <matrix.yao@intel.com>

* remove hunyuanvideo xpu expectation

Signed-off-by: Yao Matrix <matrix.yao@intel.com>

---------

Signed-off-by: Yao Matrix <matrix.yao@intel.com>
2025-05-06 14:59:50 +05:30
Aryan d7ffe60166 Hunyuan Video Framepack (#11428)
* add transformer

* add pipeline

* fixes

* make fix-copies

* update

* add flux mu shift

* update example snippet

* debug

* cleanup

* batch_size=1 optimization

* add pipeline test

* fix for model cpu offloading'

* add last_image support; credits: https://github.com/lllyasviel/FramePack/pull/167

* update example with flf2v

* update penguin url

* fix test

* address review comment: https://github.com/huggingface/diffusers/pull/11428#discussion_r2071032371

* address review comment: https://github.com/huggingface/diffusers/pull/11428#discussion_r2071087689

* Update src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py

---------

Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
2025-05-06 14:59:38 +05:30
Sayak Paul 10bee525e7 [LoRA] use removeprefix to preserve sanity. (#11493)
* use removeprefix to preserve sanity.

* f-string.
2025-05-06 12:17:57 +05:30
Sayak Paul d88ae1f52a update dep table. (#11504)
* update dep table.

* fix
2025-05-06 11:14:07 +05:30
135 changed files with 14960 additions and 370 deletions
+55
View File
@@ -142,6 +142,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
RUN_COMPILE: yes
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
@@ -525,6 +526,60 @@ jobs:
pip install slack_sdk tabulate
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
run_nightly_pipeline_level_quantization_tests:
name: Torch quantization nightly tests
strategy:
fail-fast: false
max-parallel: 2
runs-on:
group: aws-g6e-xlarge-plus
container:
image: diffusers/diffusers-pytorch-cuda
options: --shm-size "20gb" --ipc host --gpus 0
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: NVIDIA-SMI
run: nvidia-smi
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
python -m uv pip install -U bitsandbytes optimum_quanto
python -m uv pip install pytest-reportlog
- name: Environment
run: |
python utils/print_env.py
- name: Pipeline-level quantization tests on GPU
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
BIG_GPU_MEMORY: 40
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
--make-reports=tests_pipeline_level_quant_torch_cuda \
--report-log=tests_pipeline_level_quant_torch_cuda.log \
tests/quantization/test_pipeline_level_quantization.py
- name: Failure short reports
if: ${{ failure() }}
run: |
cat reports/tests_pipeline_level_quant_torch_cuda_stats.txt
cat reports/tests_pipeline_level_quant_torch_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: torch_cuda_pipeline_level_quant_reports
path: reports
- name: Generate Report and Notify Channel
if: always()
run: |
pip install slack_sdk tabulate
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
# M1 runner currently not well supported
# TODO: (Dhruv) add these back when we setup better testing for Apple Silicon
# run_nightly_tests_apple_m1:
+10
View File
@@ -295,6 +295,8 @@
title: CogView4Transformer2DModel
- local: api/models/consisid_transformer3d
title: ConsisIDTransformer3DModel
- local: api/models/cosmos_transformer3d
title: CosmosTransformer3DModel
- local: api/models/dit_transformer2d
title: DiTTransformer2DModel
- local: api/models/easyanimate_transformer3d
@@ -363,6 +365,8 @@
title: AutoencoderKLAllegro
- local: api/models/autoencoderkl_cogvideox
title: AutoencoderKLCogVideoX
- local: api/models/autoencoderkl_cosmos
title: AutoencoderKLCosmos
- local: api/models/autoencoder_kl_hunyuan_video
title: AutoencoderKLHunyuanVideo
- local: api/models/autoencoderkl_ltx_video
@@ -433,6 +437,8 @@
title: ControlNet-XS with Stable Diffusion XL
- local: api/pipelines/controlnet_union
title: ControlNetUnion
- local: api/pipelines/cosmos
title: Cosmos
- local: api/pipelines/dance_diffusion
title: Dance Diffusion
- local: api/pipelines/ddim
@@ -451,6 +457,8 @@
title: Flux
- local: api/pipelines/control_flux_inpaint
title: FluxControlInpaint
- local: api/pipelines/framepack
title: Framepack
- local: api/pipelines/hidream
title: HiDream-I1
- local: api/pipelines/hunyuandit
@@ -567,6 +575,8 @@
title: UniDiffuser
- local: api/pipelines/value_guided_sampling
title: Value-guided sampling
- local: api/pipelines/visualcloze
title: VisualCloze
- local: api/pipelines/wan
title: Wan
- local: api/pipelines/wuerstchen
@@ -0,0 +1,40 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderKLCosmos
[Cosmos Tokenizers](https://github.com/NVIDIA/Cosmos-Tokenizer).
Supported models:
- [nvidia/Cosmos-1.0-Tokenizer-CV8x8x8](https://huggingface.co/nvidia/Cosmos-1.0-Tokenizer-CV8x8x8)
The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLCosmos
vae = AutoencoderKLCosmos.from_pretrained("nvidia/Cosmos-1.0-Tokenizer-CV8x8x8", subfolder="vae")
```
## AutoencoderKLCosmos
[[autodoc]] AutoencoderKLCosmos
- decode
- encode
- all
## AutoencoderKLOutput
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
## DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput
@@ -0,0 +1,30 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# CosmosTransformer3DModel
A Diffusion Transformer model for 3D video-like data was introduced in [Cosmos World Foundation Model Platform for Physical AI](https://huggingface.co/papers/2501.03575) by NVIDIA.
The model can be loaded with the following code snippet.
```python
from diffusers import CosmosTransformer3DModel
transformer = CosmosTransformer3DModel.from_pretrained("nvidia/Cosmos-1.0-Diffusion-7B-Text2World", subfolder="transformer", torch_dtype=torch.bfloat16)
```
## CosmosTransformer3DModel
[[autodoc]] CosmosTransformer3DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
@@ -21,6 +21,22 @@ from diffusers import HiDreamImageTransformer2DModel
transformer = HiDreamImageTransformer2DModel.from_pretrained("HiDream-ai/HiDream-I1-Full", subfolder="transformer", torch_dtype=torch.bfloat16)
```
## Loading GGUF quantized checkpoints for HiDream-I1
GGUF checkpoints for the `HiDreamImageTransformer2DModel` can be loaded using `~FromOriginalModelMixin.from_single_file`
```python
import torch
from diffusers import GGUFQuantizationConfig, HiDreamImageTransformer2DModel
ckpt_path = "https://huggingface.co/city96/HiDream-I1-Dev-gguf/blob/main/hidream-i1-dev-Q2_K.gguf"
transformer = HiDreamImageTransformer2DModel.from_single_file(
ckpt_path,
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
torch_dtype=torch.bfloat16
)
```
## HiDreamImageTransformer2DModel
[[autodoc]] HiDreamImageTransformer2DModel
+41
View File
@@ -0,0 +1,41 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
# Cosmos
[Cosmos World Foundation Model Platform for Physical AI](https://huggingface.co/papers/2501.03575) by NVIDIA.
*Physical AI needs to be trained digitally first. It needs a digital twin of itself, the policy model, and a digital twin of the world, the world model. In this paper, we present the Cosmos World Foundation Model Platform to help developers build customized world models for their Physical AI setups. We position a world foundation model as a general-purpose world model that can be fine-tuned into customized world models for downstream applications. Our platform covers a video curation pipeline, pre-trained world foundation models, examples of post-training of pre-trained world foundation models, and video tokenizers. To help Physical AI builders solve the most critical problems of our society, we make our platform open-source and our models open-weight with permissive licenses available via https://github.com/NVIDIA/Cosmos.*
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## CosmosTextToWorldPipeline
[[autodoc]] CosmosTextToWorldPipeline
- all
- __call__
## CosmosVideoToWorldPipeline
[[autodoc]] CosmosVideoToWorldPipeline
- all
- __call__
## CosmosPipelineOutput
[[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput
+209
View File
@@ -0,0 +1,209 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
# Framepack
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
</div>
[Packing Input Frame Context in Next-Frame Prediction Models for Video Generation](https://arxiv.org/abs/2504.12626) by Lvmin Zhang and Maneesh Agrawala.
*We present a neural network structure, FramePack, to train next-frame (or next-frame-section) prediction models for video generation. The FramePack compresses input frames to make the transformer context length a fixed number regardless of the video length. As a result, we are able to process a large number of frames using video diffusion with computation bottleneck similar to image diffusion. This also makes the training video batch sizes significantly higher (batch sizes become comparable to image diffusion training). We also propose an anti-drifting sampling method that generates frames in inverted temporal order with early-established endpoints to avoid exposure bias (error accumulation over iterations). Finally, we show that existing video diffusion models can be finetuned with FramePack, and their visual quality may be improved because the next-frame prediction supports more balanced diffusion schedulers with less extreme flow shift timesteps.*
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## Available models
| Model name | Description |
|:---|:---|
- [`lllyasviel/FramePackI2V_HY`](https://huggingface.co/lllyasviel/FramePackI2V_HY) | Trained with the "inverted anti-drifting" strategy as described in the paper. Inference requires setting `sampling_type="inverted_anti_drifting"` when running the pipeline. |
- [`lllyasviel/FramePack_F1_I2V_HY_20250503`](https://huggingface.co/lllyasviel/FramePack_F1_I2V_HY_20250503) | Trained with a novel anti-drifting strategy but inference is performed in "vanilla" strategy as described in the paper. Inference requires setting `sampling_type="vanilla"` when running the pipeline. |
## Usage
Refer to the pipeline documentation for basic usage examples. The following section contains examples of offloading, different sampling methods, quantization, and more.
### First and last frame to video
The following example shows how to use Framepack with start and end image controls, using the inverted anti-drifiting sampling model.
```python
import torch
from diffusers import HunyuanVideoFramepackPipeline, HunyuanVideoFramepackTransformer3DModel
from diffusers.utils import export_to_video, load_image
from transformers import SiglipImageProcessor, SiglipVisionModel
transformer = HunyuanVideoFramepackTransformer3DModel.from_pretrained(
"lllyasviel/FramePackI2V_HY", torch_dtype=torch.bfloat16
)
feature_extractor = SiglipImageProcessor.from_pretrained(
"lllyasviel/flux_redux_bfl", subfolder="feature_extractor"
)
image_encoder = SiglipVisionModel.from_pretrained(
"lllyasviel/flux_redux_bfl", subfolder="image_encoder", torch_dtype=torch.float16
)
pipe = HunyuanVideoFramepackPipeline.from_pretrained(
"hunyuanvideo-community/HunyuanVideo",
transformer=transformer,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
torch_dtype=torch.float16,
)
# Enable memory optimizations
pipe.enable_model_cpu_offload()
pipe.vae.enable_tiling()
prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
first_image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png"
)
last_image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png"
)
output = pipe(
image=first_image,
last_image=last_image,
prompt=prompt,
height=512,
width=512,
num_frames=91,
num_inference_steps=30,
guidance_scale=9.0,
generator=torch.Generator().manual_seed(0),
sampling_type="inverted_anti_drifting",
).frames[0]
export_to_video(output, "output.mp4", fps=30)
```
### Vanilla sampling
The following example shows how to use Framepack with the F1 model trained with vanilla sampling but new regulation approach for anti-drifting.
```python
import torch
from diffusers import HunyuanVideoFramepackPipeline, HunyuanVideoFramepackTransformer3DModel
from diffusers.utils import export_to_video, load_image
from transformers import SiglipImageProcessor, SiglipVisionModel
transformer = HunyuanVideoFramepackTransformer3DModel.from_pretrained(
"lllyasviel/FramePack_F1_I2V_HY_20250503", torch_dtype=torch.bfloat16
)
feature_extractor = SiglipImageProcessor.from_pretrained(
"lllyasviel/flux_redux_bfl", subfolder="feature_extractor"
)
image_encoder = SiglipVisionModel.from_pretrained(
"lllyasviel/flux_redux_bfl", subfolder="image_encoder", torch_dtype=torch.float16
)
pipe = HunyuanVideoFramepackPipeline.from_pretrained(
"hunyuanvideo-community/HunyuanVideo",
transformer=transformer,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
torch_dtype=torch.float16,
)
# Enable memory optimizations
pipe.enable_model_cpu_offload()
pipe.vae.enable_tiling()
image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png"
)
output = pipe(
image=image,
prompt="A penguin dancing in the snow",
height=832,
width=480,
num_frames=91,
num_inference_steps=30,
guidance_scale=9.0,
generator=torch.Generator().manual_seed(0),
sampling_type="vanilla",
).frames[0]
export_to_video(output, "output.mp4", fps=30)
```
### Group offloading
Group offloading ([`~hooks.apply_group_offloading`]) provides aggressive memory optimizations for offloading internal parts of any model to the CPU, with possibly no additional overhead to generation time. If you have very low VRAM available, this approach may be suitable for you depending on the amount of CPU RAM available.
```python
import torch
from diffusers import HunyuanVideoFramepackPipeline, HunyuanVideoFramepackTransformer3DModel
from diffusers.hooks import apply_group_offloading
from diffusers.utils import export_to_video, load_image
from transformers import SiglipImageProcessor, SiglipVisionModel
transformer = HunyuanVideoFramepackTransformer3DModel.from_pretrained(
"lllyasviel/FramePack_F1_I2V_HY_20250503", torch_dtype=torch.bfloat16
)
feature_extractor = SiglipImageProcessor.from_pretrained(
"lllyasviel/flux_redux_bfl", subfolder="feature_extractor"
)
image_encoder = SiglipVisionModel.from_pretrained(
"lllyasviel/flux_redux_bfl", subfolder="image_encoder", torch_dtype=torch.float16
)
pipe = HunyuanVideoFramepackPipeline.from_pretrained(
"hunyuanvideo-community/HunyuanVideo",
transformer=transformer,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
torch_dtype=torch.float16,
)
# Enable group offloading
onload_device = torch.device("cuda")
offload_device = torch.device("cpu")
list(map(
lambda x: apply_group_offloading(x, onload_device, offload_device, offload_type="leaf_level", use_stream=True, low_cpu_mem_usage=True),
[pipe.text_encoder, pipe.text_encoder_2, pipe.transformer]
))
pipe.image_encoder.to(onload_device)
pipe.vae.to(onload_device)
pipe.vae.enable_tiling()
image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/penguin.png"
)
output = pipe(
image=image,
prompt="A penguin dancing in the snow",
height=832,
width=480,
num_frames=91,
num_inference_steps=30,
guidance_scale=9.0,
generator=torch.Generator().manual_seed(0),
sampling_type="vanilla",
).frames[0]
print(f"Max memory: {torch.cuda.max_memory_allocated() / 1024**3:.3f} GB")
export_to_video(output, "output.mp4", fps=30)
```
## HunyuanVideoFramepackPipeline
[[autodoc]] HunyuanVideoFramepackPipeline
- all
- __call__
## HunyuanVideoPipelineOutput
[[autodoc]] pipelines.hunyuan_video.pipeline_output.HunyuanVideoPipelineOutput
+100 -3
View File
@@ -31,12 +31,103 @@ Available models:
| Model name | Recommended dtype |
|:-------------:|:-----------------:|
| [`LTX Video 0.9.0`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.safetensors) | `torch.bfloat16` |
| [`LTX Video 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` |
| [`LTX Video 0.9.5`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.5.safetensors) | `torch.bfloat16` |
| [`LTX Video 2B 0.9.0`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.safetensors) | `torch.bfloat16` |
| [`LTX Video 2B 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` |
| [`LTX Video 2B 0.9.5`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.5.safetensors) | `torch.bfloat16` |
| [`LTX Video 13B 0.9.7`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-13b-0.9.7-dev.safetensors) | `torch.bfloat16` |
| [`LTX Video Spatial Upscaler 0.9.7`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltxv-spatial-upscaler-0.9.7.safetensors) | `torch.bfloat16` |
Note: The recommended dtype is for the transformer component. The VAE and text encoders can be either `torch.float32`, `torch.bfloat16` or `torch.float16` but the recommended dtype is `torch.bfloat16` as used in the original repository.
## Recommended settings for generation
For the best results, it is recommended to follow the guidelines mentioned in the official LTX Video [repository](https://github.com/Lightricks/LTX-Video).
- Some variants of LTX Video are guidance-distilled. For guidance-distilled models, `guidance_scale` must be set to `1.0`. For any other models, `guidance_scale` should be set higher (e.g., `5.0`) for good generation quality.
- For variants with a timestep-aware VAE (LTXV 0.9.1 and above), it is recommended to set `decode_timestep` to `0.05` and `image_cond_noise_scale` to `0.025`.
- For variants that support interpolation between multiple conditioning images and videos (LTXV 0.9.5 and above), it is recommended to use similar looking images/videos for the best results. High divergence between the conditionings may lead to abrupt transitions in the generated video.
## Using LTX Video 13B 0.9.7
LTX Video 0.9.7 comes with a spatial latent upscaler and a 13B parameter transformer. The inference involves generating a low resolution video first, which is very fast, followed by upscaling and refining the generated video.
<!-- TODO(aryan): modify when official checkpoints are available -->
```python
import torch
from diffusers import LTXConditionPipeline, LTXLatentUpsamplePipeline
from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXVideoCondition
from diffusers.utils import export_to_video, load_video
pipe = LTXConditionPipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.7-diffusers", torch_dtype=torch.bfloat16)
pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained("a-r-r-o-w/LTX-Video-0.9.7-Latent-Spatial-Upsampler-diffusers", vae=pipe.vae, torch_dtype=torch.bfloat16)
pipe.to("cuda")
pipe_upsample.to("cuda")
pipe.vae.enable_tiling()
def round_to_nearest_resolution_acceptable_by_vae(height, width):
height = height - (height % pipe.vae_temporal_compression_ratio)
width = width - (width % pipe.vae_temporal_compression_ratio)
return height, width
video = load_video(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4"
)[:21] # Use only the first 21 frames as conditioning
condition1 = LTXVideoCondition(video=video, frame_index=0)
prompt = "The video depicts a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region."
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
expected_height, expected_width = 768, 1152
downscale_factor = 2 / 3
num_frames = 161
# Part 1. Generate video at smaller resolution
# Text-only conditioning is also supported without the need to pass `conditions`
downscaled_height, downscaled_width = int(expected_height * downscale_factor), int(expected_width * downscale_factor)
downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae(downscaled_height, downscaled_width)
latents = pipe(
conditions=[condition1],
prompt=prompt,
negative_prompt=negative_prompt,
width=downscaled_width,
height=downscaled_height,
num_frames=num_frames,
num_inference_steps=30,
generator=torch.Generator().manual_seed(0),
output_type="latent",
).frames
# Part 2. Upscale generated video using latent upsampler with fewer inference steps
# The available latent upsampler upscales the height/width by 2x
upscaled_height, upscaled_width = downscaled_height * 2, downscaled_width * 2
upscaled_latents = pipe_upsample(
latents=latents,
output_type="latent"
).frames
# Part 3. Denoise the upscaled video with few steps to improve texture (optional, but recommended)
video = pipe(
conditions=[condition1],
prompt=prompt,
negative_prompt=negative_prompt,
width=upscaled_width,
height=upscaled_height,
num_frames=num_frames,
denoise_strength=0.4, # Effectively, 4 inference steps out of 10
num_inference_steps=10,
latents=upscaled_latents,
decode_timestep=0.05,
image_cond_noise_scale=0.025,
generator=torch.Generator().manual_seed(0),
output_type="pil",
).frames[0]
# Part 4. Downscale the video to the expected resolution
video = [frame.resize((expected_width, expected_height)) for frame in video]
export_to_video(video, "output.mp4", fps=24)
```
## Loading Single Files
Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`]. We recommend using `from_single_file` for the Lightricks series of models, as they plan to release multiple models in the future in the single file format.
@@ -204,6 +295,12 @@ export_to_video(video, "ship.mp4", fps=24)
- all
- __call__
## LTXLatentUpsamplePipeline
[[autodoc]] LTXLatentUpsamplePipeline
- all
- __call__
## LTXPipelineOutput
[[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput
+1
View File
@@ -89,6 +89,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
| [UniDiffuser](unidiffuser) | text2image, image2text, image variation, text variation, unconditional image generation, unconditional audio generation |
| [Value-guided planning](value_guided_sampling) | value guided sampling |
| [Wuerstchen](wuerstchen) | text2image |
| [VisualCloze](visualcloze) | text2image, image2image, subject driven generation, inpainting, style transfer, image restoration, image editing, [depth,normal,edge,pose]2image, [depth,normal,edge,pose]-estimation, virtual try-on, image relighting |
## DiffusionPipeline
+300
View File
@@ -0,0 +1,300 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-->
# VisualCloze
[VisualCloze: A Universal Image Generation Framework via Visual In-Context Learning](https://arxiv.org/abs/2504.07960) is an innovative in-context learning based universal image generation framework that offers key capabilities:
1. Support for various in-domain tasks
2. Generalization to unseen tasks through in-context learning
3. Unify multiple tasks into one step and generate both target image and intermediate results
4. Support reverse-engineering conditions from target images
## Overview
The abstract from the paper is:
*Recent progress in diffusion models significantly advances various image generation tasks. However, the current mainstream approach remains focused on building task-specific models, which have limited efficiency when supporting a wide range of different needs. While universal models attempt to address this limitation, they face critical challenges, including generalizable task instruction, appropriate task distributions, and unified architectural design. To tackle these challenges, we propose VisualCloze, a universal image generation framework, which supports a wide range of in-domain tasks, generalization to unseen ones, unseen unification of multiple tasks, and reverse generation. Unlike existing methods that rely on language-based task instruction, leading to task ambiguity and weak generalization, we integrate visual in-context learning, allowing models to identify tasks from visual demonstrations. Meanwhile, the inherent sparsity of visual task distributions hampers the learning of transferable knowledge across tasks. To this end, we introduce Graph200K, a graph-structured dataset that establishes various interrelated tasks, enhancing task density and transferable knowledge. Furthermore, we uncover that our unified image generation formulation shared a consistent objective with image infilling, enabling us to leverage the strong generative priors of pre-trained infilling models without modifying the architectures. The codes, dataset, and models are available at https://visualcloze.github.io.*
## Inference
### Model loading
VisualCloze is a two-stage cascade pipeline, containing `VisualClozeGenerationPipeline` and `VisualClozeUpsamplingPipeline`.
- In `VisualClozeGenerationPipeline`, each image is downsampled before concatenating images into a grid layout, avoiding excessively high resolutions. VisualCloze releases two models suitable for diffusers, i.e., [VisualClozePipeline-384](https://huggingface.co/VisualCloze/VisualClozePipeline-384) and [VisualClozePipeline-512](https://huggingface.co/VisualCloze/VisualClozePipeline-384), which downsample images to resolutions of 384 and 512, respectively.
- `VisualClozeUpsamplingPipeline` uses [SDEdit](https://arxiv.org/abs/2108.01073) to enable high-resolution image synthesis.
The `VisualClozePipeline` integrates both stages to support convenient end-to-end sampling, while also allowing users to utilize each pipeline independently as needed.
### Input Specifications
#### Task and Content Prompts
- Task prompt: Required to describe the generation task intention
- Content prompt: Optional description or caption of the target image
- When content prompt is not needed, pass `None`
- For batch inference, pass `List[str|None]`
#### Image Input Format
- Format: `List[List[Image|None]]`
- Structure:
- All rows except the last represent in-context examples
- Last row represents the current query (target image set to `None`)
- For batch inference, pass `List[List[List[Image|None]]]`
#### Resolution Control
- Default behavior:
- Initial generation in the first stage: area of ${pipe.resolution}^2$
- Upsampling in the second stage: 3x factor
- Custom resolution: Adjust using `upsampling_height` and `upsampling_width` parameters
### Examples
For comprehensive examples covering a wide range of tasks, please refer to the [Online Demo](https://huggingface.co/spaces/VisualCloze/VisualCloze) and [GitHub Repository](https://github.com/lzyhha/VisualCloze). Below are simple examples for three cases: mask-to-image conversion, edge detection, and subject-driven generation.
#### Example for mask2image
```python
import torch
from diffusers import VisualClozePipeline
from diffusers.utils import load_image
pipe = VisualClozePipeline.from_pretrained("VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Load in-context images (make sure the paths are correct and accessible)
image_paths = [
# in-context examples
[
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_mask.jpg'),
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_image.jpg'),
],
# query with the target image
[
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_query_mask.jpg'),
None, # No image needed for the target image
],
]
# Task and content prompt
task_prompt = "In each row, a logical task is demonstrated to achieve [IMAGE2] an aesthetically pleasing photograph based on [IMAGE1] sam 2-generated masks with rich color coding."
content_prompt = """Majestic photo of a golden eagle perched on a rocky outcrop in a mountainous landscape.
The eagle is positioned in the right foreground, facing left, with its sharp beak and keen eyes prominently visible.
Its plumage is a mix of dark brown and golden hues, with intricate feather details.
The background features a soft-focus view of snow-capped mountains under a cloudy sky, creating a serene and grandiose atmosphere.
The foreground includes rugged rocks and patches of green moss. Photorealistic, medium depth of field,
soft natural lighting, cool color palette, high contrast, sharp focus on the eagle, blurred background,
tranquil, majestic, wildlife photography."""
# Run the pipeline
image_result = pipe(
task_prompt=task_prompt,
content_prompt=content_prompt,
image=image_paths,
upsampling_width=1344,
upsampling_height=768,
upsampling_strength=0.4,
guidance_scale=30,
num_inference_steps=30,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(0)
).images[0][0]
# Save the resulting image
image_result.save("visualcloze.png")
```
#### Example for edge-detection
```python
import torch
from diffusers import VisualClozePipeline
from diffusers.utils import load_image
pipe = VisualClozePipeline.from_pretrained("VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Load in-context images (make sure the paths are correct and accessible)
image_paths = [
# in-context examples
[
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_edgedetection_incontext-example-1_image.jpg'),
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_edgedetection_incontext-example-1_edge.jpg'),
],
[
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_edgedetection_incontext-example-2_image.jpg'),
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_edgedetection_incontext-example-2_edge.jpg'),
],
# query with the target image
[
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_edgedetection_query_image.jpg'),
None, # No image needed for the target image
],
]
# Task and content prompt
task_prompt = "Each row illustrates a pathway from [IMAGE1] a sharp and beautifully composed photograph to [IMAGE2] edge map with natural well-connected outlines using a clear logical task."
content_prompt = ""
# Run the pipeline
image_result = pipe(
task_prompt=task_prompt,
content_prompt=content_prompt,
image=image_paths,
upsampling_width=864,
upsampling_height=1152,
upsampling_strength=0.4,
guidance_scale=30,
num_inference_steps=30,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(0)
).images[0][0]
# Save the resulting image
image_result.save("visualcloze.png")
```
#### Example for subject-driven generation
```python
import torch
from diffusers import VisualClozePipeline
from diffusers.utils import load_image
pipe = VisualClozePipeline.from_pretrained("VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Load in-context images (make sure the paths are correct and accessible)
image_paths = [
# in-context examples
[
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-1_reference.jpg'),
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-1_depth.jpg'),
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-1_image.jpg'),
],
[
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-2_reference.jpg'),
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-2_depth.jpg'),
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_incontext-example-2_image.jpg'),
],
# query with the target image
[
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_query_reference.jpg'),
load_image('https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_subjectdriven_query_depth.jpg'),
None, # No image needed for the target image
],
]
# Task and content prompt
task_prompt = """Each row describes a process that begins with [IMAGE1] an image containing the key object,
[IMAGE2] depth map revealing gray-toned spatial layers and results in
[IMAGE3] an image with artistic qualitya high-quality image with exceptional detail."""
content_prompt = """A vintage porcelain collector's item. Beneath a blossoming cherry tree in early spring,
this treasure is photographed up close, with soft pink petals drifting through the air and vibrant blossoms framing the scene."""
# Run the pipeline
image_result = pipe(
task_prompt=task_prompt,
content_prompt=content_prompt,
image=image_paths,
upsampling_width=1024,
upsampling_height=1024,
upsampling_strength=0.2,
guidance_scale=30,
num_inference_steps=30,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(0)
).images[0][0]
# Save the resulting image
image_result.save("visualcloze.png")
```
#### Utilize each pipeline independently
```python
import torch
from diffusers import VisualClozeGenerationPipeline, FluxFillPipeline as VisualClozeUpsamplingPipeline
from diffusers.utils import load_image
from PIL import Image
pipe = VisualClozeGenerationPipeline.from_pretrained(
"VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16
)
pipe.to("cuda")
image_paths = [
# in-context examples
[
load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_mask.jpg"
),
load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_image.jpg"
),
],
# query with the target image
[
load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_query_mask.jpg"
),
None, # No image needed for the target image
],
]
task_prompt = "In each row, a logical task is demonstrated to achieve [IMAGE2] an aesthetically pleasing photograph based on [IMAGE1] sam 2-generated masks with rich color coding."
content_prompt = "Majestic photo of a golden eagle perched on a rocky outcrop in a mountainous landscape. The eagle is positioned in the right foreground, facing left, with its sharp beak and keen eyes prominently visible. Its plumage is a mix of dark brown and golden hues, with intricate feather details. The background features a soft-focus view of snow-capped mountains under a cloudy sky, creating a serene and grandiose atmosphere. The foreground includes rugged rocks and patches of green moss. Photorealistic, medium depth of field, soft natural lighting, cool color palette, high contrast, sharp focus on the eagle, blurred background, tranquil, majestic, wildlife photography."
# Stage 1: Generate initial image
image = pipe(
task_prompt=task_prompt,
content_prompt=content_prompt,
image=image_paths,
guidance_scale=30,
num_inference_steps=30,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(0),
).images[0][0]
# Stage 2 (optional): Upsample the generated image
pipe_upsample = VisualClozeUpsamplingPipeline.from_pipe(pipe)
pipe_upsample.to("cuda")
mask_image = Image.new("RGB", image.size, (255, 255, 255))
image = pipe_upsample(
image=image,
mask_image=mask_image,
prompt=content_prompt,
width=1344,
height=768,
strength=0.4,
guidance_scale=30,
num_inference_steps=30,
max_sequence_length=512,
generator=torch.Generator("cpu").manual_seed(0),
).images[0]
image.save("visualcloze.png")
```
## VisualClozePipeline
[[autodoc]] VisualClozePipeline
- all
- __call__
## VisualClozeGenerationPipeline
[[autodoc]] VisualClozeGenerationPipeline
- all
- __call__
+4 -3
View File
@@ -13,9 +13,7 @@ specific language governing permissions and limitations under the License.
# Quantization
Quantization techniques reduce memory and computational costs by representing weights and activations with lower-precision data types like 8-bit integers (int8). This enables loading larger models you normally wouldn't be able to fit into memory, and speeding up inference. Diffusers supports 8-bit and 4-bit quantization with [bitsandbytes](https://huggingface.co/docs/bitsandbytes/en/index).
Quantization techniques that aren't supported in Transformers can be added with the [`DiffusersQuantizer`] class.
Quantization techniques reduce memory and computational costs by representing weights and activations with lower-precision data types like 8-bit integers (int8). This enables loading larger models you normally wouldn't be able to fit into memory, and speeding up inference.
<Tip>
@@ -23,6 +21,9 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui
</Tip>
## PipelineQuantizationConfig
[[autodoc]] quantizers.PipelineQuantizationConfig
## BitsAndBytesConfig
+17
View File
@@ -78,6 +78,23 @@ For more information and different options about `torch.compile`, refer to the [
> [!TIP]
> Learn more about other ways PyTorch 2.0 can help optimize your model in the [Accelerate inference of text-to-image diffusion models](../tutorials/fast_diffusion) tutorial.
### Regional compilation
Compiling the whole model usually has a big problem space for optimization. Models are often composed of multiple repeated blocks. [Regional compilation](https://pytorch.org/tutorials/recipes/regional_compilation.html) compiles the repeated block first (a transformer encoder block, for example), so that the Torch compiler would re-use its cached/optimized generated code for the other blocks, reducing (often massively) the cold start compilation time observed on the first inference call.
Enabling regional compilation might require simple yet intrusive changes to the
modeling code. However, 🤗 Accelerate provides a utility [`compile_regions()`](https://huggingface.co/docs/accelerate/main/en/usage_guides/compilation#how-to-use-regional-compilation) which automatically compiles
the repeated blocks of the provided `nn.Module` sequentially, and the rest of the model separately. This helps with reducing cold start time while keeping most (if not all) of the speedup you would get from full compilation.
```py
# Make sure you're on the latest `accelerate`: `pip install -U accelerate`.
from accelerate.utils import compile_regions
pipe.unet = compile_regions(pipe.unet, mode="reduce-overhead", fullgraph=True)
```
As you may have noticed `compile_regions()` takes the same arguments as `torch.compile()`, allowing flexibility.
## Benchmark
We conducted a comprehensive benchmark with PyTorch 2.0's efficient attention implementation and `torch.compile` across different GPUs and batch sizes for five of our most used pipelines. The code is benchmarked on 🤗 Diffusers v0.17.0.dev0 to optimize `torch.compile` usage (see [here](https://github.com/huggingface/diffusers/pull/3313) for more details).
+9 -3
View File
@@ -48,7 +48,7 @@ For Ada and higher-series GPUs. we recommend changing `torch_dtype` to `torch.bf
```py
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
import torch
from diffusers import AutoModel
from transformers import T5EncoderModel
@@ -88,6 +88,8 @@ Setting `device_map="auto"` automatically fills all available space on the GPU(s
CPU, and finally, the hard drive (the absolute slowest option) if there is still not enough memory.
```py
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer_8bit,
@@ -132,7 +134,7 @@ For Ada and higher-series GPUs. we recommend changing `torch_dtype` to `torch.bf
```py
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
import torch
from diffusers import AutoModel
from transformers import T5EncoderModel
@@ -171,6 +173,8 @@ Let's generate an image using our quantized models.
Setting `device_map="auto"` automatically fills all available space on the GPU(s) first, then the CPU, and finally, the hard drive (the absolute slowest option) if there is still not enough memory.
```py
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer_4bit,
@@ -214,6 +218,8 @@ Check your memory footprint with the `get_memory_footprint` method:
print(model.get_memory_footprint())
```
Note that this only tells you the memory footprint of the model params and does _not_ estimate the inference memory requirements.
Quantized models can be loaded from the [`~ModelMixin.from_pretrained`] method without needing to specify the `quantization_config` parameters:
```py
@@ -413,4 +419,4 @@ transformer_4bit.dequantize()
## Resources
* [End-to-end notebook showing Flux.1 Dev inference in a free-tier Colab](https://gist.github.com/sayakpaul/c76bd845b48759e11687ac550b99d8b4)
* [Training](https://gist.github.com/sayakpaul/05afd428bc089b47af7c016e42004527)
* [Training](https://github.com/huggingface/diffusers/blob/8c661ea586bf11cb2440da740dd3c4cf84679b85/examples/dreambooth/README_hidream.md#using-quantization)
+87
View File
@@ -39,3 +39,90 @@ Diffusers currently supports the following quantization methods.
- [Quanto](./quanto.md)
[This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
## Pipeline-level quantization
Diffusers allows users to directly initialize pipelines from checkpoints that may contain quantized models ([example](https://huggingface.co/hf-internal-testing/flux.1-dev-nf4-pkg)). However, users may want to apply
quantization on-the-fly when initializing a pipeline from a pre-trained and non-quantized checkpoint. You can
do this with [`~quantizers.PipelineQuantizationConfig`].
Start by defining a `PipelineQuantizationConfig`:
```py
import torch
from diffusers import DiffusionPipeline
from diffusers.quantizers.quantization_config import QuantoConfig
from diffusers.quantizers import PipelineQuantizationConfig
from transformers import BitsAndBytesConfig
pipeline_quant_config = PipelineQuantizationConfig(
quant_mapping={
"transformer": QuantoConfig(weights_dtype="int8"),
"text_encoder_2": BitsAndBytesConfig(
load_in_4bit=True, compute_dtype=torch.bfloat16
),
}
)
```
Then pass it to [`~DiffusionPipeline.from_pretrained`] and run inference:
```py
pipe = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
quantization_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
).to("cuda")
image = pipe("photo of a cute dog").images[0]
```
This method allows for more granular control over the quantization specifications of individual
model-level components of a pipeline. It also allows for different quantization backends for
different components. In the above example, you used a combination of Quanto and BitsandBytes. However,
one caveat of this method is that users need to know which components come from `transformers` to be able
to import the right quantization config class.
The other method is simpler in terms of experience but is
less-flexible. Start by defining a `PipelineQuantizationConfig` but in a different way:
```py
pipeline_quant_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_4bit",
quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
components_to_quantize=["transformer", "text_encoder_2"],
)
```
This `pipeline_quant_config` can now be passed to [`~DiffusionPipeline.from_pretrained`] similar to the above example.
In this case, `quant_kwargs` will be used to initialize the quantization specifications
of the respective quantization configuration class of `quant_backend`. `components_to_quantize`
is used to denote the components that will be quantized. For most pipelines, you would want to
keep `transformer` in the list as that is often the most compute and memory intensive.
The config below will work for most diffusion pipelines that have a `transformer` component present.
In most case, you will want to quantize the `transformer` component as that is often the most compute-
intensive part of a diffusion pipeline.
```py
pipeline_quant_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_4bit",
quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
components_to_quantize=["transformer"],
)
```
Below is a list of the supported quantization backends available in both `diffusers` and `transformers`:
* `bitsandbytes_4bit`
* `bitsandbytes_8bit`
* `gguf`
* `quanto`
* `torchao`
Diffusion pipelines can have multiple text encoders. [`FluxPipeline`] has two, for example. It's
recommended to quantize the text encoders that are memory-intensive. Some examples include T5,
Llama, Gemma, etc. In the above example, you quantized the T5 model of [`FluxPipeline`] through
`text_encoder_2` while keeping the CLIP model intact (accessible through `text_encoder`).
@@ -430,6 +430,9 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
parser.add_argument(
"--with_prior_preservation",
default=False,
@@ -1554,6 +1557,7 @@ def main(args):
transformer_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=target_modules,
)
@@ -1562,6 +1566,7 @@ def main(args):
text_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)
@@ -658,6 +658,8 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
parser.add_argument(
"--use_dora",
action="store_true",
@@ -1248,6 +1250,7 @@ def main(args):
unet_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
lora_dropout=args.lora_dropout,
use_dora=args.use_dora,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
@@ -1260,6 +1263,7 @@ def main(args):
text_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
lora_dropout=args.lora_dropout,
use_dora=args.use_dora,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
@@ -767,6 +767,9 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
parser.add_argument(
"--use_dora",
action="store_true",
@@ -1558,6 +1561,7 @@ def main(args):
r=args.rank,
use_dora=args.use_dora,
lora_alpha=args.rank,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=target_modules,
)
@@ -1570,6 +1574,7 @@ def main(args):
r=args.rank,
use_dora=args.use_dora,
lora_alpha=args.rank,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)
+18 -7
View File
@@ -1114,17 +1114,22 @@ def main(args):
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
num_training_steps_for_scheduler = (
args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
)
else:
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_warmup_steps=num_warmup_steps_for_scheduler,
num_training_steps=num_training_steps_for_scheduler,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
@@ -1156,8 +1161,14 @@ def main(args):
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
if num_training_steps_for_scheduler != args.max_train_steps:
logger.warning(
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
f"This inconsistency may result in the learning rate scheduler not functioning properly."
)
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
@@ -524,6 +524,9 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
parser.add_argument(
"--image_interpolation_mode",
type=str,
@@ -932,6 +935,7 @@ def main(args):
unet_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"],
)
@@ -942,6 +946,7 @@ def main(args):
text_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)
@@ -358,6 +358,9 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
parser.add_argument(
"--with_prior_preservation",
default=False,
@@ -1236,6 +1239,7 @@ def main(args):
transformer_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=target_modules,
)
@@ -1244,6 +1248,7 @@ def main(args):
text_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)
@@ -417,6 +417,9 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
parser.add_argument(
"--with_prior_preservation",
default=False,
@@ -1161,6 +1164,7 @@ def main(args):
transformer_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=target_modules,
)
@@ -328,6 +328,9 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
parser.add_argument(
"--with_prior_preservation",
default=False,
@@ -1023,6 +1026,7 @@ def main(args):
transformer_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=target_modules,
)
@@ -323,6 +323,9 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
parser.add_argument(
"--with_prior_preservation",
default=False,
@@ -1021,6 +1024,7 @@ def main(args):
transformer_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=target_modules,
)
@@ -367,6 +367,9 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
parser.add_argument(
"--with_prior_preservation",
default=False,
@@ -1264,6 +1267,7 @@ def main(args):
transformer_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=target_modules,
)
@@ -1273,6 +1277,7 @@ def main(args):
text_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
lora_dropout=args.lora_dropout,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)
@@ -659,6 +659,9 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
parser.add_argument(
"--use_dora",
action="store_true",
@@ -1199,10 +1202,11 @@ def main(args):
text_encoder_one.gradient_checkpointing_enable()
text_encoder_two.gradient_checkpointing_enable()
def get_lora_config(rank, use_dora, target_modules):
def get_lora_config(rank, dropout, use_dora, target_modules):
base_config = {
"r": rank,
"lora_alpha": rank,
"lora_dropout": dropout,
"init_lora_weights": "gaussian",
"target_modules": target_modules,
}
@@ -1218,14 +1222,24 @@ def main(args):
# now we will add new LoRA weights to the attention layers
unet_target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
unet_lora_config = get_lora_config(rank=args.rank, use_dora=args.use_dora, target_modules=unet_target_modules)
unet_lora_config = get_lora_config(
rank=args.rank,
dropout=args.lora_dropout,
use_dora=args.use_dora,
target_modules=unet_target_modules,
)
unet.add_adapter(unet_lora_config)
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
# So, instead, we monkey-patch the forward calls of its attention-blocks.
if args.train_text_encoder:
text_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
text_lora_config = get_lora_config(rank=args.rank, use_dora=args.use_dora, target_modules=text_target_modules)
text_lora_config = get_lora_config(
rank=args.rank,
dropout=args.lora_dropout,
use_dora=args.use_dora,
target_modules=text_target_modules,
)
text_encoder_one.add_adapter(text_lora_config)
text_encoder_two.add_adapter(text_lora_config)
@@ -812,7 +812,7 @@ def main(args):
if args.scale_lr:
args.learning_rate = (
args.learning_rat * args.gradient_accumulation_steps * args.per_gpu_batch_size * accelerator.num_processes
args.learning_rate * args.gradient_accumulation_steps * args.per_gpu_batch_size * accelerator.num_processes
)
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+95
View File
@@ -0,0 +1,95 @@
# Training SANA Sprint Diffuser
This README explains how to use the provided bash script commands to download a pre-trained teacher diffuser model and train it on a specific dataset, following the [SANA Sprint methodology](https://arxiv.org/abs/2503.09641).
## Setup
### 1. Define the local paths
Set a variable for your desired output directory. This directory will store the downloaded model and the training checkpoints/results.
```bash
your_local_path='output' # Or any other path you prefer
mkdir -p $your_local_path # Create the directory if it doesn't exist
```
### 2. Download the pre-trained model
Download the SANA Sprint teacher model from Hugging Face Hub. The script uses the 1.6B parameter model.
```bash
huggingface-cli download Efficient-Large-Model/SANA_Sprint_1.6B_1024px_teacher_diffusers --local-dir $your_local_path/SANA_Sprint_1.6B_1024px_teacher_diffusers
```
*(Optional: You can also download the 0.6B model by replacing the model name: `Efficient-Large-Model/Sana_Sprint_0.6B_1024px_teacher_diffusers`)*
### 3. Acquire the dataset shards
The training script in this example uses specific `.parquet` shards from a randomly selected `brivangl/midjourney-v6-llava` dataset instead of downloading the entire dataset automatically via `dataset_name`.
The script specifically uses these three files:
* `data/train_000.parquet`
* `data/train_001.parquet`
* `data/train_002.parquet`
You can either:
Let the script download the dataset automatically during first run
Or download it manually
**Note:** The full `brivangl/midjourney-v6-llava` dataset is much larger and contains many more shards. This script example explicitly trains *only* on the three specified shards.
## Usage
Once the model is downloaded, you can run the training script.
```bash
your_local_path='output' # Ensure this variable is set
python train_sana_sprint_diffusers.py \
--pretrained_model_name_or_path=$your_local_path/SANA_Sprint_1.6B_1024px_teacher_diffusers \
--output_dir=$your_local_path \
--mixed_precision=bf16 \
--resolution=1024 \
--learning_rate=1e-6 \
--max_train_steps=30000 \
--dataloader_num_workers=8 \
--dataset_name='brivangl/midjourney-v6-llava' \
--file_path data/train_000.parquet data/train_001.parquet data/train_002.parquet \
--checkpointing_steps=500 --checkpoints_total_limit=10 \
--train_batch_size=1 \
--gradient_accumulation_steps=1 \
--seed=453645634 \
--train_largest_timestep \
--misaligned_pairs_D \
--gradient_checkpointing \
--resume_from_checkpoint="latest" \
```
### Explanation of parameters
* `--pretrained_model_name_or_path`: Path to the downloaded pre-trained model directory.
* `--output_dir`: Directory where training logs, checkpoints, and the final model will be saved.
* `--mixed_precision`: Use BF16 mixed precision for training, which can save memory and speed up training on compatible hardware.
* `--resolution`: The image resolution used for training (1024x1024).
* `--learning_rate`: The learning rate for the optimizer.
* `--max_train_steps`: The total number of training steps to perform.
* `--dataloader_num_workers`: Number of worker processes for loading data. Increase for faster data loading if your CPU and disk can handle it.
* `--dataset_name`: The name of the dataset on Hugging Face Hub (`brivangl/midjourney-v6-llava`).
* `--file_path`: **Specifies the local paths to the dataset shards to be used for training.** In this case, `data/train_000.parquet`, `data/train_001.parquet`, and `data/train_002.parquet`.
* `--checkpointing_steps`: Save a training checkpoint every X steps.
* `--checkpoints_total_limit`: Maximum number of checkpoints to keep. Older checkpoints will be deleted.
* `--train_batch_size`: The batch size per GPU.
* `--gradient_accumulation_steps`: Number of steps to accumulate gradients before performing an optimizer step.
* `--seed`: Random seed for reproducibility.
* `--train_largest_timestep`: A specific training strategy focusing on larger timesteps.
* `--misaligned_pairs_D`: Another specific training strategy to add misaligned image-text pairs as fake data for GAN.
* `--gradient_checkpointing`: Enable gradient checkpointing to save GPU memory.
* `--resume_from_checkpoint`: Allows resuming training from the latest saved checkpoint in the `--output_dir`.
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,26 @@
your_local_path='output'
huggingface-cli download Efficient-Large-Model/SANA_Sprint_1.6B_1024px_teacher_diffusers --local-dir $your_local_path/SANA_Sprint_1.6B_1024px_teacher_diffusers
# or Sana_Sprint_0.6B_1024px_teacher_diffusers
python train_sana_sprint_diffusers.py \
--pretrained_model_name_or_path=$your_local_path/SANA_Sprint_1.6B_1024px_teacher_diffusers \
--output_dir=$your_local_path \
--mixed_precision=bf16 \
--resolution=1024 \
--learning_rate=1e-6 \
--max_train_steps=30000 \
--dataloader_num_workers=8 \
--dataset_name='brivangl/midjourney-v6-llava' \
--file_path data/train_000.parquet data/train_001.parquet data/train_002.parquet \
--checkpointing_steps=500 --checkpoints_total_limit=10 \
--train_batch_size=1 \
--gradient_accumulation_steps=1 \
--seed=453645634 \
--train_largest_timestep \
--misaligned_pairs_D \
--gradient_checkpointing \
--resume_from_checkpoint="latest" \
+352
View File
@@ -0,0 +1,352 @@
import argparse
import pathlib
from typing import Any, Dict
import torch
from accelerate import init_empty_weights
from huggingface_hub import snapshot_download
from transformers import T5EncoderModel, T5TokenizerFast
from diffusers import AutoencoderKLCosmos, CosmosTextToWorldPipeline, CosmosTransformer3DModel, EDMEulerScheduler
def remove_keys_(key: str, state_dict: Dict[str, Any]):
state_dict.pop(key)
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
state_dict[new_key] = state_dict.pop(old_key)
def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
block_index = int(key.split(".")[1].removeprefix("block"))
new_key = key
old_prefix = f"blocks.block{block_index}"
new_prefix = f"transformer_blocks.{block_index}"
new_key = new_prefix + new_key.removeprefix(old_prefix)
state_dict[new_key] = state_dict.pop(key)
TRANSFORMER_KEYS_RENAME_DICT = {
"t_embedder.1": "time_embed.t_embedder",
"affline_norm": "time_embed.norm",
".blocks.0.block.attn": ".attn1",
".blocks.1.block.attn": ".attn2",
".blocks.2.block": ".ff",
".blocks.0.adaLN_modulation.1": ".norm1.linear_1",
".blocks.0.adaLN_modulation.2": ".norm1.linear_2",
".blocks.1.adaLN_modulation.1": ".norm2.linear_1",
".blocks.1.adaLN_modulation.2": ".norm2.linear_2",
".blocks.2.adaLN_modulation.1": ".norm3.linear_1",
".blocks.2.adaLN_modulation.2": ".norm3.linear_2",
"to_q.0": "to_q",
"to_q.1": "norm_q",
"to_k.0": "to_k",
"to_k.1": "norm_k",
"to_v.0": "to_v",
"layer1": "net.0.proj",
"layer2": "net.2",
"proj.1": "proj",
"x_embedder": "patch_embed",
"extra_pos_embedder": "learnable_pos_embed",
"final_layer.adaLN_modulation.1": "norm_out.linear_1",
"final_layer.adaLN_modulation.2": "norm_out.linear_2",
"final_layer.linear": "proj_out",
}
TRANSFORMER_SPECIAL_KEYS_REMAP = {
"blocks.block": rename_transformer_blocks_,
"logvar.0.freqs": remove_keys_,
"logvar.0.phases": remove_keys_,
"logvar.1.weight": remove_keys_,
"pos_embedder.seq": remove_keys_,
}
TRANSFORMER_CONFIGS = {
"Cosmos-1.0-Diffusion-7B-Text2World": {
"in_channels": 16,
"out_channels": 16,
"num_attention_heads": 32,
"attention_head_dim": 128,
"num_layers": 28,
"mlp_ratio": 4.0,
"text_embed_dim": 1024,
"adaln_lora_dim": 256,
"max_size": (128, 240, 240),
"patch_size": (1, 2, 2),
"rope_scale": (2.0, 1.0, 1.0),
"concat_padding_mask": True,
"extra_pos_embed_type": "learnable",
},
"Cosmos-1.0-Diffusion-7B-Video2World": {
"in_channels": 16 + 1,
"out_channels": 16,
"num_attention_heads": 32,
"attention_head_dim": 128,
"num_layers": 28,
"mlp_ratio": 4.0,
"text_embed_dim": 1024,
"adaln_lora_dim": 256,
"max_size": (128, 240, 240),
"patch_size": (1, 2, 2),
"rope_scale": (2.0, 1.0, 1.0),
"concat_padding_mask": True,
"extra_pos_embed_type": "learnable",
},
"Cosmos-1.0-Diffusion-14B-Text2World": {
"in_channels": 16,
"out_channels": 16,
"num_attention_heads": 40,
"attention_head_dim": 128,
"num_layers": 36,
"mlp_ratio": 4.0,
"text_embed_dim": 1024,
"adaln_lora_dim": 256,
"max_size": (128, 240, 240),
"patch_size": (1, 2, 2),
"rope_scale": (2.0, 2.0, 2.0),
"concat_padding_mask": True,
"extra_pos_embed_type": "learnable",
},
"Cosmos-1.0-Diffusion-14B-Video2World": {
"in_channels": 16 + 1,
"out_channels": 16,
"num_attention_heads": 40,
"attention_head_dim": 128,
"num_layers": 36,
"mlp_ratio": 4.0,
"text_embed_dim": 1024,
"adaln_lora_dim": 256,
"max_size": (128, 240, 240),
"patch_size": (1, 2, 2),
"rope_scale": (2.0, 2.0, 2.0),
"concat_padding_mask": True,
"extra_pos_embed_type": "learnable",
},
}
VAE_KEYS_RENAME_DICT = {
"down.0": "down_blocks.0",
"down.1": "down_blocks.1",
"down.2": "down_blocks.2",
"up.0": "up_blocks.2",
"up.1": "up_blocks.1",
"up.2": "up_blocks.0",
".block.": ".resnets.",
"downsample": "downsamplers.0",
"upsample": "upsamplers.0",
"mid.block_1": "mid_block.resnets.0",
"mid.attn_1.0": "mid_block.attentions.0",
"mid.attn_1.1": "mid_block.temp_attentions.0",
"mid.block_2": "mid_block.resnets.1",
".q.conv3d": ".to_q",
".k.conv3d": ".to_k",
".v.conv3d": ".to_v",
".proj_out.conv3d": ".to_out.0",
".0.conv3d": ".conv_s",
".1.conv3d": ".conv_t",
"conv1.conv3d": "conv1",
"conv2.conv3d": "conv2",
"conv3.conv3d": "conv3",
"nin_shortcut.conv3d": "conv_shortcut",
"quant_conv.conv3d": "quant_conv",
"post_quant_conv.conv3d": "post_quant_conv",
}
VAE_SPECIAL_KEYS_REMAP = {
"wavelets": remove_keys_,
"_arange": remove_keys_,
"patch_size_buffer": remove_keys_,
}
VAE_CONFIGS = {
"CV8x8x8-0.1": {
"name": "nvidia/Cosmos-0.1-Tokenizer-CV8x8x8",
"diffusers_config": {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 16,
"encoder_block_out_channels": (128, 256, 512, 512),
"decode_block_out_channels": (256, 512, 512, 512),
"attention_resolutions": (32,),
"resolution": 1024,
"num_layers": 2,
"patch_size": 4,
"patch_type": "haar",
"scaling_factor": 1.0,
"spatial_compression_ratio": 8,
"temporal_compression_ratio": 8,
"latents_mean": None,
"latents_std": None,
},
},
"CV8x8x8-1.0": {
"name": "nvidia/Cosmos-1.0-Tokenizer-CV8x8x8",
"diffusers_config": {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 16,
"encoder_block_out_channels": (128, 256, 512, 512),
"decode_block_out_channels": (256, 512, 512, 512),
"attention_resolutions": (32,),
"resolution": 1024,
"num_layers": 2,
"patch_size": 4,
"patch_type": "haar",
"scaling_factor": 1.0,
"spatial_compression_ratio": 8,
"temporal_compression_ratio": 8,
"latents_mean": None,
"latents_std": None,
},
},
}
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
state_dict = saved_dict
if "model" in saved_dict.keys():
state_dict = state_dict["model"]
if "module" in saved_dict.keys():
state_dict = state_dict["module"]
if "state_dict" in saved_dict.keys():
state_dict = state_dict["state_dict"]
return state_dict
def convert_transformer(transformer_type: str, ckpt_path: str):
PREFIX_KEY = "net."
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))
with init_empty_weights():
config = TRANSFORMER_CONFIGS[transformer_type]
transformer = CosmosTransformer3DModel(**config)
for key in list(original_state_dict.keys()):
new_key = key[:]
if new_key.startswith(PREFIX_KEY):
new_key = new_key.removeprefix(PREFIX_KEY)
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_(original_state_dict, key, new_key)
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
return transformer
def convert_vae(vae_type: str):
model_name = VAE_CONFIGS[vae_type]["name"]
snapshot_directory = snapshot_download(model_name, repo_type="model")
directory = pathlib.Path(snapshot_directory)
autoencoder_file = directory / "autoencoder.jit"
mean_std_file = directory / "mean_std.pt"
original_state_dict = torch.jit.load(autoencoder_file.as_posix()).state_dict()
if mean_std_file.exists():
mean_std = torch.load(mean_std_file, map_location="cpu", weights_only=True)
else:
mean_std = (None, None)
config = VAE_CONFIGS[vae_type]["diffusers_config"]
config.update(
{
"latents_mean": mean_std[0].detach().cpu().numpy().tolist(),
"latents_std": mean_std[1].detach().cpu().numpy().tolist(),
}
)
vae = AutoencoderKLCosmos(**config)
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_(original_state_dict, key, new_key)
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
vae.load_state_dict(original_state_dict, strict=True, assign=True)
return vae
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--transformer_type", type=str, default=None, choices=list(TRANSFORMER_CONFIGS.keys()))
parser.add_argument(
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
)
parser.add_argument("--vae_type", type=str, default=None, choices=list(VAE_CONFIGS.keys()), help="Type of VAE")
parser.add_argument("--text_encoder_path", type=str, default="google-t5/t5-11b")
parser.add_argument("--tokenizer_path", type=str, default="google-t5/t5-11b")
parser.add_argument("--save_pipeline", action="store_true")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
return parser.parse_args()
DTYPE_MAPPING = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
if __name__ == "__main__":
args = get_args()
transformer = None
dtype = DTYPE_MAPPING[args.dtype]
if args.save_pipeline:
assert args.transformer_ckpt_path is not None
assert args.vae_type is not None
assert args.text_encoder_path is not None
assert args.tokenizer_path is not None
if args.transformer_ckpt_path is not None:
transformer = convert_transformer(args.transformer_type, args.transformer_ckpt_path)
transformer = transformer.to(dtype=dtype)
if not args.save_pipeline:
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
if args.vae_type is not None:
vae = convert_vae(args.vae_type)
if not args.save_pipeline:
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
if args.save_pipeline:
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, torch_dtype=dtype)
tokenizer = T5TokenizerFast.from_pretrained(args.tokenizer_path)
# The original code initializes EDM config with sigma_min=0.0002, but does not make use of it anywhere directly.
# So, the sigma_min values that is used is the default value of 0.002.
scheduler = EDMEulerScheduler(
sigma_min=0.002,
sigma_max=80,
sigma_data=0.5,
sigma_schedule="karras",
num_train_timesteps=1000,
prediction_type="epsilon",
rho=7.0,
final_sigmas_type="sigma_min",
)
pipe = CosmosTextToWorldPipeline(
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
vae=vae,
scheduler=scheduler,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
+172 -27
View File
@@ -7,7 +7,15 @@ from accelerate import init_empty_weights
from safetensors.torch import load_file
from transformers import T5EncoderModel, T5Tokenizer
from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXVideoTransformer3DModel
from diffusers import (
AutoencoderKLLTXVideo,
FlowMatchEulerDiscreteScheduler,
LTXConditionPipeline,
LTXLatentUpsamplePipeline,
LTXPipeline,
LTXVideoTransformer3DModel,
)
from diffusers.pipelines.ltx.modeling_latent_upsampler import LTXLatentUpsamplerModel
def remove_keys_(key: str, state_dict: Dict[str, Any]):
@@ -123,17 +131,10 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
state_dict[new_key] = state_dict.pop(old_key)
def convert_transformer(
ckpt_path: str,
dtype: torch.dtype,
version: str = "0.9.0",
):
def convert_transformer(ckpt_path: str, config, dtype: torch.dtype):
PREFIX_KEY = "model.diffusion_model."
original_state_dict = get_state_dict(load_file(ckpt_path))
config = {}
if version == "0.9.5":
config["_use_causal_rope_fix"] = True
with init_empty_weights():
transformer = LTXVideoTransformer3DModel(**config)
@@ -180,8 +181,59 @@ def convert_vae(ckpt_path: str, config, dtype: torch.dtype):
return vae
def convert_spatial_latent_upsampler(ckpt_path: str, config, dtype: torch.dtype):
original_state_dict = get_state_dict(load_file(ckpt_path))
with init_empty_weights():
latent_upsampler = LTXLatentUpsamplerModel(**config)
latent_upsampler.load_state_dict(original_state_dict, strict=True, assign=True)
latent_upsampler.to(dtype)
return latent_upsampler
def get_transformer_config(version: str) -> Dict[str, Any]:
if version == "0.9.7":
config = {
"in_channels": 128,
"out_channels": 128,
"patch_size": 1,
"patch_size_t": 1,
"num_attention_heads": 32,
"attention_head_dim": 128,
"cross_attention_dim": 4096,
"num_layers": 48,
"activation_fn": "gelu-approximate",
"qk_norm": "rms_norm_across_heads",
"norm_elementwise_affine": False,
"norm_eps": 1e-6,
"caption_channels": 4096,
"attention_bias": True,
"attention_out_bias": True,
}
else:
config = {
"in_channels": 128,
"out_channels": 128,
"patch_size": 1,
"patch_size_t": 1,
"num_attention_heads": 32,
"attention_head_dim": 64,
"cross_attention_dim": 2048,
"num_layers": 28,
"activation_fn": "gelu-approximate",
"qk_norm": "rms_norm_across_heads",
"norm_elementwise_affine": False,
"norm_eps": 1e-6,
"caption_channels": 4096,
"attention_bias": True,
"attention_out_bias": True,
}
return config
def get_vae_config(version: str) -> Dict[str, Any]:
if version == "0.9.0":
if version in ["0.9.0"]:
config = {
"in_channels": 3,
"out_channels": 3,
@@ -210,7 +262,7 @@ def get_vae_config(version: str) -> Dict[str, Any]:
"decoder_causal": False,
"timestep_conditioning": False,
}
elif version == "0.9.1":
elif version in ["0.9.1"]:
config = {
"in_channels": 3,
"out_channels": 3,
@@ -240,7 +292,39 @@ def get_vae_config(version: str) -> Dict[str, Any]:
"decoder_causal": False,
}
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
elif version == "0.9.5":
elif version in ["0.9.5"]:
config = {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (128, 256, 512, 1024, 2048),
"down_block_types": (
"LTXVideo095DownBlock3D",
"LTXVideo095DownBlock3D",
"LTXVideo095DownBlock3D",
"LTXVideo095DownBlock3D",
),
"decoder_block_out_channels": (256, 512, 1024),
"layers_per_block": (4, 6, 6, 2, 2),
"decoder_layers_per_block": (5, 5, 5, 5),
"spatio_temporal_scaling": (True, True, True, True),
"decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (False, False, False, False),
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": True,
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
"scaling_factor": 1.0,
"encoder_causal": True,
"decoder_causal": False,
"spatial_compression_ratio": 32,
"temporal_compression_ratio": 8,
}
VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
elif version in ["0.9.7"]:
config = {
"in_channels": 3,
"out_channels": 3,
@@ -275,12 +359,33 @@ def get_vae_config(version: str) -> Dict[str, Any]:
return config
def get_spatial_latent_upsampler_config(version: str) -> Dict[str, Any]:
if version == "0.9.7":
config = {
"in_channels": 128,
"mid_channels": 512,
"num_blocks_per_stage": 4,
"dims": 3,
"spatial_upsample": True,
"temporal_upsample": False,
}
else:
raise ValueError(f"Unsupported version: {version}")
return config
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
)
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
parser.add_argument(
"--spatial_latent_upsampler_path",
type=str,
default=None,
help="Path to original spatial latent upsampler checkpoint",
)
parser.add_argument(
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
)
@@ -294,7 +399,11 @@ def get_args():
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
parser.add_argument(
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1", "0.9.5"], help="Version of the LTX model"
"--version",
type=str,
default="0.9.0",
choices=["0.9.0", "0.9.1", "0.9.5", "0.9.7"],
help="Version of the LTX model",
)
return parser.parse_args()
@@ -320,11 +429,9 @@ if __name__ == "__main__":
variant = VARIANT_MAPPING[args.dtype]
output_path = Path(args.output_path)
if args.save_pipeline:
assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None
if args.transformer_ckpt_path is not None:
transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, dtype)
config = get_transformer_config(args.version)
transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, config, dtype)
if not args.save_pipeline:
transformer.save_pretrained(
output_path / "transformer", safe_serialization=True, max_shard_size="5GB", variant=variant
@@ -336,6 +443,16 @@ if __name__ == "__main__":
if not args.save_pipeline:
vae.save_pretrained(output_path / "vae", safe_serialization=True, max_shard_size="5GB", variant=variant)
if args.spatial_latent_upsampler_path is not None:
config = get_spatial_latent_upsampler_config(args.version)
latent_upsampler: LTXLatentUpsamplerModel = convert_spatial_latent_upsampler(
args.spatial_latent_upsampler_path, config, dtype
)
if not args.save_pipeline:
latent_upsampler.save_pretrained(
output_path / "latent_upsampler", safe_serialization=True, max_shard_size="5GB", variant=variant
)
if args.save_pipeline:
text_encoder_id = "google/t5-v1_1-xxl"
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
@@ -348,7 +465,7 @@ if __name__ == "__main__":
for param in text_encoder.parameters():
param.data = param.data.contiguous()
if args.version == "0.9.5":
if args.version in ["0.9.5", "0.9.7"]:
scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=False)
else:
scheduler = FlowMatchEulerDiscreteScheduler(
@@ -360,12 +477,40 @@ if __name__ == "__main__":
shift_terminal=0.1,
)
pipe = LTXPipeline(
scheduler=scheduler,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, variant=variant, max_shard_size="5GB")
if args.version in ["0.9.0", "0.9.1", "0.9.5"]:
pipe = LTXPipeline(
scheduler=scheduler,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
)
pipe.save_pretrained(
output_path.as_posix(), safe_serialization=True, variant=variant, max_shard_size="5GB"
)
elif args.version in ["0.9.7"]:
pipe = LTXConditionPipeline(
scheduler=scheduler,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
)
pipe_upsample = LTXLatentUpsamplePipeline(
vae=vae,
latent_upsampler=latent_upsampler,
)
pipe.save_pretrained(
(output_path / "ltx_pipeline").as_posix(),
safe_serialization=True,
variant=variant,
max_shard_size="5GB",
)
pipe_upsample.save_pretrained(
(output_path / "ltx_upsample_pipeline").as_posix(),
safe_serialization=True,
variant=variant,
max_shard_size="5GB",
)
else:
raise ValueError(f"Unsupported version: {args.version}")
+21
View File
@@ -148,6 +148,7 @@ else:
"AutoencoderKL",
"AutoencoderKLAllegro",
"AutoencoderKLCogVideoX",
"AutoencoderKLCosmos",
"AutoencoderKLHunyuanVideo",
"AutoencoderKLLTXVideo",
"AutoencoderKLMagvit",
@@ -158,6 +159,7 @@ else:
"AutoencoderTiny",
"AutoModel",
"CacheMixin",
"ChromaTransformer2DModel",
"CogVideoXTransformer3DModel",
"CogView3PlusTransformer2DModel",
"CogView4Transformer2DModel",
@@ -166,6 +168,7 @@ else:
"ControlNetModel",
"ControlNetUnionModel",
"ControlNetXSAdapter",
"CosmosTransformer3DModel",
"DiTTransformer2DModel",
"EasyAnimateTransformer3DModel",
"FluxControlNetModel",
@@ -175,6 +178,7 @@ else:
"HunyuanDiT2DControlNetModel",
"HunyuanDiT2DModel",
"HunyuanDiT2DMultiControlNetModel",
"HunyuanVideoFramepackTransformer3DModel",
"HunyuanVideoTransformer3DModel",
"I2VGenXLUNet",
"Kandinsky3UNet",
@@ -356,6 +360,9 @@ else:
"CogView3PlusPipeline",
"CogView4ControlPipeline",
"CogView4Pipeline",
"ConsisIDPipeline",
"CosmosTextToWorldPipeline",
"CosmosVideoToWorldPipeline",
"CycleDiffusionPipeline",
"EasyAnimateControlPipeline",
"EasyAnimateInpaintPipeline",
@@ -376,6 +383,7 @@ else:
"HunyuanDiTPAGPipeline",
"HunyuanDiTPipeline",
"HunyuanSkyreelsImageToVideoPipeline",
"HunyuanVideoFramepackPipeline",
"HunyuanVideoImageToVideoPipeline",
"HunyuanVideoPipeline",
"I2VGenXLPipeline",
@@ -413,6 +421,7 @@ else:
"LEditsPPPipelineStableDiffusionXL",
"LTXConditionPipeline",
"LTXImageToVideoPipeline",
"LTXLatentUpsamplePipeline",
"LTXPipeline",
"Lumina2Pipeline",
"Lumina2Text2ImgPipeline",
@@ -513,6 +522,8 @@ else:
"VersatileDiffusionPipeline",
"VersatileDiffusionTextToImagePipeline",
"VideoToVideoSDPipeline",
"VisualClozeGenerationPipeline",
"VisualClozePipeline",
"VQDiffusionPipeline",
"WanImageToVideoPipeline",
"WanPipeline",
@@ -743,6 +754,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKL,
AutoencoderKLAllegro,
AutoencoderKLCogVideoX,
AutoencoderKLCosmos,
AutoencoderKLHunyuanVideo,
AutoencoderKLLTXVideo,
AutoencoderKLMagvit,
@@ -761,6 +773,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ControlNetModel,
ControlNetUnionModel,
ControlNetXSAdapter,
CosmosTransformer3DModel,
DiTTransformer2DModel,
EasyAnimateTransformer3DModel,
FluxControlNetModel,
@@ -770,6 +783,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanDiT2DControlNetModel,
HunyuanDiT2DModel,
HunyuanDiT2DMultiControlNetModel,
HunyuanVideoFramepackTransformer3DModel,
HunyuanVideoTransformer3DModel,
I2VGenXLUNet,
Kandinsky3UNet,
@@ -930,6 +944,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CogView3PlusPipeline,
CogView4ControlPipeline,
CogView4Pipeline,
ConsisIDPipeline,
CosmosTextToWorldPipeline,
CosmosVideoToWorldPipeline,
CycleDiffusionPipeline,
EasyAnimateControlPipeline,
EasyAnimateInpaintPipeline,
@@ -950,6 +967,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanDiTPAGPipeline,
HunyuanDiTPipeline,
HunyuanSkyreelsImageToVideoPipeline,
HunyuanVideoFramepackPipeline,
HunyuanVideoImageToVideoPipeline,
HunyuanVideoPipeline,
I2VGenXLPipeline,
@@ -987,6 +1005,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LEditsPPPipelineStableDiffusionXL,
LTXConditionPipeline,
LTXImageToVideoPipeline,
LTXLatentUpsamplePipeline,
LTXPipeline,
Lumina2Pipeline,
Lumina2Text2ImgPipeline,
@@ -1086,6 +1105,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
VersatileDiffusionPipeline,
VersatileDiffusionTextToImagePipeline,
VideoToVideoSDPipeline,
VisualClozeGenerationPipeline,
VisualClozePipeline,
VQDiffusionPipeline,
WanImageToVideoPipeline,
WanPipeline,
+2 -2
View File
@@ -348,7 +348,7 @@ def _load_lora_into_text_encoder(
# Load the layers corresponding to text encoder and make necessary adjustments.
if prefix is not None:
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
if len(state_dict) > 0:
logger.info(f"Loading {prefix}.")
@@ -374,7 +374,7 @@ def _load_lora_into_text_encoder(
if network_alphas is not None:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
+26 -1
View File
@@ -727,8 +727,25 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
elif k.startswith("lora_te1_"):
has_te_keys = True
continue
elif k.startswith("lora_transformer_context_embedder"):
diffusers_key = "context_embedder"
elif k.startswith("lora_transformer_norm_out_linear"):
diffusers_key = "norm_out.linear"
elif k.startswith("lora_transformer_proj_out"):
diffusers_key = "proj_out"
elif k.startswith("lora_transformer_x_embedder"):
diffusers_key = "x_embedder"
elif k.startswith("lora_transformer_time_text_embed_guidance_embedder_linear_"):
i = int(k.split("lora_transformer_time_text_embed_guidance_embedder_linear_")[-1])
diffusers_key = f"time_text_embed.guidance_embedder.linear_{i}"
elif k.startswith("lora_transformer_time_text_embed_text_embedder_linear_"):
i = int(k.split("lora_transformer_time_text_embed_text_embedder_linear_")[-1])
diffusers_key = f"time_text_embed.text_embedder.linear_{i}"
elif k.startswith("lora_transformer_time_text_embed_timestep_embedder_linear_"):
i = int(k.split("lora_transformer_time_text_embed_timestep_embedder_linear_")[-1])
diffusers_key = f"time_text_embed.timestep_embedder.linear_{i}"
else:
raise NotImplementedError
raise NotImplementedError(f"Handling for key ({k}) is not implemented.")
if "attn_" in k:
if "_to_out_0" in k:
@@ -1687,3 +1704,11 @@ def _convert_musubi_wan_lora_to_diffusers(state_dict):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
return converted_state_dict
def _convert_non_diffusers_hidream_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"):
if not all(k.startswith(non_diffusers_prefix) for k in state_dict):
raise ValueError("Invalid LoRA state dict for HiDream.")
converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
return converted_state_dict
+7 -3
View File
@@ -43,6 +43,7 @@ from .lora_conversion_utils import (
_convert_hunyuan_video_lora_to_diffusers,
_convert_kohya_flux_lora_to_diffusers,
_convert_musubi_wan_lora_to_diffusers,
_convert_non_diffusers_hidream_lora_to_diffusers,
_convert_non_diffusers_lora_to_diffusers,
_convert_non_diffusers_lumina2_lora_to_diffusers,
_convert_non_diffusers_wan_lora_to_diffusers,
@@ -2103,7 +2104,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
prefix = prefix or cls.transformer_name
for key in list(state_dict.keys()):
if key.split(".")[0] == prefix:
state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key)
state_dict[key.removeprefix(f"{prefix}.")] = state_dict.pop(key)
# Find invalid keys
transformer_state_dict = transformer.state_dict()
@@ -2425,7 +2426,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
prefix = prefix or cls.transformer_name
for key in list(state_dict.keys()):
if key.split(".")[0] == prefix:
state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key)
state_dict[key.removeprefix(f"{prefix}.")] = state_dict.pop(key)
# Expand transformer parameter shapes if they don't match lora
has_param_with_shape_update = False
@@ -5371,7 +5372,6 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
@classmethod
@validate_hf_hub_args
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
@@ -5465,6 +5465,10 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
is_non_diffusers_format = any("diffusion_model" in k for k in state_dict)
if is_non_diffusers_format:
state_dict = _convert_non_diffusers_hidream_lora_to_diffusers(state_dict)
return state_dict
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
+5 -2
View File
@@ -57,6 +57,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
"WanTransformer3DModel": lambda model_cls, weights: weights,
"CogView4Transformer2DModel": lambda model_cls, weights: weights,
"HiDreamImageTransformer2DModel": lambda model_cls, weights: weights,
"HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights,
}
@@ -230,7 +231,7 @@ class PeftAdapterMixin:
raise ValueError("`network_alphas` cannot be None when `prefix` is None.")
if prefix is not None:
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
if len(state_dict) > 0:
if adapter_name in getattr(self, "peft_config", {}) and not hotswap:
@@ -261,7 +262,9 @@ class PeftAdapterMixin:
if network_alphas is not None and len(network_alphas) >= 1:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
network_alphas = {
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
}
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
_maybe_raise_error_for_ambiguity(lora_config_kwargs)
+12 -3
View File
@@ -29,8 +29,10 @@ from .single_file_utils import (
convert_animatediff_checkpoint_to_diffusers,
convert_auraflow_transformer_checkpoint_to_diffusers,
convert_autoencoder_dc_checkpoint_to_diffusers,
convert_chroma_transformer_to_diffusers,
convert_controlnet_checkpoint,
convert_flux_transformer_checkpoint_to_diffusers,
convert_hidream_transformer_to_diffusers,
convert_hunyuan_video_transformer_to_diffusers,
convert_ldm_unet_checkpoint,
convert_ldm_vae_checkpoint,
@@ -133,6 +135,14 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
"default_subfolder": "vae",
},
"HiDreamImageTransformer2DModel": {
"checkpoint_mapping_fn": convert_hidream_transformer_to_diffusers,
"default_subfolder": "transformer",
},
"ChromaTransformer2DModel": {
"checkpoint_mapping_fn": convert_chroma_transformer_to_diffusers,
"default_subfolder": "transformer",
},
}
@@ -187,9 +197,8 @@ class FromOriginalModelMixin:
original_config (`str`, *optional*):
Dict or path to a yaml file containing the configuration for the model in its original format.
If a dict is provided, it will be used to initialize the model configuration.
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
dtype is automatically derived from the model's weights.
torch_dtype (`torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
+163 -1
View File
@@ -126,6 +126,7 @@ CHECKPOINT_KEY_NAMES = {
],
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
"wan_vae": "decoder.middle.0.residual.0.gamma",
"hidream": "double_stream_blocks.0.block.adaLN_modulation.1.bias",
}
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -190,6 +191,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
"hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"},
}
# Use to configure model sample size when original config is provided
@@ -701,6 +703,8 @@ def infer_diffusers_model_type(checkpoint):
elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint:
# All Wan models use the same VAE so we can use the same default model repo to fetch the config
model_type = "wan-t2v-14B"
elif CHECKPOINT_KEY_NAMES["hidream"] in checkpoint:
model_type = "hidream"
else:
model_type = "v1"
@@ -2195,7 +2199,6 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
for i in range(num_layers):
block_prefix = f"transformer_blocks.{i}."
# norms.
## norm1
converted_state_dict[f"{block_prefix}norm1.linear.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_mod.lin.weight"
)
@@ -2281,6 +2284,7 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
# single transformer blocks
for i in range(num_single_layers):
block_prefix = f"single_transformer_blocks.{i}."
# norm.linear <- single_blocks.0.modulation.lin
converted_state_dict[f"{block_prefix}norm.linear.weight"] = checkpoint.pop(
f"single_blocks.{i}.modulation.lin.weight"
@@ -2316,6 +2320,7 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
checkpoint.pop("final_layer.adaLN_modulation.1.weight")
)
@@ -3293,3 +3298,160 @@ def convert_wan_vae_to_diffusers(checkpoint, **kwargs):
converted_state_dict[key] = value
return converted_state_dict
def convert_hidream_transformer_to_diffusers(checkpoint, **kwargs):
keys = list(checkpoint.keys())
for k in keys:
if "model.diffusion_model." in k:
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
return checkpoint
def convert_chroma_transformer_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
keys = list(checkpoint.keys())
for k in keys:
if "model.diffusion_model." in k:
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
for k in keys:
if k.startswith("distilled_guidance_layer.norms"):
converted_state_dict[k.replace(".scale", ".weight")] = checkpoint.pop(k)
elif k.startswith("distilled_guidance_layer.layer"):
converted_state_dict[k.replace("in_layer", "linear_1").replace("out_layer", "linear_2")] = checkpoint.pop(
k
)
elif k.startswith("distilled_guidance_layer"):
converted_state_dict[k] = checkpoint.pop(k)
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
mlp_ratio = 4.0
inner_dim = 3072
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
# context_embedder
converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight")
converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias")
# x_embedder
converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight")
converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias")
# double transformer blocks
for i in range(num_layers):
block_prefix = f"transformer_blocks.{i}."
# norms.
# Q, K, V
sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0)
context_q, context_k, context_v = torch.chunk(
checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
)
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
)
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
)
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
# qk_norm
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_attn.norm.query_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_attn.norm.key_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
)
# ff img_mlp
converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_mlp.0.weight"
)
converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias")
converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight")
converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias")
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_mlp.0.weight"
)
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop(
f"double_blocks.{i}.txt_mlp.0.bias"
)
converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_mlp.2.weight"
)
converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop(
f"double_blocks.{i}.txt_mlp.2.bias"
)
# output projections.
converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop(
f"double_blocks.{i}.img_attn.proj.weight"
)
converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop(
f"double_blocks.{i}.img_attn.proj.bias"
)
converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop(
f"double_blocks.{i}.txt_attn.proj.weight"
)
converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop(
f"double_blocks.{i}.txt_attn.proj.bias"
)
# single transformer blocks
for i in range(num_single_layers):
block_prefix = f"single_transformer_blocks.{i}."
# Q, K, V, mlp
mlp_hidden_dim = int(inner_dim * mlp_ratio)
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
q_bias, k_bias, v_bias, mlp_bias = torch.split(
checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
)
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
# qk norm
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
f"single_blocks.{i}.norm.query_norm.scale"
)
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
f"single_blocks.{i}.norm.key_norm.scale"
)
# output projections.
converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight")
converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias")
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
return converted_state_dict
+8
View File
@@ -32,6 +32,7 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
_import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"]
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
_import_structure["autoencoders.autoencoder_kl_cosmos"] = ["AutoencoderKLCosmos"]
_import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
@@ -73,12 +74,15 @@ if is_torch_available():
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
_import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"]
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
_import_structure["transformers.transformer_cosmos"] = ["CosmosTransformer3DModel"]
_import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"]
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
_import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
@@ -113,6 +117,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKL,
AutoencoderKLAllegro,
AutoencoderKLCogVideoX,
AutoencoderKLCosmos,
AutoencoderKLHunyuanVideo,
AutoencoderKLLTXVideo,
AutoencoderKLMagvit,
@@ -146,16 +151,19 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .transformers import (
AllegroTransformer3DModel,
AuraFlowTransformer2DModel,
ChromaTransformer2DModel,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
CogView4Transformer2DModel,
ConsisIDTransformer3DModel,
CosmosTransformer3DModel,
DiTTransformer2DModel,
DualTransformer2DModel,
EasyAnimateTransformer3DModel,
FluxTransformer2DModel,
HiDreamImageTransformer2DModel,
HunyuanDiT2DModel,
HunyuanVideoFramepackTransformer3DModel,
HunyuanVideoTransformer3DModel,
LatteTransformer3DModel,
LTXVideoTransformer3DModel,
+2 -3
View File
@@ -161,9 +161,8 @@ class MultiAdapter(ModelMixin):
pretrained_model_path (`os.PathLike`):
A path to a *directory* containing model weights saved using
[`~diffusers.models.adapter.MultiAdapter.save_pretrained`], e.g., `./my_model_directory/adapter`.
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
will be automatically derived from the model's weights.
torch_dtype (`torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model under this dtype.
output_loading_info(`bool`, *optional*, defaults to `False`):
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
+2 -2
View File
@@ -203,8 +203,8 @@ class Attention(nn.Module):
self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)
elif qk_norm == "rms_norm":
self.norm_q = RMSNorm(dim_head, eps=eps)
self.norm_k = RMSNorm(dim_head, eps=eps)
self.norm_q = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_k = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
elif qk_norm == "rms_norm_across_heads":
# LTX applies qk norm across all heads
self.norm_q = RMSNorm(dim_head * heads, eps=eps)
+2 -3
View File
@@ -52,9 +52,8 @@ class AutoModel(ConfigMixin):
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
dtype is automatically derived from the model's weights.
torch_dtype (`torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
@@ -3,6 +3,7 @@ from .autoencoder_dc import AutoencoderDC
from .autoencoder_kl import AutoencoderKL
from .autoencoder_kl_allegro import AutoencoderKLAllegro
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
from .autoencoder_kl_cosmos import AutoencoderKLCosmos
from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
from .autoencoder_kl_magvit import AutoencoderKLMagvit
File diff suppressed because it is too large Load Diff
+11
View File
@@ -744,6 +744,17 @@ class DiagonalGaussianDistribution(object):
return self.mean
class IdentityDistribution(object):
def __init__(self, parameters: torch.Tensor):
self.parameters = parameters
def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor:
return self.parameters
def mode(self) -> torch.Tensor:
return self.parameters
class EncoderTiny(nn.Module):
r"""
The `EncoderTiny` layer is a simpler version of the `Encoder` layer.
@@ -130,9 +130,8 @@ class MultiControlNetModel(ModelMixin):
A path to a *directory* containing model weights saved using
[`~models.controlnets.multicontrolnet.MultiControlNetModel.save_pretrained`], e.g.,
`./my_model_directory/controlnet`.
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
will be automatically derived from the model's weights.
torch_dtype (`torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model under this dtype.
output_loading_info(`bool`, *optional*, defaults to `False`):
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
@@ -143,9 +143,8 @@ class MultiControlNetUnionModel(ModelMixin):
A path to a *directory* containing model weights saved using
[`~models.controlnets.multicontrolnet.MultiControlNetUnionModel.save_pretrained`], e.g.,
`./my_model_directory/controlnet`.
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
will be automatically derived from the model's weights.
torch_dtype (`torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model under this dtype.
output_loading_info(`bool`, *optional*, defaults to `False`):
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
+3 -3
View File
@@ -31,7 +31,7 @@ def get_timestep_embedding(
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
) -> torch.Tensor:
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
@@ -1204,7 +1204,7 @@ def apply_rotary_emb(
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
elif use_real_unbind_dim == -2:
# Used for Stable Audio, OmniGen and CogView4
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
else:
@@ -1327,7 +1327,7 @@ class Timesteps(nn.Module):
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale
def forward(self, timesteps):
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
+2 -3
View File
@@ -787,9 +787,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
dtype is automatically derived from the model's weights.
torch_dtype (`torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
+40
View File
@@ -171,6 +171,46 @@ class AdaLayerNormZero(nn.Module):
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
class AdaLayerNormZeroPruned(nn.Module):
r"""
Norm layer adaptive layer norm zero (adaLN-Zero).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the embeddings dictionary.
"""
def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True):
super().__init__()
if num_embeddings is not None:
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
else:
self.emb = None
if norm_type == "layer_norm":
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
elif norm_type == "fp32_layer_norm":
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
else:
raise ValueError(
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
)
def forward(
self,
x: torch.Tensor,
timestep: Optional[torch.Tensor] = None,
class_labels: Optional[torch.LongTensor] = None,
hidden_dtype: Optional[torch.dtype] = None,
emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
if self.emb is not None:
emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.squeeze(0).chunk(6, dim=0)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
class AdaLayerNormZeroSingle(nn.Module):
r"""
Norm layer adaptive layer norm zero (adaLN-Zero).
@@ -17,12 +17,15 @@ if is_torch_available():
from .t5_film_transformer import T5FilmDecoder
from .transformer_2d import Transformer2DModel
from .transformer_allegro import AllegroTransformer3DModel
from .transformer_chroma import ChromaTransformer2DModel
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
from .transformer_cogview4 import CogView4Transformer2DModel
from .transformer_cosmos import CosmosTransformer3DModel
from .transformer_easyanimate import EasyAnimateTransformer3DModel
from .transformer_flux import FluxTransformer2DModel
from .transformer_hidream_image import HiDreamImageTransformer2DModel
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel
from .transformer_ltx import LTXVideoTransformer3DModel
from .transformer_lumina2 import Lumina2Transformer2DModel
from .transformer_mochi import MochiTransformer3DModel
@@ -0,0 +1,753 @@
# Copyright 2025 Black Forest Labs, The HuggingFace Team and lodestone-rock. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.import_utils import is_torch_npu_available
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import (
Attention,
AttentionProcessor,
FluxAttnProcessor2_0,
FluxAttnProcessor2_0_NPU,
FusedFluxAttnProcessor2_0,
)
from ..cache_utils import CacheMixin
from ..embeddings import (
CombinedTimestepLabelEmbeddings,
FluxPosEmbed,
PixArtAlphaTextProjection,
Timesteps,
get_timestep_embedding,
)
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import FP32LayerNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class ChromaApproximator(nn.Module):
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers: int = 5):
super().__init__()
self.in_proj = nn.Linear(in_dim, hidden_dim, bias=True)
self.layers = nn.ModuleList(
[PixArtAlphaTextProjection(hidden_dim, hidden_dim, act_fn="silu") for _ in range(n_layers)]
)
self.norms = nn.ModuleList([nn.RMSNorm(hidden_dim) for _ in range(n_layers)])
self.out_proj = nn.Linear(hidden_dim, out_dim)
def forward(self, x):
x = self.in_proj(x)
for layer, norms in zip(self.layers, self.norms):
x = x + layer(norms(x))
return self.out_proj(x)
class ChromaTimestepEmbeddings(nn.Module):
def __init__(
self,
num_channels: int,
out_dim: int,
):
super().__init__()
self.time_proj = Timesteps(num_channels=num_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
self.guidance_proj = Timesteps(num_channels=num_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
self.register_buffer(
"mod_proj",
get_timestep_embedding(
torch.arange(out_dim) * 1000,
2 * num_channels,
flip_sin_to_cos=True,
downscale_freq_shift=0,
),
persistent=False,
)
def forward(self, timestep: torch.Tensor) -> torch.Tensor:
mod_index_length = self.mod_proj.shape[0]
timesteps_proj = self.time_proj(timestep).to(dtype=timestep.dtype)
guidance_proj = self.guidance_proj(torch.tensor([0])).to(dtype=timestep.dtype, device=timestep.device)
mod_proj = self.mod_proj.to(dtype=timesteps_proj.dtype, device=timesteps_proj.device)
timestep_guidance = (
torch.cat([timesteps_proj, guidance_proj], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1)
)
input_vec = torch.cat([timestep_guidance, mod_proj.unsqueeze(0)], dim=-1)
return input_vec
class ChromaAdaLayerNormZeroSinglePruned(nn.Module):
r"""
Norm layer adaptive layer norm zero (adaLN-Zero).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the embeddings dictionary.
"""
def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
super().__init__()
if norm_type == "layer_norm":
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
else:
raise ValueError(
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
)
def forward(
self,
x: torch.Tensor,
emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
shift_msa, scale_msa, gate_msa = emb.squeeze(0).chunk(3, dim=0)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa
class ChromaAdaLayerNormZeroPruned(nn.Module):
r"""
Norm layer adaptive layer norm zero (adaLN-Zero).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the embeddings dictionary.
"""
def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, norm_type="layer_norm", bias=True):
super().__init__()
if num_embeddings is not None:
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
else:
self.emb = None
if norm_type == "layer_norm":
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
elif norm_type == "fp32_layer_norm":
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
else:
raise ValueError(
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
)
def forward(
self,
x: torch.Tensor,
timestep: Optional[torch.Tensor] = None,
class_labels: Optional[torch.LongTensor] = None,
hidden_dtype: Optional[torch.dtype] = None,
emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
if self.emb is not None:
emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.squeeze(0).chunk(6, dim=0)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
@maybe_allow_in_graph
class ChromaSingleTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
mlp_ratio: float = 4.0,
):
super().__init__()
self.mlp_hidden_dim = int(dim * mlp_ratio)
self.norm = ChromaAdaLayerNormZeroSinglePruned(dim)
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
if is_torch_npu_available():
deprecation_message = (
"Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
"should be set explicitly using the `set_attn_processor` method."
)
deprecate("npu_processor", "0.34.0", deprecation_message)
processor = FluxAttnProcessor2_0_NPU()
else:
processor = FluxAttnProcessor2_0()
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
processor=processor,
qk_norm="rms_norm",
eps=1e-6,
pre_only=True,
)
def forward(
self,
hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
joint_attention_kwargs = joint_attention_kwargs or {}
attn_output = self.attn(
hidden_states=norm_hidden_states,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
gate = gate.unsqueeze(1)
hidden_states = gate * self.proj_out(hidden_states)
hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
return hidden_states
@maybe_allow_in_graph
class ChromaTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
qk_norm: str = "rms_norm",
eps: float = 1e-6,
):
super().__init__()
self.norm1 = ChromaAdaLayerNormZeroPruned(dim)
self.norm1_context = ChromaAdaLayerNormZeroPruned(dim)
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=False,
bias=True,
processor=FluxAttnProcessor2_0(),
qk_norm=qk_norm,
eps=eps,
)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
temb_img, temb_txt = temb[:, :6], temb[:, 6:]
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb_img)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb_txt
)
joint_attention_kwargs = joint_attention_kwargs or {}
# Attention.
attention_outputs = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)
if len(attention_outputs) == 2:
attn_output, context_attn_output = attention_outputs
elif len(attention_outputs) == 3:
attn_output, context_attn_output, ip_attn_output = attention_outputs
# Process attention outputs for the `hidden_states`.
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = hidden_states + ff_output
if len(attention_outputs) == 3:
hidden_states = hidden_states + ip_attn_output
# Process attention outputs for the `encoder_hidden_states`.
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
if encoder_hidden_states.dtype == torch.float16:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
return encoder_hidden_states, hidden_states
class ChromaAdaLayerNormContinuous(nn.Module):
r"""
Adaptive normalization layer with a norm layer (layer_norm or rms_norm).
Args:
embedding_dim (`int`): Embedding dimension to use during projection.
conditioning_embedding_dim (`int`): Dimension of the input condition.
elementwise_affine (`bool`, defaults to `True`):
Boolean flag to denote if affine transformation should be applied.
eps (`float`, defaults to 1e-5): Epsilon factor.
bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
norm_type (`str`, defaults to `"layer_norm"`):
Normalization layer to use. Values supported: "layer_norm", "rms_norm".
"""
def __init__(
self,
embedding_dim: int,
conditioning_embedding_dim: int,
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
# However, this is how it was implemented in the original code, and it's rather likely you should
# set `elementwise_affine` to False.
elementwise_affine=True,
eps=1e-5,
bias=True,
norm_type="layer_norm",
):
super().__init__()
if norm_type == "layer_norm":
self.norm = nn.LayerNorm(embedding_dim, eps, elementwise_affine, bias)
elif norm_type == "rms_norm":
self.norm = nn.RMSNorm(embedding_dim, eps, elementwise_affine)
else:
raise ValueError(f"unknown norm_type {norm_type}")
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
shift, scale = torch.chunk(emb.squeeze(0).to(x.dtype), 2, dim=0)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x
class ChromaTransformer2DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
):
"""
The Transformer model based on Flux SCHNELL architecture.
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
Args:
patch_size (`int`, defaults to `1`):
Patch size to turn the input data into small patches.
in_channels (`int`, defaults to `64`):
The number of channels in the input.
out_channels (`int`, *optional*, defaults to `None`):
The number of channels in the output. If not specified, it defaults to `in_channels`.
num_layers (`int`, defaults to `19`):
The number of layers of dual stream DiT blocks to use.
num_single_layers (`int`, defaults to `38`):
The number of layers of single stream DiT blocks to use.
attention_head_dim (`int`, defaults to `128`):
The number of dimensions to use for each attention head.
num_attention_heads (`int`, defaults to `24`):
The number of attention heads to use.
joint_attention_dim (`int`, defaults to `4096`):
The number of dimensions to use for the joint attention (embedding/channel dimension of
`encoder_hidden_states`).
pooled_projection_dim (`int`, defaults to `768`):
The number of dimensions to use for the pooled projection.
guidance_embeds (`bool`, defaults to `False`):
Whether to use guidance embeddings for guidance-distilled variant of the model.
axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
The dimensions to use for the rotary positional embeddings.
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
@register_to_config
def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
out_channels: Optional[int] = None,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 4096,
axes_dims_rope: Tuple[int, ...] = (16, 56, 56),
approximator_in_factor: int = 16,
approximator_hidden_dim: int = 5120,
approximator_layers: int = 5,
):
super().__init__()
self.out_channels = out_channels or in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
self.time_text_embed = ChromaTimestepEmbeddings(
num_channels=approximator_in_factor, out_dim=3 * num_single_layers + 2 * 6 * num_layers + 2
)
self.distilled_guidance_layer = ChromaApproximator(
in_dim=in_channels,
out_dim=self.inner_dim,
hidden_dim=approximator_hidden_dim,
n_layers=approximator_layers,
)
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
self.x_embedder = nn.Linear(in_channels, self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
ChromaTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
)
for _ in range(num_layers)
]
)
self.single_transformer_blocks = nn.ModuleList(
[
ChromaSingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
)
for _ in range(num_single_layers)
]
)
self.norm_out = ChromaAdaLayerNormContinuous(
self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6
)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
self.gradient_checkpointing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
self.set_attn_processor(FusedFluxAttnProcessor2_0())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
return_dict: bool = True,
controlnet_blocks_repeat: bool = False,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
"""
The [`FluxTransformer2DModel`] forward method.
Args:
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
Input `hidden_states`.
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
input_vec = self.time_text_embed(timestep)
pooled_temb = self.distilled_guidance_layer(input_vec)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if txt_ids.ndim == 3:
logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
logger.warning(
"Passing `img_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
for index_block, block in enumerate(self.transformer_blocks):
img_offset = 3 * len(self.single_transformer_blocks)
txt_offset = img_offset + 6 * len(self.transformer_blocks)
img_modulation = img_offset + 6 * index_block
text_modulation = txt_offset + 6 * index_block
temb = torch.cat(
(
pooled_temb[:, img_modulation : img_modulation + 6],
pooled_temb[:, text_modulation : text_modulation + 6],
),
dim=1,
)
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = int(np.ceil(interval_control))
# For Xlabs ControlNet.
if controlnet_blocks_repeat:
hidden_states = (
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
)
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
start_idx = 3 * index_block
temb = pooled_temb[:, start_idx : start_idx + 3]
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
temb,
image_rotary_emb,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ controlnet_single_block_samples[index_block // interval_control]
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
temb = pooled_temb[:, -2:]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
@@ -0,0 +1,555 @@
# Copyright 2024 The NVIDIA Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torchvision_available
from ..attention import FeedForward
from ..attention_processor import Attention
from ..embeddings import Timesteps
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import RMSNorm
if is_torchvision_available():
from torchvision import transforms
class CosmosPatchEmbed(nn.Module):
def __init__(
self, in_channels: int, out_channels: int, patch_size: Tuple[int, int, int], bias: bool = True
) -> None:
super().__init__()
self.patch_size = patch_size
self.proj = nn.Linear(in_channels * patch_size[0] * patch_size[1] * patch_size[2], out_channels, bias=bias)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.patch_size
hidden_states = hidden_states.reshape(
batch_size, num_channels, num_frames // p_t, p_t, height // p_h, p_h, width // p_w, p_w
)
hidden_states = hidden_states.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7)
hidden_states = self.proj(hidden_states)
return hidden_states
class CosmosTimestepEmbedding(nn.Module):
def __init__(self, in_features: int, out_features: int) -> None:
super().__init__()
self.linear_1 = nn.Linear(in_features, out_features, bias=False)
self.activation = nn.SiLU()
self.linear_2 = nn.Linear(out_features, 3 * out_features, bias=False)
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
emb = self.linear_1(timesteps)
emb = self.activation(emb)
emb = self.linear_2(emb)
return emb
class CosmosEmbedding(nn.Module):
def __init__(self, embedding_dim: int, condition_dim: int) -> None:
super().__init__()
self.time_proj = Timesteps(embedding_dim, flip_sin_to_cos=True, downscale_freq_shift=0.0)
self.t_embedder = CosmosTimestepEmbedding(embedding_dim, condition_dim)
self.norm = RMSNorm(embedding_dim, eps=1e-6, elementwise_affine=True)
def forward(self, hidden_states: torch.Tensor, timestep: torch.LongTensor) -> torch.Tensor:
timesteps_proj = self.time_proj(timestep).type_as(hidden_states)
temb = self.t_embedder(timesteps_proj)
embedded_timestep = self.norm(timesteps_proj)
return temb, embedded_timestep
class CosmosAdaLayerNorm(nn.Module):
def __init__(self, in_features: int, hidden_features: int) -> None:
super().__init__()
self.embedding_dim = in_features
self.activation = nn.SiLU()
self.norm = nn.LayerNorm(in_features, elementwise_affine=False, eps=1e-6)
self.linear_1 = nn.Linear(in_features, hidden_features, bias=False)
self.linear_2 = nn.Linear(hidden_features, 2 * in_features, bias=False)
def forward(
self, hidden_states: torch.Tensor, embedded_timestep: torch.Tensor, temb: Optional[torch.Tensor] = None
) -> torch.Tensor:
embedded_timestep = self.activation(embedded_timestep)
embedded_timestep = self.linear_1(embedded_timestep)
embedded_timestep = self.linear_2(embedded_timestep)
if temb is not None:
embedded_timestep = embedded_timestep + temb[:, : 2 * self.embedding_dim]
shift, scale = embedded_timestep.chunk(2, dim=1)
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
return hidden_states
class CosmosAdaLayerNormZero(nn.Module):
def __init__(self, in_features: int, hidden_features: Optional[int] = None) -> None:
super().__init__()
self.norm = nn.LayerNorm(in_features, elementwise_affine=False, eps=1e-6)
self.activation = nn.SiLU()
if hidden_features is None:
self.linear_1 = nn.Identity()
else:
self.linear_1 = nn.Linear(in_features, hidden_features, bias=False)
self.linear_2 = nn.Linear(hidden_features, 3 * in_features, bias=False)
def forward(
self,
hidden_states: torch.Tensor,
embedded_timestep: torch.Tensor,
temb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
embedded_timestep = self.activation(embedded_timestep)
embedded_timestep = self.linear_1(embedded_timestep)
embedded_timestep = self.linear_2(embedded_timestep)
if temb is not None:
embedded_timestep = embedded_timestep + temb
shift, scale, gate = embedded_timestep.chunk(3, dim=1)
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
return hidden_states, gate
class CosmosAttnProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CosmosAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# 1. QKV projections
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
# 2. QK normalization
query = attn.norm_q(query)
key = attn.norm_k(key)
# 3. Apply RoPE
if image_rotary_emb is not None:
from ..embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
# 4. Prepare for GQA
query_idx = torch.tensor(query.size(3), device=query.device)
key_idx = torch.tensor(key.size(3), device=key.device)
value_idx = torch.tensor(value.size(3), device=value.device)
key = key.repeat_interleave(query_idx // key_idx, dim=3)
value = value.repeat_interleave(query_idx // value_idx, dim=3)
# 5. Attention
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3).type_as(query)
# 6. Output projection
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class CosmosTransformerBlock(nn.Module):
def __init__(
self,
num_attention_heads: int,
attention_head_dim: int,
cross_attention_dim: int,
mlp_ratio: float = 4.0,
adaln_lora_dim: int = 256,
qk_norm: str = "rms_norm",
out_bias: bool = False,
) -> None:
super().__init__()
hidden_size = num_attention_heads * attention_head_dim
self.norm1 = CosmosAdaLayerNormZero(in_features=hidden_size, hidden_features=adaln_lora_dim)
self.attn1 = Attention(
query_dim=hidden_size,
cross_attention_dim=None,
heads=num_attention_heads,
dim_head=attention_head_dim,
qk_norm=qk_norm,
elementwise_affine=True,
out_bias=out_bias,
processor=CosmosAttnProcessor2_0(),
)
self.norm2 = CosmosAdaLayerNormZero(in_features=hidden_size, hidden_features=adaln_lora_dim)
self.attn2 = Attention(
query_dim=hidden_size,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
qk_norm=qk_norm,
elementwise_affine=True,
out_bias=out_bias,
processor=CosmosAttnProcessor2_0(),
)
self.norm3 = CosmosAdaLayerNormZero(in_features=hidden_size, hidden_features=adaln_lora_dim)
self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu", bias=out_bias)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
embedded_timestep: torch.Tensor,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
extra_pos_emb: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if extra_pos_emb is not None:
hidden_states = hidden_states + extra_pos_emb
# 1. Self Attention
norm_hidden_states, gate = self.norm1(hidden_states, embedded_timestep, temb)
attn_output = self.attn1(norm_hidden_states, image_rotary_emb=image_rotary_emb)
hidden_states = hidden_states + gate.unsqueeze(1) * attn_output
# 2. Cross Attention
norm_hidden_states, gate = self.norm2(hidden_states, embedded_timestep, temb)
attn_output = self.attn2(
norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
)
hidden_states = hidden_states + gate.unsqueeze(1) * attn_output
# 3. Feed Forward
norm_hidden_states, gate = self.norm3(hidden_states, embedded_timestep, temb)
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + gate.unsqueeze(1) * ff_output
return hidden_states
class CosmosRotaryPosEmbed(nn.Module):
def __init__(
self,
hidden_size: int,
max_size: Tuple[int, int, int] = (128, 240, 240),
patch_size: Tuple[int, int, int] = (1, 2, 2),
base_fps: int = 24,
rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0),
) -> None:
super().__init__()
self.max_size = [size // patch for size, patch in zip(max_size, patch_size)]
self.patch_size = patch_size
self.base_fps = base_fps
self.dim_h = hidden_size // 6 * 2
self.dim_w = hidden_size // 6 * 2
self.dim_t = hidden_size - self.dim_h - self.dim_w
self.h_ntk_factor = rope_scale[1] ** (self.dim_h / (self.dim_h - 2))
self.w_ntk_factor = rope_scale[2] ** (self.dim_w / (self.dim_w - 2))
self.t_ntk_factor = rope_scale[0] ** (self.dim_t / (self.dim_t - 2))
def forward(self, hidden_states: torch.Tensor, fps: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
pe_size = [num_frames // self.patch_size[0], height // self.patch_size[1], width // self.patch_size[2]]
device = hidden_states.device
h_theta = 10000.0 * self.h_ntk_factor
w_theta = 10000.0 * self.w_ntk_factor
t_theta = 10000.0 * self.t_ntk_factor
seq = torch.arange(max(self.max_size), device=device, dtype=torch.float32)
dim_h_range = (
torch.arange(0, self.dim_h, 2, device=device, dtype=torch.float32)[: (self.dim_h // 2)] / self.dim_h
)
dim_w_range = (
torch.arange(0, self.dim_w, 2, device=device, dtype=torch.float32)[: (self.dim_w // 2)] / self.dim_w
)
dim_t_range = (
torch.arange(0, self.dim_t, 2, device=device, dtype=torch.float32)[: (self.dim_t // 2)] / self.dim_t
)
h_spatial_freqs = 1.0 / (h_theta**dim_h_range)
w_spatial_freqs = 1.0 / (w_theta**dim_w_range)
temporal_freqs = 1.0 / (t_theta**dim_t_range)
emb_h = torch.outer(seq[: pe_size[1]], h_spatial_freqs)[None, :, None, :].repeat(pe_size[0], 1, pe_size[2], 1)
emb_w = torch.outer(seq[: pe_size[2]], w_spatial_freqs)[None, None, :, :].repeat(pe_size[0], pe_size[1], 1, 1)
# Apply sequence scaling in temporal dimension
if fps is None:
# Images
emb_t = torch.outer(seq[: pe_size[0]], temporal_freqs)
else:
# Videos
emb_t = torch.outer(seq[: pe_size[0]] / fps * self.base_fps, temporal_freqs)
emb_t = emb_t[:, None, None, :].repeat(1, pe_size[1], pe_size[2], 1)
freqs = torch.cat([emb_t, emb_h, emb_w] * 2, dim=-1).flatten(0, 2).float()
cos = torch.cos(freqs)
sin = torch.sin(freqs)
return cos, sin
class CosmosLearnablePositionalEmbed(nn.Module):
def __init__(
self,
hidden_size: int,
max_size: Tuple[int, int, int],
patch_size: Tuple[int, int, int],
eps: float = 1e-6,
) -> None:
super().__init__()
self.max_size = [size // patch for size, patch in zip(max_size, patch_size)]
self.patch_size = patch_size
self.eps = eps
self.pos_emb_t = nn.Parameter(torch.zeros(self.max_size[0], hidden_size))
self.pos_emb_h = nn.Parameter(torch.zeros(self.max_size[1], hidden_size))
self.pos_emb_w = nn.Parameter(torch.zeros(self.max_size[2], hidden_size))
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
pe_size = [num_frames // self.patch_size[0], height // self.patch_size[1], width // self.patch_size[2]]
emb_t = self.pos_emb_t[: pe_size[0]][None, :, None, None, :].repeat(batch_size, 1, pe_size[1], pe_size[2], 1)
emb_h = self.pos_emb_h[: pe_size[1]][None, None, :, None, :].repeat(batch_size, pe_size[0], 1, pe_size[2], 1)
emb_w = self.pos_emb_w[: pe_size[2]][None, None, None, :, :].repeat(batch_size, pe_size[0], pe_size[1], 1, 1)
emb = emb_t + emb_h + emb_w
emb = emb.flatten(1, 3)
norm = torch.linalg.vector_norm(emb, dim=-1, keepdim=True, dtype=torch.float32)
norm = torch.add(self.eps, norm, alpha=np.sqrt(norm.numel() / emb.numel()))
return (emb / norm).type_as(hidden_states)
class CosmosTransformer3DModel(ModelMixin, ConfigMixin):
r"""
A Transformer model for video-like data used in [Cosmos](https://github.com/NVIDIA/Cosmos).
Args:
in_channels (`int`, defaults to `16`):
The number of channels in the input.
out_channels (`int`, defaults to `16`):
The number of channels in the output.
num_attention_heads (`int`, defaults to `32`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`, defaults to `128`):
The number of channels in each attention head.
num_layers (`int`, defaults to `28`):
The number of layers of transformer blocks to use.
mlp_ratio (`float`, defaults to `4.0`):
The ratio of the hidden layer size to the input size in the feedforward network.
text_embed_dim (`int`, defaults to `4096`):
Input dimension of text embeddings from the text encoder.
adaln_lora_dim (`int`, defaults to `256`):
The hidden dimension of the Adaptive LayerNorm LoRA layer.
max_size (`Tuple[int, int, int]`, defaults to `(128, 240, 240)`):
The maximum size of the input latent tensors in the temporal, height, and width dimensions.
patch_size (`Tuple[int, int, int]`, defaults to `(1, 2, 2)`):
The patch size to use for patchifying the input latent tensors in the temporal, height, and width
dimensions.
rope_scale (`Tuple[float, float, float]`, defaults to `(2.0, 1.0, 1.0)`):
The scaling factor to use for RoPE in the temporal, height, and width dimensions.
concat_padding_mask (`bool`, defaults to `True`):
Whether to concatenate the padding mask to the input latent tensors.
extra_pos_embed_type (`str`, *optional*, defaults to `learnable`):
The type of extra positional embeddings to use. Can be one of `None` or `learnable`.
"""
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["patch_embed", "final_layer", "norm"]
_no_split_modules = ["CosmosTransformerBlock"]
_keep_in_fp32_modules = ["learnable_pos_embed"]
@register_to_config
def __init__(
self,
in_channels: int = 16,
out_channels: int = 16,
num_attention_heads: int = 32,
attention_head_dim: int = 128,
num_layers: int = 28,
mlp_ratio: float = 4.0,
text_embed_dim: int = 1024,
adaln_lora_dim: int = 256,
max_size: Tuple[int, int, int] = (128, 240, 240),
patch_size: Tuple[int, int, int] = (1, 2, 2),
rope_scale: Tuple[float, float, float] = (2.0, 1.0, 1.0),
concat_padding_mask: bool = True,
extra_pos_embed_type: Optional[str] = "learnable",
) -> None:
super().__init__()
hidden_size = num_attention_heads * attention_head_dim
# 1. Patch Embedding
patch_embed_in_channels = in_channels + 1 if concat_padding_mask else in_channels
self.patch_embed = CosmosPatchEmbed(patch_embed_in_channels, hidden_size, patch_size, bias=False)
# 2. Positional Embedding
self.rope = CosmosRotaryPosEmbed(
hidden_size=attention_head_dim, max_size=max_size, patch_size=patch_size, rope_scale=rope_scale
)
self.learnable_pos_embed = None
if extra_pos_embed_type == "learnable":
self.learnable_pos_embed = CosmosLearnablePositionalEmbed(
hidden_size=hidden_size,
max_size=max_size,
patch_size=patch_size,
)
# 3. Time Embedding
self.time_embed = CosmosEmbedding(hidden_size, hidden_size)
# 4. Transformer Blocks
self.transformer_blocks = nn.ModuleList(
[
CosmosTransformerBlock(
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
cross_attention_dim=text_embed_dim,
mlp_ratio=mlp_ratio,
adaln_lora_dim=adaln_lora_dim,
qk_norm="rms_norm",
out_bias=False,
)
for _ in range(num_layers)
]
)
# 5. Output norm & projection
self.norm_out = CosmosAdaLayerNorm(hidden_size, adaln_lora_dim)
self.proj_out = nn.Linear(
hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels, bias=False
)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
fps: Optional[int] = None,
condition_mask: Optional[torch.Tensor] = None,
padding_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
# 1. Concatenate padding mask if needed & prepare attention mask
if condition_mask is not None:
hidden_states = torch.cat([hidden_states, condition_mask], dim=1)
if self.config.concat_padding_mask:
padding_mask = transforms.functional.resize(
padding_mask, list(hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
)
hidden_states = torch.cat(
[hidden_states, padding_mask.unsqueeze(2).repeat(batch_size, 1, num_frames, 1, 1)], dim=1
)
if attention_mask is not None:
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, S]
# 2. Generate positional embeddings
image_rotary_emb = self.rope(hidden_states, fps=fps)
extra_pos_emb = self.learnable_pos_embed(hidden_states) if self.config.extra_pos_embed_type else None
# 3. Patchify input
p_t, p_h, p_w = self.config.patch_size
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p_h
post_patch_width = width // p_w
hidden_states = self.patch_embed(hidden_states)
hidden_states = hidden_states.flatten(1, 3) # [B, T, H, W, C] -> [B, THW, C]
# 4. Timestep embeddings
temb, embedded_timestep = self.time_embed(hidden_states, timestep)
# 5. Transformer blocks
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
embedded_timestep,
temb,
image_rotary_emb,
extra_pos_emb,
attention_mask,
)
else:
hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
embedded_timestep=embedded_timestep,
temb=temb,
image_rotary_emb=image_rotary_emb,
extra_pos_emb=extra_pos_emb,
attention_mask=attention_mask,
)
# 6. Output norm & projection & unpatchify
hidden_states = self.norm_out(hidden_states, embedded_timestep, temb)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.unflatten(2, (p_h, p_w, p_t, -1))
hidden_states = hidden_states.unflatten(1, (post_patch_num_frames, post_patch_height, post_patch_width))
# Please just kill me at this point. What even is this permutation order and why is it different from the patching order?
# Another few hours of sanity lost to the void.
hidden_states = hidden_states.permute(0, 7, 1, 6, 2, 4, 3, 5)
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if not return_dict:
return (hidden_states,)
return Transformer2DModelOutput(sample=hidden_states)
@@ -241,7 +241,7 @@ class FluxTransformer2DModel(
joint_attention_dim: int = 4096,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: Tuple[int] = (16, 56, 56),
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
):
super().__init__()
self.out_channels = out_channels or in_channels
@@ -447,8 +447,6 @@ class FluxTransformer2DModel(
timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
else:
guidance = None
temb = (
self.time_text_embed(timestep, pooled_projections)
@@ -5,7 +5,7 @@ import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...models.modeling_outputs import Transformer2DModelOutput
from ...models.modeling_utils import ModelMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
@@ -602,7 +602,7 @@ class HiDreamBlock(nn.Module):
)
class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
_supports_gradient_checkpointing = True
_no_split_modules = ["HiDreamImageTransformerBlock", "HiDreamImageSingleTransformerBlock"]
@@ -0,0 +1,416 @@
# Copyright 2025 The Framepack Team, The Hunyuan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, get_logger, scale_lora_layers, unscale_lora_layers
from ..cache_utils import CacheMixin
from ..embeddings import get_1d_rotary_pos_embed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous
from .transformer_hunyuan_video import (
HunyuanVideoConditionEmbedding,
HunyuanVideoPatchEmbed,
HunyuanVideoSingleTransformerBlock,
HunyuanVideoTokenRefiner,
HunyuanVideoTransformerBlock,
)
logger = get_logger(__name__) # pylint: disable=invalid-name
class HunyuanVideoFramepackRotaryPosEmbed(nn.Module):
def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None:
super().__init__()
self.patch_size = patch_size
self.patch_size_t = patch_size_t
self.rope_dim = rope_dim
self.theta = theta
def forward(self, frame_indices: torch.Tensor, height: int, width: int, device: torch.device):
height = height // self.patch_size
width = width // self.patch_size
grid = torch.meshgrid(
frame_indices.to(device=device, dtype=torch.float32),
torch.arange(0, height, device=device, dtype=torch.float32),
torch.arange(0, width, device=device, dtype=torch.float32),
indexing="ij",
) # 3 * [W, H, T]
grid = torch.stack(grid, dim=0) # [3, W, H, T]
freqs = []
for i in range(3):
freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True)
freqs.append(freq)
freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
return freqs_cos, freqs_sin
class FramepackClipVisionProjection(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.up = nn.Linear(in_channels, out_channels * 3)
self.down = nn.Linear(out_channels * 3, out_channels)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.up(hidden_states)
hidden_states = F.silu(hidden_states)
hidden_states = self.down(hidden_states)
return hidden_states
class HunyuanVideoHistoryPatchEmbed(nn.Module):
def __init__(self, in_channels: int, inner_dim: int):
super().__init__()
self.proj = nn.Conv3d(in_channels, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
self.proj_2x = nn.Conv3d(in_channels, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
self.proj_4x = nn.Conv3d(in_channels, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
def forward(
self,
latents_clean: Optional[torch.Tensor] = None,
latents_clean_2x: Optional[torch.Tensor] = None,
latents_clean_4x: Optional[torch.Tensor] = None,
):
if latents_clean is not None:
latents_clean = self.proj(latents_clean)
latents_clean = latents_clean.flatten(2).transpose(1, 2)
if latents_clean_2x is not None:
latents_clean_2x = _pad_for_3d_conv(latents_clean_2x, (2, 4, 4))
latents_clean_2x = self.proj_2x(latents_clean_2x)
latents_clean_2x = latents_clean_2x.flatten(2).transpose(1, 2)
if latents_clean_4x is not None:
latents_clean_4x = _pad_for_3d_conv(latents_clean_4x, (4, 8, 8))
latents_clean_4x = self.proj_4x(latents_clean_4x)
latents_clean_4x = latents_clean_4x.flatten(2).transpose(1, 2)
return latents_clean, latents_clean_2x, latents_clean_4x
class HunyuanVideoFramepackTransformer3DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin
):
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
_no_split_modules = [
"HunyuanVideoTransformerBlock",
"HunyuanVideoSingleTransformerBlock",
"HunyuanVideoHistoryPatchEmbed",
"HunyuanVideoTokenRefiner",
]
@register_to_config
def __init__(
self,
in_channels: int = 16,
out_channels: int = 16,
num_attention_heads: int = 24,
attention_head_dim: int = 128,
num_layers: int = 20,
num_single_layers: int = 40,
num_refiner_layers: int = 2,
mlp_ratio: float = 4.0,
patch_size: int = 2,
patch_size_t: int = 1,
qk_norm: str = "rms_norm",
guidance_embeds: bool = True,
text_embed_dim: int = 4096,
pooled_projection_dim: int = 768,
rope_theta: float = 256.0,
rope_axes_dim: Tuple[int] = (16, 56, 56),
image_condition_type: Optional[str] = None,
has_image_proj: int = False,
image_proj_dim: int = 1152,
has_clean_x_embedder: int = False,
) -> None:
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
out_channels = out_channels or in_channels
# 1. Latent and condition embedders
self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
# Framepack history projection embedder
self.clean_x_embedder = None
if has_clean_x_embedder:
self.clean_x_embedder = HunyuanVideoHistoryPatchEmbed(in_channels, inner_dim)
self.context_embedder = HunyuanVideoTokenRefiner(
text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
)
# Framepack image-conditioning embedder
self.image_projection = FramepackClipVisionProjection(image_proj_dim, inner_dim) if has_image_proj else None
self.time_text_embed = HunyuanVideoConditionEmbedding(
inner_dim, pooled_projection_dim, guidance_embeds, image_condition_type
)
# 2. RoPE
self.rope = HunyuanVideoFramepackRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
# 3. Dual stream transformer blocks
self.transformer_blocks = nn.ModuleList(
[
HunyuanVideoTransformerBlock(
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
)
for _ in range(num_layers)
]
)
# 4. Single stream transformer blocks
self.single_transformer_blocks = nn.ModuleList(
[
HunyuanVideoSingleTransformerBlock(
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
)
for _ in range(num_single_layers)
]
)
# 5. Output projection
self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
encoder_attention_mask: torch.Tensor,
pooled_projections: torch.Tensor,
image_embeds: torch.Tensor,
indices_latents: torch.Tensor,
guidance: Optional[torch.Tensor] = None,
latents_clean: Optional[torch.Tensor] = None,
indices_latents_clean: Optional[torch.Tensor] = None,
latents_history_2x: Optional[torch.Tensor] = None,
indices_latents_history_2x: Optional[torch.Tensor] = None,
latents_history_4x: Optional[torch.Tensor] = None,
indices_latents_history_4x: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
):
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p, p_t = self.config.patch_size, self.config.patch_size_t
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p
post_patch_width = width // p
original_context_length = post_patch_num_frames * post_patch_height * post_patch_width
if indices_latents is None:
indices_latents = torch.arange(0, num_frames).unsqueeze(0).expand(batch_size, -1)
hidden_states = self.x_embedder(hidden_states)
image_rotary_emb = self.rope(
frame_indices=indices_latents, height=height, width=width, device=hidden_states.device
)
latents_clean, latents_history_2x, latents_history_4x = self.clean_x_embedder(
latents_clean, latents_history_2x, latents_history_4x
)
if latents_clean is not None and indices_latents_clean is not None:
image_rotary_emb_clean = self.rope(
frame_indices=indices_latents_clean, height=height, width=width, device=hidden_states.device
)
if latents_history_2x is not None and indices_latents_history_2x is not None:
image_rotary_emb_history_2x = self.rope(
frame_indices=indices_latents_history_2x, height=height, width=width, device=hidden_states.device
)
if latents_history_4x is not None and indices_latents_history_4x is not None:
image_rotary_emb_history_4x = self.rope(
frame_indices=indices_latents_history_4x, height=height, width=width, device=hidden_states.device
)
hidden_states, image_rotary_emb = self._pack_history_states(
hidden_states,
latents_clean,
latents_history_2x,
latents_history_4x,
image_rotary_emb,
image_rotary_emb_clean,
image_rotary_emb_history_2x,
image_rotary_emb_history_4x,
post_patch_height,
post_patch_width,
)
temb, _ = self.time_text_embed(timestep, pooled_projections, guidance)
encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
encoder_hidden_states_image = self.image_projection(image_embeds)
attention_mask_image = encoder_attention_mask.new_ones((batch_size, encoder_hidden_states_image.shape[1]))
# must cat before (not after) encoder_hidden_states, due to attn masking
encoder_hidden_states = torch.cat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
encoder_attention_mask = torch.cat([attention_mask_image, encoder_attention_mask], dim=1)
latent_sequence_length = hidden_states.shape[1]
condition_sequence_length = encoder_hidden_states.shape[1]
sequence_length = latent_sequence_length + condition_sequence_length
attention_mask = torch.zeros(
batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool
) # [B, N]
effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,]
effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
if batch_size == 1:
encoder_hidden_states = encoder_hidden_states[:, : effective_condition_sequence_length[0]]
attention_mask = None
else:
for i in range(batch_size):
attention_mask[i, : effective_sequence_length[i]] = True
# [B, 1, 1, N], for broadcasting across attention heads
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block, hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
)
for block in self.single_transformer_blocks:
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block, hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
)
else:
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
)
for block in self.single_transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
)
hidden_states = hidden_states[:, -original_context_length:]
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p
)
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (hidden_states,)
return Transformer2DModelOutput(sample=hidden_states)
def _pack_history_states(
self,
hidden_states: torch.Tensor,
latents_clean: Optional[torch.Tensor] = None,
latents_history_2x: Optional[torch.Tensor] = None,
latents_history_4x: Optional[torch.Tensor] = None,
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] = None,
image_rotary_emb_clean: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
image_rotary_emb_history_2x: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
image_rotary_emb_history_4x: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
height: int = None,
width: int = None,
):
image_rotary_emb = list(image_rotary_emb) # convert tuple to list for in-place modification
if latents_clean is not None and image_rotary_emb_clean is not None:
hidden_states = torch.cat([latents_clean, hidden_states], dim=1)
image_rotary_emb[0] = torch.cat([image_rotary_emb_clean[0], image_rotary_emb[0]], dim=0)
image_rotary_emb[1] = torch.cat([image_rotary_emb_clean[1], image_rotary_emb[1]], dim=0)
if latents_history_2x is not None and image_rotary_emb_history_2x is not None:
hidden_states = torch.cat([latents_history_2x, hidden_states], dim=1)
image_rotary_emb_history_2x = self._pad_rotary_emb(image_rotary_emb_history_2x, height, width, (2, 2, 2))
image_rotary_emb[0] = torch.cat([image_rotary_emb_history_2x[0], image_rotary_emb[0]], dim=0)
image_rotary_emb[1] = torch.cat([image_rotary_emb_history_2x[1], image_rotary_emb[1]], dim=0)
if latents_history_4x is not None and image_rotary_emb_history_4x is not None:
hidden_states = torch.cat([latents_history_4x, hidden_states], dim=1)
image_rotary_emb_history_4x = self._pad_rotary_emb(image_rotary_emb_history_4x, height, width, (4, 4, 4))
image_rotary_emb[0] = torch.cat([image_rotary_emb_history_4x[0], image_rotary_emb[0]], dim=0)
image_rotary_emb[1] = torch.cat([image_rotary_emb_history_4x[1], image_rotary_emb[1]], dim=0)
return hidden_states, tuple(image_rotary_emb)
def _pad_rotary_emb(
self,
image_rotary_emb: Tuple[torch.Tensor],
height: int,
width: int,
kernel_size: Tuple[int, int, int],
):
# freqs_cos, freqs_sin have shape [W * H * T, D / 2], where D is attention head dim
freqs_cos, freqs_sin = image_rotary_emb
freqs_cos = freqs_cos.unsqueeze(0).permute(0, 2, 1).unflatten(2, (-1, height, width))
freqs_sin = freqs_sin.unsqueeze(0).permute(0, 2, 1).unflatten(2, (-1, height, width))
freqs_cos = _pad_for_3d_conv(freqs_cos, kernel_size)
freqs_sin = _pad_for_3d_conv(freqs_sin, kernel_size)
freqs_cos = _center_down_sample_3d(freqs_cos, kernel_size)
freqs_sin = _center_down_sample_3d(freqs_sin, kernel_size)
freqs_cos = freqs_cos.flatten(2).permute(0, 2, 1).squeeze(0)
freqs_sin = freqs_sin.flatten(2).permute(0, 2, 1).squeeze(0)
return freqs_cos, freqs_sin
def _pad_for_3d_conv(x, kernel_size):
if isinstance(x, (tuple, list)):
return tuple(_pad_for_3d_conv(i, kernel_size) for i in x)
b, c, t, h, w = x.shape
pt, ph, pw = kernel_size
pad_t = (pt - (t % pt)) % pt
pad_h = (ph - (h % ph)) % ph
pad_w = (pw - (w % pw)) % pw
return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode="replicate")
def _center_down_sample_3d(x, kernel_size):
if isinstance(x, (tuple, list)):
return tuple(_center_down_sample_3d(i, kernel_size) for i in x)
return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size)
+14 -2
View File
@@ -156,6 +156,8 @@ else:
]
_import_structure["cogview3"] = ["CogView3PlusPipeline"]
_import_structure["cogview4"] = ["CogView4Pipeline", "CogView4ControlPipeline"]
_import_structure["consisid"] = ["ConsisIDPipeline"]
_import_structure["cosmos"] = ["CosmosTextToWorldPipeline", "CosmosVideoToWorldPipeline"]
_import_structure["controlnet"].extend(
[
"BlipDiffusionControlNetPipeline",
@@ -227,6 +229,7 @@ else:
"HunyuanVideoPipeline",
"HunyuanSkyreelsImageToVideoPipeline",
"HunyuanVideoImageToVideoPipeline",
"HunyuanVideoFramepackPipeline",
]
_import_structure["kandinsky"] = [
"KandinskyCombinedPipeline",
@@ -265,7 +268,12 @@ else:
]
)
_import_structure["latte"] = ["LattePipeline"]
_import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline", "LTXConditionPipeline"]
_import_structure["ltx"] = [
"LTXPipeline",
"LTXImageToVideoPipeline",
"LTXConditionPipeline",
"LTXLatentUpsamplePipeline",
]
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
_import_structure["marigold"].extend(
@@ -278,6 +286,7 @@ else:
_import_structure["mochi"] = ["MochiPipeline"]
_import_structure["musicldm"] = ["MusicLDMPipeline"]
_import_structure["omnigen"] = ["OmniGenPipeline"]
_import_structure["visualcloze"] = ["VisualClozePipeline", "VisualClozeGenerationPipeline"]
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
_import_structure["pia"] = ["PIAPipeline"]
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"]
@@ -545,6 +554,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionControlNetXSPipeline,
StableDiffusionXLControlNetXSPipeline,
)
from .cosmos import CosmosTextToWorldPipeline, CosmosVideoToWorldPipeline
from .deepfloyd_if import (
IFImg2ImgPipeline,
IFImg2ImgSuperResolutionPipeline,
@@ -589,6 +599,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .hidream_image import HiDreamImagePipeline
from .hunyuan_video import (
HunyuanSkyreelsImageToVideoPipeline,
HunyuanVideoFramepackPipeline,
HunyuanVideoImageToVideoPipeline,
HunyuanVideoPipeline,
)
@@ -631,7 +642,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
)
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXPipeline
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
from .marigold import (
@@ -722,6 +733,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
UniDiffuserPipeline,
UniDiffuserTextDecoder,
)
from .visualcloze import VisualClozeGenerationPipeline, VisualClozePipeline
from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline
from .wuerstchen import (
WuerstchenCombinedPipeline,
@@ -40,6 +40,7 @@ from ...utils import (
logging,
replace_example_docstring,
)
from ...utils.import_utils import is_transformers_version
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
@@ -312,8 +313,19 @@ class AudioLDM2Pipeline(DiffusionPipeline):
`inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
The sequence of generated hidden-states.
"""
cache_position_kwargs = {}
if is_transformers_version("<", "4.52.0.dev0"):
cache_position_kwargs["input_ids"] = inputs_embeds
cache_position_kwargs["model_kwargs"] = model_kwargs
else:
cache_position_kwargs["seq_length"] = inputs_embeds.shape[0]
cache_position_kwargs["device"] = (
self.language_model.device if getattr(self, "language_model", None) is not None else self.device
)
cache_position_kwargs["model_kwargs"] = model_kwargs
max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens
model_kwargs = self.language_model._get_initial_cache_position(inputs_embeds, model_kwargs)
model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs)
for _ in range(max_new_tokens):
# prepare model inputs
model_inputs = prepare_inputs_for_generation(inputs_embeds, **model_kwargs)
+4 -7
View File
@@ -322,9 +322,8 @@ class AutoPipelineForText2Image(ConfigMixin):
- A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights
saved using
[`~DiffusionPipeline.save_pretrained`].
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
dtype is automatically derived from the model's weights.
torch_dtype (`torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
@@ -619,8 +618,7 @@ class AutoPipelineForImage2Image(ConfigMixin):
saved using
[`~DiffusionPipeline.save_pretrained`].
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
dtype is automatically derived from the model's weights.
Override the default `torch.dtype` and load the model with another dtype.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
@@ -930,8 +928,7 @@ class AutoPipelineForInpainting(ConfigMixin):
saved using
[`~DiffusionPipeline.save_pretrained`].
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
dtype is automatically derived from the model's weights.
Override the default `torch.dtype` and load the model with another dtype.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
@@ -0,0 +1,50 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_cosmos_text2world"] = ["CosmosTextToWorldPipeline"]
_import_structure["pipeline_cosmos_video2world"] = ["CosmosVideoToWorldPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_cosmos_text2world import CosmosTextToWorldPipeline
from .pipeline_cosmos_video2world import CosmosVideoToWorldPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
@@ -0,0 +1,667 @@
# Copyright 2024 The NVIDIA Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import T5EncoderModel, T5TokenizerFast
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...models import AutoencoderKLCosmos, CosmosTransformer3DModel
from ...schedulers import EDMEulerScheduler
from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import CosmosPipelineOutput
if is_cosmos_guardrail_available():
from cosmos_guardrail import CosmosSafetyChecker
else:
class CosmosSafetyChecker:
def __init__(self, *args, **kwargs):
raise ImportError(
"`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`."
)
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers import CosmosTextToWorldPipeline
>>> from diffusers.utils import export_to_video
>>> model_id = "nvidia/Cosmos-1.0-Diffusion-7B-Text2World"
>>> pipe = CosmosTextToWorldPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> prompt = "A sleek, humanoid robot stands in a vast warehouse filled with neatly stacked cardboard boxes on industrial shelves. The robot's metallic body gleams under the bright, even lighting, highlighting its futuristic design and intricate joints. A glowing blue light emanates from its chest, adding a touch of advanced technology. The background is dominated by rows of boxes, suggesting a highly organized storage system. The floor is lined with wooden pallets, enhancing the industrial setting. The camera remains static, capturing the robot's poised stance amidst the orderly environment, with a shallow depth of field that keeps the focus on the robot while subtly blurring the background for a cinematic effect."
>>> output = pipe(prompt=prompt).frames[0]
>>> export_to_video(output, "output.mp4", fps=30)
```
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class CosmosTextToWorldPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-video generation using [Cosmos](https://github.com/NVIDIA/Cosmos).
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Args:
text_encoder ([`T5EncoderModel`]):
Frozen text-encoder. Cosmos uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
[t5-11b](https://huggingface.co/google-t5/t5-11b) variant.
tokenizer (`T5TokenizerFast`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
transformer ([`CosmosTransformer3DModel`]):
Conditional Transformer to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLCosmos`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
# We mark safety_checker as optional here to get around some test failures, but it is not really optional
_optional_components = ["safety_checker"]
def __init__(
self,
text_encoder: T5EncoderModel,
tokenizer: T5TokenizerFast,
transformer: CosmosTransformer3DModel,
vae: AutoencoderKLCosmos,
scheduler: EDMEulerScheduler,
safety_checker: CosmosSafetyChecker = None,
):
super().__init__()
if safety_checker is None:
safety_checker = CosmosSafetyChecker()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
safety_checker=safety_checker,
)
self.vae_scale_factor_temporal = (
self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 8
)
self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
max_sequence_length: int = 512,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
return_length=True,
return_offsets_mapping=False,
)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask.bool().to(device)
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(
text_input_ids.to(device), attention_mask=prompt_attention_mask
).last_hidden_state
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
lengths = prompt_attention_mask.sum(dim=1).cpu()
for i, length in enumerate(lengths):
prompt_embeds[i, length:] = 0
return prompt_embeds
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 512,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = self._get_t5_prompt_embeds(
prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = negative_prompt_embeds.shape
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
return prompt_embeds, negative_prompt_embeds
def prepare_latents(
self,
batch_size: int,
num_channels_latents: 16,
height: int = 704,
width: int = 1280,
num_frames: int = 121,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if latents is not None:
return latents.to(device=device, dtype=dtype) * self.scheduler.config.sigma_max
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
latent_height = height // self.vae_scale_factor_spatial
latent_width = width // self.vae_scale_factor_spatial
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents * self.scheduler.config.sigma_max
def check_inputs(
self,
prompt,
height,
width,
prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0
@property
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 704,
width: int = 1280,
num_frames: int = 121,
num_inference_steps: int = 36,
guidance_scale: float = 7.0,
fps: int = 30,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
):
r"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, defaults to `720`):
The height in pixels of the generated image.
width (`int`, defaults to `1280`):
The width in pixels of the generated image.
num_frames (`int`, defaults to `129`):
The number of frames in the generated video.
num_inference_steps (`int`, defaults to `50`):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to `6.0`):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`.
fps (`int`, defaults to `30`):
The frames per second of the generated video.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
Examples:
Returns:
[`~CosmosPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`CosmosPipelineOutput`] is returned, otherwise a `tuple` is returned where
the first element is a list with the generated images and the second element is a list of `bool`s
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
"""
if self.safety_checker is None:
raise ValueError(
f"You have disabled the safety checker for {self.__class__}. This is in violation of the "
"[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). "
f"Please ensure that you are compliant with the license agreement."
)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs)
self._guidance_scale = guidance_scale
self._current_timestep = None
self._interrupt = False
device = self._execution_device
if self.safety_checker is not None:
self.safety_checker.to(device)
if prompt is not None:
prompt_list = [prompt] if isinstance(prompt, str) else prompt
for p in prompt_list:
if not self.safety_checker.check_text_safety(p):
raise ValueError(
f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the "
f"prompt abides by the NVIDIA Open Model License Agreement."
)
self.safety_checker.to("cpu")
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# 3. Encode input prompt
(
prompt_embeds,
negative_prompt_embeds,
) = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
device=device,
max_sequence_length=max_sequence_length,
)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device)
# 5. Prepare latent variables
transformer_dtype = self.transformer.dtype
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
num_frames,
torch.float32,
device,
generator,
latents,
)
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
# 6. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
timestep = t.expand(latents.shape[0]).to(transformer_dtype)
latent_model_input = latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
latent_model_input = latent_model_input.to(transformer_dtype)
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
fps=fps,
padding_mask=padding_mask,
return_dict=False,
)[0]
sample = latents
if self.do_classifier_free_guidance:
noise_pred_uncond = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
fps=fps,
padding_mask=padding_mask,
return_dict=False,
)[0]
noise_pred = torch.cat([noise_pred_uncond, noise_pred])
sample = torch.cat([sample, sample])
# pred_original_sample (x0)
noise_pred = self.scheduler.step(noise_pred, t, sample, return_dict=False)[1]
self.scheduler._step_index -= 1
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred = noise_pred_cond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
# pred_sample (eps)
latents = self.scheduler.step(
noise_pred, t, latents, return_dict=False, pred_original_sample=noise_pred
)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if not output_type == "latent":
if self.vae.config.latents_mean is not None:
latents_mean, latents_std = self.vae.config.latents_mean, self.vae.config.latents_std
latents_mean = (
torch.tensor(latents_mean)
.view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : latents.size(2)]
.to(latents)
)
latents_std = (
torch.tensor(latents_std)
.view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : latents.size(2)]
.to(latents)
)
latents = latents * latents_std / self.scheduler.config.sigma_data + latents_mean
else:
latents = latents / self.scheduler.config.sigma_data
video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
if self.safety_checker is not None:
self.safety_checker.to(device)
video = self.video_processor.postprocess_video(video, output_type="np")
video = (video * 255).astype(np.uint8)
video_batch = []
for vid in video:
vid = self.safety_checker.check_video_safety(vid)
video_batch.append(vid)
video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1
video = torch.from_numpy(video).permute(0, 4, 1, 2, 3)
video = self.video_processor.postprocess_video(video, output_type=output_type)
self.safety_checker.to("cpu")
else:
video = self.video_processor.postprocess_video(video, output_type=output_type)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return CosmosPipelineOutput(frames=video)
@@ -0,0 +1,828 @@
# Copyright 2024 The NVIDIA Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import T5EncoderModel, T5TokenizerFast
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput
from ...models import AutoencoderKLCosmos, CosmosTransformer3DModel
from ...schedulers import EDMEulerScheduler
from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import CosmosPipelineOutput
if is_cosmos_guardrail_available():
from cosmos_guardrail import CosmosSafetyChecker
else:
class CosmosSafetyChecker:
def __init__(self, *args, **kwargs):
raise ImportError(
"`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`."
)
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
Image conditioning:
```python
>>> import torch
>>> from diffusers import CosmosVideoToWorldPipeline
>>> from diffusers.utils import export_to_video, load_image
>>> model_id = "nvidia/Cosmos-1.0-Diffusion-7B-Video2World"
>>> pipe = CosmosVideoToWorldPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> prompt = "The video depicts a long, straight highway stretching into the distance, flanked by metal guardrails. The road is divided into multiple lanes, with a few vehicles visible in the far distance. The surrounding landscape features dry, grassy fields on one side and rolling hills on the other. The sky is mostly clear with a few scattered clouds, suggesting a bright, sunny day."
>>> image = load_image(
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input.jpg"
... )
>>> video = pipe(image=image, prompt=prompt).frames[0]
>>> export_to_video(video, "output.mp4", fps=30)
```
Video conditioning:
```python
>>> import torch
>>> from diffusers import CosmosVideoToWorldPipeline
>>> from diffusers.utils import export_to_video, load_video
>>> model_id = "nvidia/Cosmos-1.0-Diffusion-7B-Video2World"
>>> pipe = CosmosVideoToWorldPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
>>> pipe.transformer = torch.compile(pipe.transformer)
>>> pipe.to("cuda")
>>> prompt = "The video depicts a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region."
>>> video = load_video(
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4"
... )[
... :21
... ] # This example uses only the first 21 frames
>>> video = pipe(video=video, prompt=prompt).frames[0]
>>> export_to_video(video, "output.mp4", fps=30)
```
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class CosmosVideoToWorldPipeline(DiffusionPipeline):
r"""
Pipeline for image-to-video and video-to-video generation using [Cosmos](https://github.com/NVIDIA/Cosmos).
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Args:
text_encoder ([`T5EncoderModel`]):
Frozen text-encoder. Cosmos uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
[t5-11b](https://huggingface.co/google-t5/t5-11b) variant.
tokenizer (`T5TokenizerFast`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
transformer ([`CosmosTransformer3DModel`]):
Conditional Transformer to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLCosmos`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
# We mark safety_checker as optional here to get around some test failures, but it is not really optional
_optional_components = ["safety_checker"]
def __init__(
self,
text_encoder: T5EncoderModel,
tokenizer: T5TokenizerFast,
transformer: CosmosTransformer3DModel,
vae: AutoencoderKLCosmos,
scheduler: EDMEulerScheduler,
safety_checker: CosmosSafetyChecker = None,
):
super().__init__()
if safety_checker is None:
safety_checker = CosmosSafetyChecker()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
safety_checker=safety_checker,
)
self.vae_scale_factor_temporal = (
self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 8
)
self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline._get_t5_prompt_embeds
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
max_sequence_length: int = 512,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
return_length=True,
return_offsets_mapping=False,
)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask.bool().to(device)
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(
text_input_ids.to(device), attention_mask=prompt_attention_mask
).last_hidden_state
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
lengths = prompt_attention_mask.sum(dim=1).cpu()
for i, length in enumerate(lengths):
prompt_embeds[i, length:] = 0
return prompt_embeds
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 512,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = self._get_t5_prompt_embeds(
prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype
)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = negative_prompt_embeds.shape
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
return prompt_embeds, negative_prompt_embeds
def prepare_latents(
self,
video: torch.Tensor,
batch_size: int,
num_channels_latents: 16,
height: int = 704,
width: int = 1280,
num_frames: int = 121,
do_classifier_free_guidance: bool = True,
input_frames_guidance: bool = False,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
num_cond_frames = video.size(2)
if num_cond_frames >= num_frames:
# Take the last `num_frames` frames for conditioning
num_cond_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
video = video[:, :, -num_frames:]
else:
num_cond_latent_frames = (num_cond_frames - 1) // self.vae_scale_factor_temporal + 1
num_padding_frames = num_frames - num_cond_frames
padding = video.new_zeros(video.size(0), video.size(1), num_padding_frames, video.size(3), video.size(4))
video = torch.cat([video, padding], dim=2)
if isinstance(generator, list):
init_latents = [
retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator=generator[i])
for i in range(batch_size)
]
else:
init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
init_latents = torch.cat(init_latents, dim=0).to(dtype)
if self.vae.config.latents_mean is not None:
latents_mean, latents_std = self.vae.config.latents_mean, self.vae.config.latents_std
latents_mean = (
torch.tensor(latents_mean)
.view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : init_latents.size(2)]
.to(init_latents)
)
latents_std = (
torch.tensor(latents_std)
.view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : init_latents.size(2)]
.to(init_latents)
)
init_latents = (init_latents - latents_mean) * self.scheduler.config.sigma_data / latents_std
else:
init_latents = init_latents * self.scheduler.config.sigma_data
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
latent_height = height // self.vae_scale_factor_spatial
latent_width = width // self.vae_scale_factor_spatial
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device=device, dtype=dtype)
latents = latents * self.scheduler.config.sigma_max
padding_shape = (batch_size, 1, num_latent_frames, latent_height, latent_width)
ones_padding = latents.new_ones(padding_shape)
zeros_padding = latents.new_zeros(padding_shape)
cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1)
cond_indicator[:, :, :num_cond_latent_frames] = 1.0
cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding
uncond_indicator = uncond_mask = None
if do_classifier_free_guidance:
uncond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1)
uncond_indicator[:, :, :num_cond_latent_frames] = 1.0
uncond_mask = zeros_padding
if not input_frames_guidance:
uncond_mask = uncond_indicator * ones_padding + (1 - uncond_indicator) * zeros_padding
return latents, init_latents, cond_indicator, uncond_indicator, cond_mask, uncond_mask
def check_inputs(
self,
prompt,
height,
width,
prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
image=None,
video=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if image is None and video is None:
raise ValueError("Either `image` or `video` has to be provided.")
if image is not None and video is not None:
raise ValueError("Only one of `image` or `video` has to be provided.")
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0
@property
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
image: PipelineImageInput = None,
video: List[PipelineImageInput] = None,
prompt: Union[str, List[str]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 704,
width: int = 1280,
num_frames: int = 121,
num_inference_steps: int = 36,
guidance_scale: float = 7.0,
input_frames_guidance: bool = False,
augment_sigma: float = 0.001,
fps: int = 30,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
):
r"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, defaults to `720`):
The height in pixels of the generated image.
width (`int`, defaults to `1280`):
The width in pixels of the generated image.
num_frames (`int`, defaults to `129`):
The number of frames in the generated video.
num_inference_steps (`int`, defaults to `50`):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to `6.0`):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`.
fps (`int`, defaults to `30`):
The frames per second of the generated video.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
Examples:
Returns:
[`~CosmosPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`CosmosPipelineOutput`] is returned, otherwise a `tuple` is returned where
the first element is a list with the generated images and the second element is a list of `bool`s
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
"""
if self.safety_checker is None:
raise ValueError(
f"You have disabled the safety checker for {self.__class__}. This is in violation of the "
"[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). "
f"Please ensure that you are compliant with the license agreement."
)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs, image, video)
self._guidance_scale = guidance_scale
self._current_timestep = None
self._interrupt = False
device = self._execution_device
if self.safety_checker is not None:
self.safety_checker.to(device)
if prompt is not None:
prompt_list = [prompt] if isinstance(prompt, str) else prompt
for p in prompt_list:
if not self.safety_checker.check_text_safety(p):
raise ValueError(
f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the "
f"prompt abides by the NVIDIA Open Model License Agreement."
)
self.safety_checker.to("cpu")
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# 3. Encode input prompt
(
prompt_embeds,
negative_prompt_embeds,
) = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
device=device,
max_sequence_length=max_sequence_length,
)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device)
# 5. Prepare latent variables
vae_dtype = self.vae.dtype
transformer_dtype = self.transformer.dtype
if image is not None:
video = self.video_processor.preprocess(image, height, width).unsqueeze(2)
else:
video = self.video_processor.preprocess_video(video, height, width)
video = video.to(device=device, dtype=vae_dtype)
num_channels_latents = self.transformer.config.in_channels - 1
latents, conditioning_latents, cond_indicator, uncond_indicator, cond_mask, uncond_mask = self.prepare_latents(
video,
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
num_frames,
self.do_classifier_free_guidance,
input_frames_guidance,
torch.float32,
device,
generator,
latents,
)
cond_mask = cond_mask.to(transformer_dtype)
if self.do_classifier_free_guidance:
uncond_mask = uncond_mask.to(transformer_dtype)
augment_sigma = torch.tensor([augment_sigma], device=device, dtype=torch.float32)
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
# 6. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
timestep = t.expand(latents.shape[0]).to(transformer_dtype)
current_sigma = self.scheduler.sigmas[i]
is_augment_sigma_greater = augment_sigma >= current_sigma
c_in_augment = self.scheduler._get_conditioning_c_in(augment_sigma)
c_in_original = self.scheduler._get_conditioning_c_in(current_sigma)
current_cond_indicator = cond_indicator * 0 if is_augment_sigma_greater else cond_indicator
cond_noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=torch.float32)
cond_latent = conditioning_latents + cond_noise * augment_sigma[:, None, None, None, None]
cond_latent = cond_latent * c_in_augment / c_in_original
cond_latent = current_cond_indicator * cond_latent + (1 - current_cond_indicator) * latents
cond_latent = self.scheduler.scale_model_input(cond_latent, t)
cond_latent = cond_latent.to(transformer_dtype)
noise_pred = self.transformer(
hidden_states=cond_latent,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
fps=fps,
condition_mask=cond_mask,
padding_mask=padding_mask,
return_dict=False,
)[0]
sample = latents
if self.do_classifier_free_guidance:
current_uncond_indicator = uncond_indicator * 0 if is_augment_sigma_greater else uncond_indicator
uncond_noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=torch.float32)
uncond_latent = conditioning_latents + uncond_noise * augment_sigma[:, None, None, None, None]
uncond_latent = uncond_latent * c_in_augment / c_in_original
uncond_latent = current_uncond_indicator * uncond_latent + (1 - current_uncond_indicator) * latents
uncond_latent = self.scheduler.scale_model_input(uncond_latent, t)
uncond_latent = uncond_latent.to(transformer_dtype)
noise_pred_uncond = self.transformer(
hidden_states=uncond_latent,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
fps=fps,
condition_mask=uncond_mask,
padding_mask=padding_mask,
return_dict=False,
)[0]
noise_pred = torch.cat([noise_pred_uncond, noise_pred])
sample = torch.cat([sample, sample])
# pred_original_sample (x0)
noise_pred = self.scheduler.step(noise_pred, t, sample, return_dict=False)[1]
self.scheduler._step_index -= 1
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0)
noise_pred_uncond = (
current_uncond_indicator * conditioning_latents
+ (1 - current_uncond_indicator) * noise_pred_uncond
)
noise_pred_cond = (
current_cond_indicator * conditioning_latents + (1 - current_cond_indicator) * noise_pred_cond
)
noise_pred = noise_pred_cond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
else:
noise_pred = (
current_cond_indicator * conditioning_latents + (1 - current_cond_indicator) * noise_pred
)
# pred_sample (eps)
latents = self.scheduler.step(
noise_pred, t, latents, return_dict=False, pred_original_sample=noise_pred
)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if not output_type == "latent":
if self.vae.config.latents_mean is not None:
latents_mean, latents_std = self.vae.config.latents_mean, self.vae.config.latents_std
latents_mean = (
torch.tensor(latents_mean)
.view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : latents.size(2)]
.to(latents)
)
latents_std = (
torch.tensor(latents_std)
.view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : latents.size(2)]
.to(latents)
)
latents = latents * latents_std / self.scheduler.config.sigma_data + latents_mean
else:
latents = latents / self.scheduler.config.sigma_data
video = self.vae.decode(latents.to(vae_dtype), return_dict=False)[0]
if self.safety_checker is not None:
self.safety_checker.to(device)
video = self.video_processor.postprocess_video(video, output_type="np")
video = (video * 255).astype(np.uint8)
video_batch = []
for vid in video:
vid = self.safety_checker.check_video_safety(vid)
video_batch.append(vid)
video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1
video = torch.from_numpy(video).permute(0, 4, 1, 2, 3)
video = self.video_processor.postprocess_video(video, output_type=output_type)
self.safety_checker.to("cpu")
else:
video = self.video_processor.postprocess_video(video, output_type=output_type)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return CosmosPipelineOutput(frames=video)
@@ -0,0 +1,20 @@
from dataclasses import dataclass
import torch
from diffusers.utils import BaseOutput
@dataclass
class CosmosPipelineOutput(BaseOutput):
r"""
Output class for Cosmos pipelines.
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`.
"""
frames: torch.Tensor
@@ -687,11 +687,11 @@ class FluxPipeline(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 3.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -700,7 +700,7 @@ class FluxPipeline(
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -607,6 +607,39 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
return latents
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
def prepare_latents(
self,
image,
@@ -36,11 +36,11 @@ EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from transformers import PreTrainedTokenizerFast, LlamaForCausalLM
>>> from diffusers import UniPCMultistepScheduler, HiDreamImagePipeline
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> from diffusers import HiDreamImagePipeline
>>> tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
>>> tokenizer_4 = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
>>> text_encoder_4 = LlamaForCausalLM.from_pretrained(
... "meta-llama/Meta-Llama-3.1-8B-Instruct",
... output_hidden_states=True,
@@ -24,6 +24,7 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["pipeline_hunyuan_skyreels_image2video"] = ["HunyuanSkyreelsImageToVideoPipeline"]
_import_structure["pipeline_hunyuan_video"] = ["HunyuanVideoPipeline"]
_import_structure["pipeline_hunyuan_video_framepack"] = ["HunyuanVideoFramepackPipeline"]
_import_structure["pipeline_hunyuan_video_image2video"] = ["HunyuanVideoImageToVideoPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -36,6 +37,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .pipeline_hunyuan_skyreels_image2video import HunyuanSkyreelsImageToVideoPipeline
from .pipeline_hunyuan_video import HunyuanVideoPipeline
from .pipeline_hunyuan_video_framepack import HunyuanVideoFramepackPipeline
from .pipeline_hunyuan_video_image2video import HunyuanVideoImageToVideoPipeline
else:
File diff suppressed because it is too large Load Diff
@@ -1,5 +1,8 @@
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
import torch
from diffusers.utils import BaseOutput
@@ -18,3 +21,19 @@ class HunyuanVideoPipelineOutput(BaseOutput):
"""
frames: torch.Tensor
@dataclass
class HunyuanVideoFramepackPipelineOutput(BaseOutput):
r"""
Output class for HunyuanVideo pipelines.
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`. Or, a list of torch tensors where each tensor
corresponds to a latent that decodes to multiple frames.
"""
frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]], List[torch.Tensor]]
+4
View File
@@ -22,9 +22,11 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["modeling_latent_upsampler"] = ["LTXLatentUpsamplerModel"]
_import_structure["pipeline_ltx"] = ["LTXPipeline"]
_import_structure["pipeline_ltx_condition"] = ["LTXConditionPipeline"]
_import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"]
_import_structure["pipeline_ltx_latent_upsample"] = ["LTXLatentUpsamplePipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -34,9 +36,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .modeling_latent_upsampler import LTXLatentUpsamplerModel
from .pipeline_ltx import LTXPipeline
from .pipeline_ltx_condition import LTXConditionPipeline
from .pipeline_ltx_image2video import LTXImageToVideoPipeline
from .pipeline_ltx_latent_upsample import LTXLatentUpsamplePipeline
else:
import sys
@@ -0,0 +1,188 @@
# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.modeling_utils import ModelMixin
class ResBlock(torch.nn.Module):
def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3):
super().__init__()
if mid_channels is None:
mid_channels = channels
Conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1)
self.norm1 = torch.nn.GroupNorm(32, mid_channels)
self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1)
self.norm2 = torch.nn.GroupNorm(32, channels)
self.activation = torch.nn.SiLU()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states
hidden_states = self.conv1(hidden_states)
hidden_states = self.norm1(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.conv2(hidden_states)
hidden_states = self.norm2(hidden_states)
hidden_states = self.activation(hidden_states + residual)
return hidden_states
class PixelShuffleND(torch.nn.Module):
def __init__(self, dims, upscale_factors=(2, 2, 2)):
super().__init__()
self.dims = dims
self.upscale_factors = upscale_factors
if dims not in [1, 2, 3]:
raise ValueError("dims must be 1, 2, or 3")
def forward(self, x):
if self.dims == 3:
# spatiotemporal: b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)
return (
x.unflatten(1, (-1, *self.upscale_factors[:3]))
.permute(0, 1, 5, 2, 6, 3, 7, 4)
.flatten(6, 7)
.flatten(4, 5)
.flatten(2, 3)
)
elif self.dims == 2:
# spatial: b (c p1 p2) h w -> b c (h p1) (w p2)
return (
x.unflatten(1, (-1, *self.upscale_factors[:2])).permute(0, 1, 4, 2, 5, 3).flatten(4, 5).flatten(2, 3)
)
elif self.dims == 1:
# temporal: b (c p1) f h w -> b c (f p1) h w
return x.unflatten(1, (-1, *self.upscale_factors[:1])).permute(0, 1, 3, 2, 4, 5).flatten(2, 3)
class LTXLatentUpsamplerModel(ModelMixin, ConfigMixin):
"""
Model to spatially upsample VAE latents.
Args:
in_channels (`int`, defaults to `128`):
Number of channels in the input latent
mid_channels (`int`, defaults to `512`):
Number of channels in the middle layers
num_blocks_per_stage (`int`, defaults to `4`):
Number of ResBlocks to use in each stage (pre/post upsampling)
dims (`int`, defaults to `3`):
Number of dimensions for convolutions (2 or 3)
spatial_upsample (`bool`, defaults to `True`):
Whether to spatially upsample the latent
temporal_upsample (`bool`, defaults to `False`):
Whether to temporally upsample the latent
"""
@register_to_config
def __init__(
self,
in_channels: int = 128,
mid_channels: int = 512,
num_blocks_per_stage: int = 4,
dims: int = 3,
spatial_upsample: bool = True,
temporal_upsample: bool = False,
):
super().__init__()
self.in_channels = in_channels
self.mid_channels = mid_channels
self.num_blocks_per_stage = num_blocks_per_stage
self.dims = dims
self.spatial_upsample = spatial_upsample
self.temporal_upsample = temporal_upsample
ConvNd = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
self.initial_conv = ConvNd(in_channels, mid_channels, kernel_size=3, padding=1)
self.initial_norm = torch.nn.GroupNorm(32, mid_channels)
self.initial_activation = torch.nn.SiLU()
self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)])
if spatial_upsample and temporal_upsample:
self.upsampler = torch.nn.Sequential(
torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(3),
)
elif spatial_upsample:
self.upsampler = torch.nn.Sequential(
torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(2),
)
elif temporal_upsample:
self.upsampler = torch.nn.Sequential(
torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(1),
)
else:
raise ValueError("Either spatial_upsample or temporal_upsample must be True")
self.post_upsample_res_blocks = torch.nn.ModuleList(
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
)
self.final_conv = ConvNd(mid_channels, in_channels, kernel_size=3, padding=1)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
if self.dims == 2:
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
hidden_states = self.initial_conv(hidden_states)
hidden_states = self.initial_norm(hidden_states)
hidden_states = self.initial_activation(hidden_states)
for block in self.res_blocks:
hidden_states = block(hidden_states)
hidden_states = self.upsampler(hidden_states)
for block in self.post_upsample_res_blocks:
hidden_states = block(hidden_states)
hidden_states = self.final_conv(hidden_states)
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
else:
hidden_states = self.initial_conv(hidden_states)
hidden_states = self.initial_norm(hidden_states)
hidden_states = self.initial_activation(hidden_states)
for block in self.res_blocks:
hidden_states = block(hidden_states)
if self.temporal_upsample:
hidden_states = self.upsampler(hidden_states)
hidden_states = hidden_states[:, :, 1:, :, :]
else:
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
hidden_states = self.upsampler(hidden_states)
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
for block in self.post_upsample_res_blocks:
hidden_states = block(hidden_states)
hidden_states = self.final_conv(hidden_states)
return hidden_states
@@ -789,6 +789,7 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
]
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
latents = latents.to(self.vae.dtype)
video = self.vae.decode(latents, timestep, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)
@@ -430,6 +430,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
video,
frame_index,
strength,
denoise_strength,
height,
width,
callback_on_step_end_tensor_inputs=None,
@@ -497,6 +498,9 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
elif isinstance(video, list) and isinstance(strength, list) and len(video) != len(strength):
raise ValueError("If `conditions` is not provided, `video` and `strength` must be of the same length.")
if denoise_strength < 0 or denoise_strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {denoise_strength}")
@staticmethod
def _prepare_video_ids(
batch_size: int,
@@ -649,6 +653,8 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
width: int = 704,
num_frames: int = 161,
num_prefix_latent_frames: int = 2,
sigma: Optional[torch.Tensor] = None,
latents: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
@@ -658,7 +664,18 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
latent_width = width // self.vae_spatial_compression_ratio
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if latents is not None and sigma is not None:
if latents.shape != shape:
raise ValueError(
f"Latents shape {latents.shape} does not match expected shape {shape}. Please check the input."
)
latents = latents.to(device=device, dtype=dtype)
sigma = sigma.to(device=device, dtype=dtype)
latents = sigma * noise + (1 - sigma) * latents
else:
latents = noise
if len(conditions) > 0:
condition_latent_frames_mask = torch.zeros(
@@ -766,6 +783,13 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
return latents, conditioning_mask, video_ids, extra_conditioning_num_latents
def get_timesteps(self, sigmas, timesteps, num_inference_steps, strength):
num_steps = min(int(num_inference_steps * strength), num_inference_steps)
start_index = max(num_inference_steps - num_steps, 0)
sigmas = sigmas[start_index:]
timesteps = timesteps[start_index:]
return sigmas, timesteps, num_inference_steps - start_index
@property
def guidance_scale(self):
return self._guidance_scale
@@ -799,6 +823,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
video: List[PipelineImageInput] = None,
frame_index: Union[int, List[int]] = 0,
strength: Union[float, List[float]] = 1.0,
denoise_strength: float = 1.0,
prompt: Union[str, List[str]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 512,
@@ -842,6 +867,10 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
generation. If not provided, one has to pass `conditions`.
strength (`float` or `List[float]`, *optional*):
The strength or strengths of the conditioning effect. If not provided, one has to pass `conditions`.
denoise_strength (`float`, defaults to `1.0`):
The strength of the noise added to the latents for editing. Higher strength leads to more noise added
to the latents, therefore leading to more differences between original video and generated video. This
is useful for video-to-video editing.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
@@ -918,8 +947,6 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
if latents is not None:
raise ValueError("Passing latents is not yet supported.")
# 1. Check inputs. Raise error if not correct
self.check_inputs(
@@ -929,6 +956,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
video=video,
frame_index=frame_index,
strength=strength,
denoise_strength=denoise_strength,
height=height,
width=width,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
@@ -977,8 +1005,9 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
strength = [strength] * num_conditions
device = self._execution_device
vae_dtype = self.vae.dtype
# 3. Prepare text embeddings
# 3. Prepare text embeddings & conditioning image/video
(
prompt_embeds,
prompt_attention_mask,
@@ -1000,8 +1029,6 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
vae_dtype = self.vae.dtype
conditioning_tensors = []
is_conditioning_image_or_video = image is not None or video is not None
if is_conditioning_image_or_video:
@@ -1032,7 +1059,25 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
)
conditioning_tensors.append(condition_tensor)
# 4. Prepare latent variables
# 4. Prepare timesteps
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
latent_height = height // self.vae_spatial_compression_ratio
latent_width = width // self.vae_spatial_compression_ratio
sigmas = linear_quadratic_schedule(num_inference_steps)
timesteps = sigmas * 1000
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
latent_sigma = None
if denoise_strength < 1:
sigmas, timesteps, num_inference_steps = self.get_timesteps(
sigmas, timesteps, num_inference_steps, denoise_strength
)
latent_sigma = sigmas[:1].repeat(batch_size * num_videos_per_prompt)
self._num_timesteps = len(timesteps)
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents, conditioning_mask, video_coords, extra_conditioning_num_latents = self.prepare_latents(
conditioning_tensors,
@@ -1043,6 +1088,8 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
height=height,
width=width,
num_frames=num_frames,
sigma=latent_sigma,
latents=latents,
generator=generator,
device=device,
dtype=torch.float32,
@@ -1056,21 +1103,6 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
if self.do_classifier_free_guidance:
video_coords = torch.cat([video_coords, video_coords], dim=0)
# 5. Prepare timesteps
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
latent_height = height // self.vae_spatial_compression_ratio
latent_width = width // self.vae_spatial_compression_ratio
sigmas = linear_quadratic_schedule(num_inference_steps)
timesteps = sigmas * 1000
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps=timesteps,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
@@ -1168,7 +1200,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
if not self.vae.config.timestep_conditioning:
timestep = None
else:
noise = torch.randn(latents.shape, generator=generator, device=device, dtype=latents.dtype)
noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
if not isinstance(decode_timestep, list):
decode_timestep = [decode_timestep] * batch_size
if decode_noise_scale is None:
@@ -0,0 +1,241 @@
# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
import torch
from ...image_processor import PipelineImageInput
from ...models import AutoencoderKLLTXVideo
from ...utils import get_logger
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .modeling_latent_upsampler import LTXLatentUpsamplerModel
from .pipeline_output import LTXPipelineOutput
logger = get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class LTXLatentUpsamplePipeline(DiffusionPipeline):
def __init__(
self,
vae: AutoencoderKLLTXVideo,
latent_upsampler: LTXLatentUpsamplerModel,
) -> None:
super().__init__()
self.register_modules(vae=vae, latent_upsampler=latent_upsampler)
self.vae_spatial_compression_ratio = (
self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
)
self.vae_temporal_compression_ratio = (
self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
def prepare_latents(
self,
video: Optional[torch.Tensor] = None,
batch_size: int = 1,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if latents is not None:
return latents.to(device=device, dtype=dtype)
video = video.to(device=device, dtype=self.vae.dtype)
if isinstance(generator, list):
if len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
init_latents = [
retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
]
else:
init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
init_latents = torch.cat(init_latents, dim=0).to(dtype)
init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std)
return init_latents
@staticmethod
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents
def _normalize_latents(
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
) -> torch.Tensor:
# Normalize latents across the channel dimension [B, C, F, H, W]
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents = (latents - latents_mean) * scaling_factor / latents_std
return latents
@staticmethod
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._denormalize_latents
def _denormalize_latents(
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
) -> torch.Tensor:
# Denormalize latents across the channel dimension [B, C, F, H, W]
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents = latents * latents_std / scaling_factor + latents_mean
return latents
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
def check_inputs(self, video, height, width, latents):
if height % self.vae_spatial_compression_ratio != 0 or width % self.vae_spatial_compression_ratio != 0:
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
if video is not None and latents is not None:
raise ValueError("Only one of `video` or `latents` can be provided.")
if video is None and latents is None:
raise ValueError("One of `video` or `latents` has to be provided.")
@torch.no_grad()
def __call__(
self,
video: Optional[List[PipelineImageInput]] = None,
height: int = 512,
width: int = 704,
latents: Optional[torch.Tensor] = None,
decode_timestep: Union[float, List[float]] = 0.0,
decode_noise_scale: Optional[Union[float, List[float]]] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
):
self.check_inputs(
video=video,
height=height,
width=width,
latents=latents,
)
if video is not None:
# Batched video input is not yet tested/supported. TODO: take a look later
batch_size = 1
else:
batch_size = latents.shape[0]
device = self._execution_device
if video is not None:
num_frames = len(video)
if num_frames % self.vae_temporal_compression_ratio != 1:
num_frames = (
num_frames // self.vae_temporal_compression_ratio * self.vae_temporal_compression_ratio + 1
)
video = video[:num_frames]
logger.warning(
f"Video length expected to be of the form `k * {self.vae_temporal_compression_ratio} + 1` but is {len(video)}. Truncating to {num_frames} frames."
)
video = self.video_processor.preprocess_video(video, height=height, width=width)
video = video.to(device=device, dtype=torch.float32)
latents = self.prepare_latents(
video=video,
batch_size=batch_size,
dtype=torch.float32,
device=device,
generator=generator,
latents=latents,
)
latents = self._denormalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
latents = latents.to(self.latent_upsampler.dtype)
latents = self.latent_upsampler(latents)
if output_type == "latent":
latents = self._normalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
video = latents
else:
if not self.vae.config.timestep_conditioning:
timestep = None
else:
noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
if not isinstance(decode_timestep, list):
decode_timestep = [decode_timestep] * batch_size
if decode_noise_scale is None:
decode_noise_scale = decode_timestep
elif not isinstance(decode_noise_scale, list):
decode_noise_scale = [decode_noise_scale] * batch_size
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
:, None, None, None, None
]
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
video = self.vae.decode(latents, timestep, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return LTXPipelineOutput(frames=video)
@@ -248,9 +248,8 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
pretrained pipeline hosted on the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
using [`~FlaxDiffusionPipeline.save_pretrained`].
dtype (`str` or `jnp.dtype`, *optional*):
Override the default `jnp.dtype` and load the model under this dtype. If `"auto"`, the dtype is
automatically derived from the model's weights.
dtype (`jnp.dtype`, *optional*):
Override the default `jnp.dtype` and load the model under this dtype.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
@@ -675,8 +675,10 @@ def load_sub_model(
use_safetensors: bool,
dduf_entries: Optional[Dict[str, DDUFEntry]],
provider_options: Any,
quantization_config: Optional[Any] = None,
):
"""Helper method to load the module `name` from `library_name` and `class_name`"""
from ..quantizers import PipelineQuantizationConfig
# retrieve class candidates
@@ -769,6 +771,17 @@ def load_sub_model(
else:
loading_kwargs["low_cpu_mem_usage"] = False
if (
quantization_config is not None
and isinstance(quantization_config, PipelineQuantizationConfig)
and issubclass(class_obj, torch.nn.Module)
):
model_quant_config = quantization_config._resolve_quant_config(
is_diffusers=is_diffusers_model, module_name=name
)
if model_quant_config is not None:
loading_kwargs["quantization_config"] = model_quant_config
# check if the module is in a subdirectory
if dduf_entries:
loading_kwargs["dduf_entries"] = dduf_entries
+12 -6
View File
@@ -47,6 +47,7 @@ from ..configuration_utils import ConfigMixin
from ..models import AutoencoderKL
from ..models.attention_processor import FusedAttnProcessor2_0
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
from ..quantizers import PipelineQuantizationConfig
from ..quantizers.bitsandbytes.utils import _check_bnb_status
from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from ..utils import (
@@ -572,12 +573,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
saved using
[`~DiffusionPipeline.save_pretrained`].
- A path to a *directory* (for example `./my_pipeline_directory/`) containing a dduf file
torch_dtype (`str` or `torch.dtype` or `dict[str, Union[str, torch.dtype]]`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
dtype is automatically derived from the model's weights. To load submodels with different dtype pass a
`dict` (for example `{'transformer': torch.bfloat16, 'vae': torch.float16}`). Set the default dtype for
unspecified components with `default` (for example `{'transformer': torch.bfloat16, 'default':
torch.float16}`). If a component is not specified and no default is set, `torch.float32` is used.
torch_dtype (`torch.dtype` or `dict[str, Union[str, torch.dtype]]`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. To load submodels with
different dtype pass a `dict` (for example `{'transformer': torch.bfloat16, 'vae': torch.float16}`).
Set the default dtype for unspecified components with `default` (for example `{'transformer':
torch.bfloat16, 'default': torch.float16}`). If a component is not specified and no default is set,
`torch.float32` is used.
custom_pipeline (`str`, *optional*):
<Tip warning={true}>
@@ -725,6 +726,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
use_safetensors = kwargs.pop("use_safetensors", None)
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
quantization_config = kwargs.pop("quantization_config", None)
if torch_dtype is not None and not isinstance(torch_dtype, dict) and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
@@ -741,6 +743,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
" install accelerate\n```\n."
)
if quantization_config is not None and not isinstance(quantization_config, PipelineQuantizationConfig):
raise ValueError("`quantization_config` must be an instance of `PipelineQuantizationConfig`.")
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
@@ -1001,6 +1006,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
use_safetensors=use_safetensors,
dduf_entries=dduf_entries,
provider_options=provider_options,
quantization_config=quantization_config,
)
logger.info(
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
@@ -30,18 +30,11 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["clip_image_project_model"] = ["CLIPImageProjection"]
_import_structure["pipeline_cycle_diffusion"] = ["CycleDiffusionPipeline"]
_import_structure["pipeline_stable_diffusion"] = ["StableDiffusionPipeline"]
_import_structure["pipeline_stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"]
_import_structure["pipeline_stable_diffusion_gligen"] = ["StableDiffusionGLIGENPipeline"]
_import_structure["pipeline_stable_diffusion_gligen_text_image"] = ["StableDiffusionGLIGENTextImagePipeline"]
_import_structure["pipeline_stable_diffusion_img2img"] = ["StableDiffusionImg2ImgPipeline"]
_import_structure["pipeline_stable_diffusion_inpaint"] = ["StableDiffusionInpaintPipeline"]
_import_structure["pipeline_stable_diffusion_inpaint_legacy"] = ["StableDiffusionInpaintPipelineLegacy"]
_import_structure["pipeline_stable_diffusion_instruct_pix2pix"] = ["StableDiffusionInstructPix2PixPipeline"]
_import_structure["pipeline_stable_diffusion_latent_upscale"] = ["StableDiffusionLatentUpscalePipeline"]
_import_structure["pipeline_stable_diffusion_model_editing"] = ["StableDiffusionModelEditingPipeline"]
_import_structure["pipeline_stable_diffusion_paradigms"] = ["StableDiffusionParadigmsPipeline"]
_import_structure["pipeline_stable_diffusion_upscale"] = ["StableDiffusionUpscalePipeline"]
_import_structure["pipeline_stable_unclip"] = ["StableUnCLIPPipeline"]
_import_structure["pipeline_stable_unclip_img2img"] = ["StableUnCLIPImg2ImgPipeline"]
@@ -0,0 +1,52 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_visualcloze_combined"] = ["VisualClozePipeline"]
_import_structure["pipeline_visualcloze_generation"] = ["VisualClozeGenerationPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_visualcloze_combined import VisualClozePipeline
from .pipeline_visualcloze_generation import VisualClozeGenerationPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
@@ -0,0 +1,444 @@
# Copyright 2025 VisualCloze team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from PIL import Image
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
is_torch_xla_available,
logging,
replace_example_docstring,
)
from ..flux.pipeline_flux_fill import FluxFillPipeline as VisualClozeUpsamplingPipeline
from ..flux.pipeline_output import FluxPipelineOutput
from ..pipeline_utils import DiffusionPipeline
from .pipeline_visualcloze_generation import VisualClozeGenerationPipeline
if is_torch_xla_available():
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers import VisualClozePipeline
>>> from diffusers.utils import load_image
>>> image_paths = [
... # in-context examples
... [
... load_image(
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_mask.jpg"
... ),
... load_image(
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_image.jpg"
... ),
... ],
... # query with the target image
... [
... load_image(
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_query_mask.jpg"
... ),
... None, # No image needed for the target image
... ],
... ]
>>> task_prompt = "In each row, a logical task is demonstrated to achieve [IMAGE2] an aesthetically pleasing photograph based on [IMAGE1] sam 2-generated masks with rich color coding."
>>> content_prompt = "Majestic photo of a golden eagle perched on a rocky outcrop in a mountainous landscape. The eagle is positioned in the right foreground, facing left, with its sharp beak and keen eyes prominently visible. Its plumage is a mix of dark brown and golden hues, with intricate feather details. The background features a soft-focus view of snow-capped mountains under a cloudy sky, creating a serene and grandiose atmosphere. The foreground includes rugged rocks and patches of green moss. Photorealistic, medium depth of field, soft natural lighting, cool color palette, high contrast, sharp focus on the eagle, blurred background, tranquil, majestic, wildlife photography."
>>> pipe = VisualClozePipeline.from_pretrained(
... "VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16
... )
>>> pipe.to("cuda")
>>> image = pipe(
... task_prompt=task_prompt,
... content_prompt=content_prompt,
... image=image_paths,
... upsampling_width=1344,
... upsampling_height=768,
... upsampling_strength=0.4,
... guidance_scale=30,
... num_inference_steps=30,
... max_sequence_length=512,
... generator=torch.Generator("cpu").manual_seed(0),
... ).images[0][0]
>>> image.save("visualcloze.png")
```
"""
class VisualClozePipeline(
DiffusionPipeline,
FluxLoraLoaderMixin,
FromSingleFileMixin,
TextualInversionLoaderMixin,
):
r"""
The VisualCloze pipeline for image generation with visual context. Reference:
https://github.com/lzyhha/VisualCloze/tree/main. This pipeline is designed to generate images based on visual
in-context examples.
Args:
transformer ([`FluxTransformer2DModel`]):
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
text_encoder_2 ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
tokenizer_2 (`T5TokenizerFast`):
Second Tokenizer of class
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
resolution (`int`, *optional*, defaults to 384):
The resolution of each image when concatenating images from the query and in-context examples.
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
_optional_components = []
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
text_encoder_2: T5EncoderModel,
tokenizer_2: T5TokenizerFast,
transformer: FluxTransformer2DModel,
resolution: int = 384,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
transformer=transformer,
scheduler=scheduler,
)
self.generation_pipe = VisualClozeGenerationPipeline(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
transformer=transformer,
scheduler=scheduler,
resolution=resolution,
)
self.upsampling_pipe = VisualClozeUpsamplingPipeline(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
transformer=transformer,
scheduler=scheduler,
)
def check_inputs(
self,
image,
task_prompt,
content_prompt,
upsampling_height,
upsampling_width,
strength,
prompt_embeds=None,
pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if upsampling_height is not None and upsampling_height % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`upsampling_height`has to be divisible by {self.vae_scale_factor * 2} but are {upsampling_height}. Dimensions will be resized accordingly"
)
if upsampling_width is not None and upsampling_width % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`upsampling_width` have to be divisible by {self.vae_scale_factor * 2} but are {upsampling_width}. Dimensions will be resized accordingly"
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
# Validate prompt inputs
if (task_prompt is not None or content_prompt is not None) and prompt_embeds is not None:
raise ValueError("Cannot provide both text `task_prompt` + `content_prompt` and `prompt_embeds`. ")
if task_prompt is None and content_prompt is None and prompt_embeds is None:
raise ValueError("Must provide either `task_prompt` + `content_prompt` or pre-computed `prompt_embeds`. ")
# Validate prompt types and consistency
if task_prompt is None:
raise ValueError("`task_prompt` is missing.")
if task_prompt is not None and not isinstance(task_prompt, (str, list)):
raise ValueError(f"`task_prompt` must be str or list, got {type(task_prompt)}")
if content_prompt is not None and not isinstance(content_prompt, (str, list)):
raise ValueError(f"`content_prompt` must be str or list, got {type(content_prompt)}")
if isinstance(task_prompt, list) or isinstance(content_prompt, list):
if not isinstance(task_prompt, list) or not isinstance(content_prompt, list):
raise ValueError(
f"`task_prompt` and `content_prompt` must both be lists, or both be of type str or None, "
f"got {type(task_prompt)} and {type(content_prompt)}"
)
if len(content_prompt) != len(task_prompt):
raise ValueError("`task_prompt` and `content_prompt` must have the same length whe they are lists.")
for sample in image:
if not isinstance(sample, list) or not isinstance(sample[0], list):
raise ValueError("Each sample in the batch must have a 2D list of images.")
if len({len(row) for row in sample}) != 1:
raise ValueError("Each in-context example and query should contain the same number of images.")
if not any(img is None for img in sample[-1]):
raise ValueError("There are no targets in the query, which should be represented as None.")
for row in sample[:-1]:
if any(img is None for img in row):
raise ValueError("Images are missing in in-context examples.")
# Validate embeddings
if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
)
# Validate sequence length
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"max_sequence_length cannot exceed 512, got {max_sequence_length}")
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
task_prompt: Union[str, List[str]] = None,
content_prompt: Union[str, List[str]] = None,
image: Optional[torch.FloatTensor] = None,
upsampling_height: Optional[int] = None,
upsampling_width: Optional[int] = None,
num_inference_steps: int = 50,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 30.0,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
upsampling_strength: float = 1.0,
):
r"""
Function invoked when calling the VisualCloze pipeline for generation.
Args:
task_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to define the task intention.
content_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to define the content or caption of the target image to be generated.
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`.
upsampling_height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image (i.e., output image) after upsampling via SDEdit. By
default, the image is upsampled by a factor of three, and the base resolution is determined by the
resolution parameter of the pipeline. When only one of `upsampling_height` or `upsampling_width` is
specified, the other will be automatically set based on the aspect ratio.
upsampling_width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image (i.e., output image) after upsampling via SDEdit. By
default, the image is upsampled by a factor of three, and the base resolution is determined by the
resolution parameter of the pipeline. When only one of `upsampling_height` or `upsampling_width` is
specified, the other will be automatically set based on the aspect ratio.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 30.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
upsampling_strength (`float`, *optional*, defaults to 1.0):
Indicates extent to transform the reference `image` when upsampling the results. Must be between 0 and
1. The generated image is used as a starting point and more noise is added the higher the
`upsampling_strength`. The number of denoising steps depends on the amount of noise initially added.
When `upsampling_strength` is 1, added noise is maximum and the denoising process runs for the full
number of iterations specified in `num_inference_steps`. A value of 0 skips the upsampling step and
output the results at the resolution of `self.resolution`.
Examples:
Returns:
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
images.
"""
generation_output = self.generation_pipe(
task_prompt=task_prompt,
content_prompt=content_prompt,
image=image,
num_inference_steps=num_inference_steps,
sigmas=sigmas,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
generator=generator,
latents=latents,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
joint_attention_kwargs=joint_attention_kwargs,
callback_on_step_end=callback_on_step_end,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
output_type=output_type if upsampling_strength == 0 else "pil",
)
if upsampling_strength == 0:
if not return_dict:
return (generation_output,)
return FluxPipelineOutput(images=generation_output)
# Upsampling the generated images
# 1. Prepare the input images and prompts
if not isinstance(content_prompt, (list)):
content_prompt = [content_prompt]
n_target_per_sample = []
upsampling_image = []
upsampling_mask = []
upsampling_prompt = []
upsampling_generator = generator if isinstance(generator, (torch.Generator,)) else []
for i in range(len(generation_output.images)):
n_target_per_sample.append(len(generation_output.images[i]))
for image in generation_output.images[i]:
upsampling_image.append(image)
upsampling_mask.append(Image.new("RGB", image.size, (255, 255, 255)))
upsampling_prompt.append(
content_prompt[i % len(content_prompt)] if content_prompt[i % len(content_prompt)] else ""
)
if not isinstance(generator, (torch.Generator,)):
upsampling_generator.append(generator[i % len(content_prompt)])
# 2. Apply the denosing loop
upsampling_output = self.upsampling_pipe(
prompt=upsampling_prompt,
image=upsampling_image,
mask_image=upsampling_mask,
height=upsampling_height,
width=upsampling_width,
strength=upsampling_strength,
num_inference_steps=num_inference_steps,
sigmas=sigmas,
guidance_scale=guidance_scale,
generator=upsampling_generator,
output_type=output_type,
joint_attention_kwargs=joint_attention_kwargs,
callback_on_step_end=callback_on_step_end,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
image = upsampling_output.images
output = []
if output_type == "pil":
# Each sample in the batch may have multiple output images. When returning as PIL images,
# these images cannot be concatenated. Therefore, for each sample,
# a list is used to represent all the output images.
output = []
start = 0
for n in n_target_per_sample:
output.append(image[start : start + n])
start += n
else:
output = image
if not return_dict:
return (output,)
return FluxPipelineOutput(images=output)
@@ -0,0 +1,952 @@
# Copyright 2025 VisualCloze team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ..flux.pipeline_flux_fill import calculate_shift, retrieve_latents, retrieve_timesteps
from ..flux.pipeline_output import FluxPipelineOutput
from ..pipeline_utils import DiffusionPipeline
from .visualcloze_utils import VisualClozeProcessor
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers import VisualClozeGenerationPipeline, FluxFillPipeline as VisualClozeUpsamplingPipeline
>>> from diffusers.utils import load_image
>>> from PIL import Image
>>> image_paths = [
... # in-context examples
... [
... load_image(
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_mask.jpg"
... ),
... load_image(
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_incontext-example-1_image.jpg"
... ),
... ],
... # query with the target image
... [
... load_image(
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/visualcloze/visualcloze_mask2image_query_mask.jpg"
... ),
... None, # No image needed for the target image
... ],
... ]
>>> task_prompt = "In each row, a logical task is demonstrated to achieve [IMAGE2] an aesthetically pleasing photograph based on [IMAGE1] sam 2-generated masks with rich color coding."
>>> content_prompt = "Majestic photo of a golden eagle perched on a rocky outcrop in a mountainous landscape. The eagle is positioned in the right foreground, facing left, with its sharp beak and keen eyes prominently visible. Its plumage is a mix of dark brown and golden hues, with intricate feather details. The background features a soft-focus view of snow-capped mountains under a cloudy sky, creating a serene and grandiose atmosphere. The foreground includes rugged rocks and patches of green moss. Photorealistic, medium depth of field, soft natural lighting, cool color palette, high contrast, sharp focus on the eagle, blurred background, tranquil, majestic, wildlife photography."
>>> pipe = VisualClozeGenerationPipeline.from_pretrained(
... "VisualCloze/VisualClozePipeline-384", resolution=384, torch_dtype=torch.bfloat16
... )
>>> pipe.to("cuda")
>>> image = pipe(
... task_prompt=task_prompt,
... content_prompt=content_prompt,
... image=image_paths,
... guidance_scale=30,
... num_inference_steps=30,
... max_sequence_length=512,
... generator=torch.Generator("cpu").manual_seed(0),
... ).images[0][0]
>>> # optional, upsampling the generated image
>>> pipe_upsample = VisualClozeUpsamplingPipeline.from_pipe(pipe)
>>> pipe_upsample.to("cuda")
>>> mask_image = Image.new("RGB", image.size, (255, 255, 255))
>>> image = pipe_upsample(
... image=image,
... mask_image=mask_image,
... prompt=content_prompt,
... width=1344,
... height=768,
... strength=0.4,
... guidance_scale=30,
... num_inference_steps=30,
... max_sequence_length=512,
... generator=torch.Generator("cpu").manual_seed(0),
... ).images[0]
>>> image.save("visualcloze.png")
```
"""
class VisualClozeGenerationPipeline(
DiffusionPipeline,
FluxLoraLoaderMixin,
FromSingleFileMixin,
TextualInversionLoaderMixin,
):
r"""
The VisualCloze pipeline for image generation with visual context. Reference:
https://github.com/lzyhha/VisualCloze/tree/main This pipeline is designed to generate images based on visual
in-context examples.
Args:
transformer ([`FluxTransformer2DModel`]):
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
text_encoder_2 ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
tokenizer_2 (`T5TokenizerFast`):
Second Tokenizer of class
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
resolution (`int`, *optional*, defaults to 384):
The resolution of each image when concatenating images from the query and in-context examples.
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
_optional_components = []
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
text_encoder_2: T5EncoderModel,
tokenizer_2: T5TokenizerFast,
transformer: FluxTransformer2DModel,
resolution: int = 384,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
transformer=transformer,
scheduler=scheduler,
)
self.resolution = resolution
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
self.image_processor = VisualClozeProcessor(
vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels, resolution=resolution
)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
self.default_sample_size = 128
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 512,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
text_inputs = self.tokenizer_2(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
dtype = self.text_encoder_2.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
def _get_clip_prompt_embeds(
self,
prompt: Union[str, List[str]],
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
):
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer_max_length,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer_max_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds
# Modified from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
def encode_prompt(
self,
layout_prompt: Union[str, List[str]],
task_prompt: Union[str, List[str]],
content_prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 512,
lora_scale: Optional[float] = None,
):
r"""
Args:
layout_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to define the number of in-context examples and the number of images involved in
the task.
task_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to define the task intention.
content_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to define the content or caption of the target image to be generated.
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
device = device or self._execution_device
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if self.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder_2, lora_scale)
if isinstance(layout_prompt, str):
layout_prompt = [layout_prompt]
task_prompt = [task_prompt]
content_prompt = [content_prompt]
def _preprocess(prompt, content=False):
if prompt is not None:
return f"The last image of the last row depicts: {prompt}" if content else prompt
else:
return ""
prompt = [
f"{_preprocess(layout_prompt[i])} {_preprocess(task_prompt[i])} {_preprocess(content_prompt[i], content=True)}".strip()
for i in range(len(layout_prompt))
]
pooled_prompt_embeds = self._get_clip_prompt_embeds(
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
)
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
if self.text_encoder is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder_2, lora_scale)
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
return prompt_embeds, pooled_prompt_embeds, text_ids
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
return image_latents
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(num_inference_steps * strength, num_inference_steps)
t_start = int(max(num_inference_steps - init_timestep, 0))
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start
def check_inputs(
self,
image,
task_prompt,
content_prompt,
prompt_embeds=None,
pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
# Validate prompt inputs
if (task_prompt is not None or content_prompt is not None) and prompt_embeds is not None:
raise ValueError("Cannot provide both text `task_prompt` + `content_prompt` and `prompt_embeds`. ")
if task_prompt is None and content_prompt is None and prompt_embeds is None:
raise ValueError("Must provide either `task_prompt` + `content_prompt` or pre-computed `prompt_embeds`. ")
# Validate prompt types and consistency
if task_prompt is None:
raise ValueError("`task_prompt` is missing.")
if task_prompt is not None and not isinstance(task_prompt, (str, list)):
raise ValueError(f"`task_prompt` must be str or list, got {type(task_prompt)}")
if content_prompt is not None and not isinstance(content_prompt, (str, list)):
raise ValueError(f"`content_prompt` must be str or list, got {type(content_prompt)}")
if isinstance(task_prompt, list) or isinstance(content_prompt, list):
if not isinstance(task_prompt, list) or not isinstance(content_prompt, list):
raise ValueError(
f"`task_prompt` and `content_prompt` must both be lists, or both be of type str or None, "
f"got {type(task_prompt)} and {type(content_prompt)}"
)
if len(content_prompt) != len(task_prompt):
raise ValueError("`task_prompt` and `content_prompt` must have the same length whe they are lists.")
for sample in image:
if not isinstance(sample, list) or not isinstance(sample[0], list):
raise ValueError("Each sample in the batch must have a 2D list of images.")
if len({len(row) for row in sample}) != 1:
raise ValueError("Each in-context example and query should contain the same number of images.")
if not any(img is None for img in sample[-1]):
raise ValueError("There are no targets in the query, which should be represented as None.")
for row in sample[:-1]:
if any(img is None for img in row):
raise ValueError("Images are missing in in-context examples.")
# Validate embeddings
if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
)
# Validate sequence length
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"max_sequence_length cannot exceed 512, got {max_sequence_length}")
@staticmethod
def _prepare_latent_image_ids(image, vae_scale_factor, device, dtype):
latent_image_ids = []
for idx, img in enumerate(image, start=1):
img = img.squeeze(0)
channels, height, width = img.shape
num_patches_h = height // vae_scale_factor // 2
num_patches_w = width // vae_scale_factor // 2
patch_ids = torch.zeros(num_patches_h, num_patches_w, 3, device=device, dtype=dtype)
patch_ids[..., 0] = idx
patch_ids[..., 1] = torch.arange(num_patches_h, device=device, dtype=dtype)[:, None]
patch_ids[..., 2] = torch.arange(num_patches_w, device=device, dtype=dtype)[None, :]
patch_ids = patch_ids.reshape(-1, 3)
latent_image_ids.append(patch_ids)
return torch.cat(latent_image_ids, dim=0)
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
@staticmethod
def _unpack_latents(latents, sizes, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
start = 0
unpacked_latents = []
for i in range(len(sizes)):
cur_size = sizes[i]
height = cur_size[0][0] // vae_scale_factor
width = sum([size[1] for size in cur_size]) // vae_scale_factor
end = start + (height * width) // 4
cur_latents = latents[:, start:end]
cur_latents = cur_latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
cur_latents = cur_latents.permute(0, 3, 1, 4, 2, 5)
cur_latents = cur_latents.reshape(batch_size, channels // (2 * 2), height, width)
unpacked_latents.append(cur_latents)
start = end
return unpacked_latents
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
def _prepare_latents(self, image, mask, gen, vae_scale_factor, device, dtype):
"""Helper function to prepare latents for a single batch."""
# Concatenate images and masks along width dimension
image = [torch.cat(img, dim=3).to(device=device, dtype=dtype) for img in image]
mask = [torch.cat(m, dim=3).to(device=device, dtype=dtype) for m in mask]
# Generate latent image IDs
latent_image_ids = self._prepare_latent_image_ids(image, vae_scale_factor, device, dtype)
# For initial encoding, use actual images
image_latent = [self._encode_vae_image(img, gen) for img in image]
masked_image_latent = [img.clone() for img in image_latent]
for i in range(len(image_latent)):
# Rearrange latents and masks for patch processing
num_channels_latents, height, width = image_latent[i].shape[1:]
image_latent[i] = self._pack_latents(image_latent[i], 1, num_channels_latents, height, width)
masked_image_latent[i] = self._pack_latents(masked_image_latent[i], 1, num_channels_latents, height, width)
# Rearrange masks for patch processing
num_channels_latents, height, width = mask[i].shape[1:]
mask[i] = mask[i].view(
1,
num_channels_latents,
height // vae_scale_factor,
vae_scale_factor,
width // vae_scale_factor,
vae_scale_factor,
)
mask[i] = mask[i].permute(0, 1, 3, 5, 2, 4)
mask[i] = mask[i].reshape(
1,
num_channels_latents * (vae_scale_factor**2),
height // vae_scale_factor,
width // vae_scale_factor,
)
mask[i] = self._pack_latents(
mask[i],
1,
num_channels_latents * (vae_scale_factor**2),
height // vae_scale_factor,
width // vae_scale_factor,
)
# Concatenate along batch dimension
image_latent = torch.cat(image_latent, dim=1)
masked_image_latent = torch.cat(masked_image_latent, dim=1)
mask = torch.cat(mask, dim=1)
return image_latent, masked_image_latent, mask, latent_image_ids
def prepare_latents(
self,
input_image,
input_mask,
timestep,
batch_size,
dtype,
device,
generator,
vae_scale_factor,
):
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
# Process each batch
masked_image_latents = []
image_latents = []
masks = []
latent_image_ids = []
for i in range(len(input_image)):
_image_latent, _masked_image_latent, _mask, _latent_image_ids = self._prepare_latents(
input_image[i],
input_mask[i],
generator if isinstance(generator, torch.Generator) else generator[i],
vae_scale_factor,
device,
dtype,
)
masked_image_latents.append(_masked_image_latent)
image_latents.append(_image_latent)
masks.append(_mask)
latent_image_ids.append(_latent_image_ids)
# Concatenate all batches
masked_image_latents = torch.cat(masked_image_latents, dim=0)
image_latents = torch.cat(image_latents, dim=0)
masks = torch.cat(masks, dim=0)
# Handle batch size expansion
if batch_size > masked_image_latents.shape[0]:
if batch_size % masked_image_latents.shape[0] == 0:
# Expand batches by repeating
additional_image_per_prompt = batch_size // masked_image_latents.shape[0]
masked_image_latents = torch.cat([masked_image_latents] * additional_image_per_prompt, dim=0)
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
masks = torch.cat([masks] * additional_image_per_prompt, dim=0)
else:
raise ValueError(
f"Cannot expand batch size from {masked_image_latents.shape[0]} to {batch_size}. "
"Batch sizes must be multiples of each other."
)
# Add noise to latents
noises = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=dtype)
latents = self.scheduler.scale_noise(image_latents, timestep, noises).to(dtype=dtype)
# Combine masked latents with masks
masked_image_latents = torch.cat((masked_image_latents, masks), dim=-1).to(dtype=dtype)
return latents, masked_image_latents, latent_image_ids[0]
@property
def guidance_scale(self):
return self._guidance_scale
@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
task_prompt: Union[str, List[str]] = None,
content_prompt: Union[str, List[str]] = None,
image: Optional[torch.FloatTensor] = None,
num_inference_steps: int = 50,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 30.0,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
):
r"""
Function invoked when calling the VisualCloze pipeline for generation.
Args:
task_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to define the task intention.
content_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to define the content or caption of the target image to be generated.
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 30.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
images.
"""
# 1. Check inputs. Raise error if not correct
self.check_inputs(
image,
task_prompt,
content_prompt,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
processor_output = self.image_processor.preprocess(
task_prompt, content_prompt, image, vae_scale_factor=self.vae_scale_factor
)
# 2. Define call parameters
if processor_output["task_prompt"] is not None and isinstance(processor_output["task_prompt"], str):
batch_size = 1
elif processor_output["task_prompt"] is not None and isinstance(processor_output["task_prompt"], list):
batch_size = len(processor_output["task_prompt"])
device = self._execution_device
# 3. Prepare prompt embeddings
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
layout_prompt=processor_output["layout_prompt"],
task_prompt=processor_output["task_prompt"],
content_prompt=processor_output["content_prompt"],
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
# 4. Prepare timesteps
# Calculate sequence length and shift factor
image_seq_len = sum(
(size[0] // self.vae_scale_factor // 2) * (size[1] // self.vae_scale_factor // 2)
for sample in processor_output["image_size"][0]
for size in sample
)
# Calculate noise schedule parameters
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
# Get timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
mu=mu,
)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, 1.0, device)
# 5. Prepare latent variables
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
latents, masked_image_latents, latent_image_ids = self.prepare_latents(
processor_output["init_image"],
processor_output["mask"],
latent_timestep,
batch_size * num_images_per_prompt,
prompt_embeds.dtype,
device,
generator,
vae_scale_factor=self.vae_scale_factor,
)
# Calculate warmup steps
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# Prepare guidance
if self.transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# Broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
latent_model_input = torch.cat((latents, masked_image_latents), dim=2)
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
# Compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# Some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# Call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
# XLA optimization
if XLA_AVAILABLE:
xm.mark_step()
# 7. Post-process the image
# Crop the target image
# Since the generated image is a concatenation of the conditional and target regions,
# we need to extract only the target regions based on their positions
image = []
if output_type == "latent":
image = latents
else:
for b in range(len(latents)):
cur_image_size = processor_output["image_size"][b % batch_size]
cur_target_position = processor_output["target_position"][b % batch_size]
cur_latent = self._unpack_latents(latents[b].unsqueeze(0), cur_image_size, self.vae_scale_factor)[-1]
cur_latent = (cur_latent / self.vae.config.scaling_factor) + self.vae.config.shift_factor
cur_image = self.vae.decode(cur_latent, return_dict=False)[0]
cur_image = self.image_processor.postprocess(cur_image, output_type=output_type)[0]
start = 0
cropped = []
for i, size in enumerate(cur_image_size[-1]):
if cur_target_position[i]:
if output_type == "pil":
cropped.append(cur_image.crop((start, 0, start + size[1], size[0])))
else:
cropped.append(cur_image[0 : size[0], start : start + size[1]])
start += size[1]
image.append(cropped)
if output_type != "pil":
image = np.concatenate([arr[None] for sub_image in image for arr in sub_image], axis=0)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return FluxPipelineOutput(images=image)
@@ -0,0 +1,251 @@
# Copyright 2025 VisualCloze team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional, Tuple, Union
import torch
from PIL import Image
from ...image_processor import VaeImageProcessor
class VisualClozeProcessor(VaeImageProcessor):
"""
Image processor for the VisualCloze pipeline.
This processor handles the preprocessing of images for visual cloze tasks, including resizing, normalization, and
mask generation.
Args:
resolution (int, optional):
Target resolution for processing images. Each image will be resized to this resolution before being
concatenated to avoid the out-of-memory error. Defaults to 384.
*args: Additional arguments passed to [~image_processor.VaeImageProcessor]
**kwargs: Additional keyword arguments passed to [~image_processor.VaeImageProcessor]
"""
def __init__(self, *args, resolution: int = 384, **kwargs):
super().__init__(*args, **kwargs)
self.resolution = resolution
def preprocess_image(
self, input_images: List[List[Optional[Image.Image]]], vae_scale_factor: int
) -> Tuple[List[List[torch.Tensor]], List[List[List[int]]], List[int]]:
"""
Preprocesses input images for the VisualCloze pipeline.
This function handles the preprocessing of input images by:
1. Resizing and cropping images to maintain consistent dimensions
2. Converting images to the Tensor format for the VAE
3. Normalizing pixel values
4. Tracking image sizes and positions of target images
Args:
input_images (List[List[Optional[Image.Image]]]):
A nested list of PIL Images where:
- Outer list represents different samples, including in-context examples and the query
- Inner list contains images for the task
- In the last row, condition images are provided and the target images are placed as None
vae_scale_factor (int):
The scale factor used by the VAE for resizing images
Returns:
Tuple containing:
- List[List[torch.Tensor]]: Preprocessed images in tensor format
- List[List[List[int]]]: Dimensions of each processed image [height, width]
- List[int]: Target positions indicating which images are to be generated
"""
n_samples, n_task_images = len(input_images), len(input_images[0])
divisible = 2 * vae_scale_factor
processed_images: List[List[Image.Image]] = [[] for _ in range(n_samples)]
resize_size: List[Optional[Tuple[int, int]]] = [None for _ in range(n_samples)]
target_position: List[int] = []
# Process each sample
for i in range(n_samples):
# Determine size from first non-None image
for j in range(n_task_images):
if input_images[i][j] is not None:
aspect_ratio = input_images[i][j].width / input_images[i][j].height
target_area = self.resolution * self.resolution
new_h = int((target_area / aspect_ratio) ** 0.5)
new_w = int(new_h * aspect_ratio)
new_w = max(new_w // divisible, 1) * divisible
new_h = max(new_h // divisible, 1) * divisible
resize_size[i] = (new_w, new_h)
break
# Process all images in the sample
for j in range(n_task_images):
if input_images[i][j] is not None:
target = self._resize_and_crop(input_images[i][j], resize_size[i][0], resize_size[i][1])
processed_images[i].append(target)
if i == n_samples - 1:
target_position.append(0)
else:
blank = Image.new("RGB", resize_size[i] or (self.resolution, self.resolution), (0, 0, 0))
processed_images[i].append(blank)
if i == n_samples - 1:
target_position.append(1)
# Ensure consistent width for multiple target images when there are multiple target images
if len(target_position) > 1 and sum(target_position) > 1:
new_w = resize_size[n_samples - 1][0] or 384
for i in range(len(processed_images)):
for j in range(len(processed_images[i])):
if processed_images[i][j] is not None:
new_h = int(processed_images[i][j].height * (new_w / processed_images[i][j].width))
new_w = int(new_w / 16) * 16
new_h = int(new_h / 16) * 16
processed_images[i][j] = self.height(processed_images[i][j], new_h, new_w)
# Convert to tensors and normalize
image_sizes = []
for i in range(len(processed_images)):
image_sizes.append([[img.height, img.width] for img in processed_images[i]])
for j, image in enumerate(processed_images[i]):
image = self.pil_to_numpy(image)
image = self.numpy_to_pt(image)
image = self.normalize(image)
processed_images[i][j] = image
return processed_images, image_sizes, target_position
def preprocess_mask(
self, input_images: List[List[Image.Image]], target_position: List[int]
) -> List[List[torch.Tensor]]:
"""
Generate masks for the VisualCloze pipeline.
Args:
input_images (List[List[Image.Image]]):
Processed images from preprocess_image
target_position (List[int]):
Binary list marking the positions of target images (1 for target, 0 for condition)
Returns:
List[List[torch.Tensor]]:
A nested list of mask tensors (1 for target positions, 0 for condition images)
"""
mask = []
for i, row in enumerate(input_images):
if i == len(input_images) - 1: # Query row
row_masks = [
torch.full((1, 1, row[0].shape[2], row[0].shape[3]), fill_value=m) for m in target_position
]
else: # In-context examples
row_masks = [
torch.full((1, 1, row[0].shape[2], row[0].shape[3]), fill_value=0) for _ in target_position
]
mask.append(row_masks)
return mask
def preprocess_image_upsampling(
self,
input_images: List[List[Image.Image]],
height: int,
width: int,
) -> Tuple[List[List[Image.Image]], List[List[List[int]]]]:
"""Process images for the upsampling stage in the VisualCloze pipeline.
Args:
input_images: Input image to process
height: Target height
width: Target width
Returns:
Tuple of processed image and its size
"""
image = self.resize(input_images[0][0], height, width)
image = self.pil_to_numpy(image) # to np
image = self.numpy_to_pt(image) # to pt
image = self.normalize(image)
input_images[0][0] = image
image_sizes = [[[height, width]]]
return input_images, image_sizes
def preprocess_mask_upsampling(self, input_images: List[List[Image.Image]]) -> List[List[torch.Tensor]]:
return [[torch.ones((1, 1, input_images[0][0].shape[2], input_images[0][0].shape[3]))]]
def get_layout_prompt(self, size: Tuple[int, int]) -> str:
layout_instruction = (
f"A grid layout with {size[0]} rows and {size[1]} columns, displaying {size[0] * size[1]} images arranged side by side.",
)
return layout_instruction
def preprocess(
self,
task_prompt: Union[str, List[str]],
content_prompt: Union[str, List[str]],
input_images: Optional[List[List[List[Optional[str]]]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
upsampling: bool = False,
vae_scale_factor: int = 16,
) -> Dict:
"""Process visual cloze inputs.
Args:
task_prompt: Task description(s)
content_prompt: Content description(s)
input_images: List of images or None for the target images
height: Optional target height for upsampling stage
width: Optional target width for upsampling stage
upsampling: Whether this is in the upsampling processing stage
Returns:
Dictionary containing processed images, masks, prompts and metadata
"""
if isinstance(task_prompt, str):
task_prompt = [task_prompt]
content_prompt = [content_prompt]
input_images = [input_images]
output = {
"init_image": [],
"mask": [],
"task_prompt": task_prompt if not upsampling else [None for _ in range(len(task_prompt))],
"content_prompt": content_prompt,
"layout_prompt": [],
"target_position": [],
"image_size": [],
}
for i in range(len(task_prompt)):
if upsampling:
layout_prompt = None
else:
layout_prompt = self.get_layout_prompt((len(input_images[i]), len(input_images[i][0])))
if upsampling:
cur_processed_images, cur_image_size = self.preprocess_image_upsampling(
input_images[i], height=height, width=width
)
cur_mask = self.preprocess_mask_upsampling(cur_processed_images)
else:
cur_processed_images, cur_image_size, cur_target_position = self.preprocess_image(
input_images[i], vae_scale_factor=vae_scale_factor
)
cur_mask = self.preprocess_mask(cur_processed_images, cur_target_position)
output["target_position"].append(cur_target_position)
output["image_size"].append(cur_image_size)
output["init_image"].append(cur_processed_images)
output["mask"].append(cur_mask)
output["layout_prompt"].append(layout_prompt)
return output
+178
View File
@@ -12,5 +12,183 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Dict, List, Optional, Union
from ..utils import is_transformers_available, logging
from .auto import DiffusersAutoQuantizer
from .base import DiffusersQuantizer
from .quantization_config import QuantizationConfigMixin as DiffQuantConfigMixin
try:
from transformers.utils.quantization_config import QuantizationConfigMixin as TransformersQuantConfigMixin
except ImportError:
class TransformersQuantConfigMixin:
pass
logger = logging.get_logger(__name__)
class PipelineQuantizationConfig:
"""
Configuration class to be used when applying quantization on-the-fly to [`~DiffusionPipeline.from_pretrained`].
Args:
quant_backend (`str`): Quantization backend to be used. When using this option, we assume that the backend
is available to both `diffusers` and `transformers`.
quant_kwargs (`dict`): Params to initialize the quantization backend class.
components_to_quantize (`list`): Components of a pipeline to be quantized.
quant_mapping (`dict`): Mapping defining the quantization specs to be used for the pipeline
components. When using this argument, users are not expected to provide `quant_backend`, `quant_kawargs`,
and `components_to_quantize`.
"""
def __init__(
self,
quant_backend: str = None,
quant_kwargs: Dict[str, Union[str, float, int, dict]] = None,
components_to_quantize: Optional[List[str]] = None,
quant_mapping: Dict[str, Union[DiffQuantConfigMixin, "TransformersQuantConfigMixin"]] = None,
):
self.quant_backend = quant_backend
# Initialize kwargs to be {} to set to the defaults.
self.quant_kwargs = quant_kwargs or {}
self.components_to_quantize = components_to_quantize
self.quant_mapping = quant_mapping
self.post_init()
def post_init(self):
quant_mapping = self.quant_mapping
self.is_granular = True if quant_mapping is not None else False
self._validate_init_args()
def _validate_init_args(self):
if self.quant_backend and self.quant_mapping:
raise ValueError("Both `quant_backend` and `quant_mapping` cannot be specified at the same time.")
if not self.quant_mapping and not self.quant_backend:
raise ValueError("Must provide a `quant_backend` when not providing a `quant_mapping`.")
if not self.quant_kwargs and not self.quant_mapping:
raise ValueError("Both `quant_kwargs` and `quant_mapping` cannot be None.")
if self.quant_backend is not None:
self._validate_init_kwargs_in_backends()
if self.quant_mapping is not None:
self._validate_quant_mapping_args()
def _validate_init_kwargs_in_backends(self):
quant_backend = self.quant_backend
self._check_backend_availability(quant_backend)
quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
if quant_config_mapping_transformers is not None:
init_kwargs_transformers = inspect.signature(quant_config_mapping_transformers[quant_backend].__init__)
init_kwargs_transformers = {name for name in init_kwargs_transformers.parameters if name != "self"}
else:
init_kwargs_transformers = None
init_kwargs_diffusers = inspect.signature(quant_config_mapping_diffusers[quant_backend].__init__)
init_kwargs_diffusers = {name for name in init_kwargs_diffusers.parameters if name != "self"}
if init_kwargs_transformers != init_kwargs_diffusers:
raise ValueError(
"The signatures of the __init__ methods of the quantization config classes in `diffusers` and `transformers` don't match. "
f"Please provide a `quant_mapping` instead, in the {self.__class__.__name__} class. Refer to [the docs](https://huggingface.co/docs/diffusers/main/en/quantization/overview#pipeline-level-quantization) to learn more about how "
"this mapping would look like."
)
def _validate_quant_mapping_args(self):
quant_mapping = self.quant_mapping
transformers_map, diffusers_map = self._get_quant_config_list()
available_transformers = list(transformers_map.values()) if transformers_map else None
available_diffusers = list(diffusers_map.values())
for module_name, config in quant_mapping.items():
if any(isinstance(config, cfg) for cfg in available_diffusers):
continue
if available_transformers and any(isinstance(config, cfg) for cfg in available_transformers):
continue
if available_transformers:
raise ValueError(
f"Provided config for module_name={module_name} could not be found. "
f"Available diffusers configs: {available_diffusers}; "
f"Available transformers configs: {available_transformers}."
)
else:
raise ValueError(
f"Provided config for module_name={module_name} could not be found. "
f"Available diffusers configs: {available_diffusers}."
)
def _check_backend_availability(self, quant_backend: str):
quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
available_backends_transformers = (
list(quant_config_mapping_transformers.keys()) if quant_config_mapping_transformers else None
)
available_backends_diffusers = list(quant_config_mapping_diffusers.keys())
if (
available_backends_transformers and quant_backend not in available_backends_transformers
) or quant_backend not in quant_config_mapping_diffusers:
error_message = f"Provided quant_backend={quant_backend} was not found."
if available_backends_transformers:
error_message += f"\nAvailable ones (transformers): {available_backends_transformers}."
error_message += f"\nAvailable ones (diffusers): {available_backends_diffusers}."
raise ValueError(error_message)
def _resolve_quant_config(self, is_diffusers: bool = True, module_name: str = None):
quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
quant_mapping = self.quant_mapping
components_to_quantize = self.components_to_quantize
# Granular case
if self.is_granular and module_name in quant_mapping:
logger.debug(f"Initializing quantization config class for {module_name}.")
config = quant_mapping[module_name]
return config
# Global config case
else:
should_quantize = False
# Only quantize the modules requested for.
if components_to_quantize and module_name in components_to_quantize:
should_quantize = True
# No specification for `components_to_quantize` means all modules should be quantized.
elif not self.is_granular and not components_to_quantize:
should_quantize = True
if should_quantize:
logger.debug(f"Initializing quantization config class for {module_name}.")
mapping_to_use = quant_config_mapping_diffusers if is_diffusers else quant_config_mapping_transformers
quant_config_cls = mapping_to_use[self.quant_backend]
quant_kwargs = self.quant_kwargs
return quant_config_cls(**quant_kwargs)
# Fallback: no applicable configuration found.
return None
def _get_quant_config_list(self):
if is_transformers_available():
from transformers.quantizers.auto import (
AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_transformers,
)
else:
quant_config_mapping_transformers = None
from ..quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_diffusers
return quant_config_mapping_transformers, quant_config_mapping_diffusers
+15 -12
View File
@@ -408,6 +408,18 @@ class GGUFParameter(torch.nn.Parameter):
def as_tensor(self):
return torch.Tensor._make_subclass(torch.Tensor, self, self.requires_grad)
@staticmethod
def _extract_quant_type(args):
# When converting from original format checkpoints we often use splits, cats etc on tensors
# this method ensures that the returned tensor type from those operations remains GGUFParameter
# so that we preserve quant_type information
for arg in args:
if isinstance(arg, list) and isinstance(arg[0], GGUFParameter):
return arg[0].quant_type
if isinstance(arg, GGUFParameter):
return arg.quant_type
return None
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
@@ -415,22 +427,13 @@ class GGUFParameter(torch.nn.Parameter):
result = super().__torch_function__(func, types, args, kwargs)
# When converting from original format checkpoints we often use splits, cats etc on tensors
# this method ensures that the returned tensor type from those operations remains GGUFParameter
# so that we preserve quant_type information
quant_type = None
for arg in args:
if isinstance(arg, list) and isinstance(arg[0], GGUFParameter):
quant_type = arg[0].quant_type
break
if isinstance(arg, GGUFParameter):
quant_type = arg.quant_type
break
if isinstance(result, torch.Tensor):
quant_type = cls._extract_quant_type(args)
return cls(result, quant_type=quant_type)
# Handle tuples and lists
elif isinstance(result, (tuple, list)):
elif type(result) in (list, tuple):
# Preserve the original type (tuple or list)
quant_type = cls._extract_quant_type(args)
wrapped = [cls(x, quant_type=quant_type) if isinstance(x, torch.Tensor) else x for x in result]
return type(result)(wrapped)
else:
@@ -75,7 +75,7 @@ class QuantizationConfigMixin:
Args:
config_dict (`Dict[str, Any]`):
Dictionary that will be used to instantiate the configuration object.
return_unused_kwargs (`bool`,*optional*, defaults to `False`):
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
Whether or not to return a list of unused keyword arguments. Used for `from_pretrained` method in
`PreTrainedModel`.
kwargs (`Dict[str, Any]`):
@@ -144,7 +144,7 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs
def precondition_inputs(self, sample, sigma):
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
c_in = self._get_conditioning_c_in(sigma)
scaled_sample = sample * c_in
return scaled_sample
@@ -568,5 +568,10 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
noisy_samples = original_samples + noise * sigma
return noisy_samples
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._get_conditioning_c_in
def _get_conditioning_c_in(self, sigma):
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
return c_in
def __len__(self):
return self.config.num_train_timesteps
@@ -176,7 +176,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs
def precondition_inputs(self, sample, sigma):
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
c_in = self._get_conditioning_c_in(sigma)
scaled_sample = sample * c_in
return scaled_sample
@@ -703,5 +703,10 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
noisy_samples = original_samples + noise * sigma
return noisy_samples
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._get_conditioning_c_in
def _get_conditioning_c_in(self, sigma):
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
return c_in
def __len__(self):
return self.config.num_train_timesteps
@@ -103,11 +103,13 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
# setable values
self.num_inference_steps = None
sigmas = torch.arange(num_train_timesteps + 1) / num_train_timesteps
sigmas_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
sigmas = torch.arange(num_train_timesteps + 1, dtype=sigmas_dtype) / num_train_timesteps
if sigma_schedule == "karras":
sigmas = self._compute_karras_sigmas(sigmas)
elif sigma_schedule == "exponential":
sigmas = self._compute_exponential_sigmas(sigmas)
sigmas = sigmas.to(torch.float32)
self.timesteps = self.precondition_noise(sigmas)
@@ -159,7 +161,7 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
self._begin_index = begin_index
def precondition_inputs(self, sample, sigma):
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
c_in = self._get_conditioning_c_in(sigma)
scaled_sample = sample * c_in
return scaled_sample
@@ -230,18 +232,19 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
"""
self.num_inference_steps = num_inference_steps
sigmas_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
if sigmas is None:
sigmas = torch.linspace(0, 1, self.num_inference_steps)
sigmas = torch.linspace(0, 1, self.num_inference_steps, dtype=sigmas_dtype)
elif isinstance(sigmas, float):
sigmas = torch.tensor(sigmas, dtype=torch.float32)
sigmas = torch.tensor(sigmas, dtype=sigmas_dtype)
else:
sigmas = sigmas
sigmas = sigmas.to(sigmas_dtype)
if self.config.sigma_schedule == "karras":
sigmas = self._compute_karras_sigmas(sigmas)
elif self.config.sigma_schedule == "exponential":
sigmas = self._compute_exponential_sigmas(sigmas)
sigmas = sigmas.to(dtype=torch.float32, device=device)
self.timesteps = self.precondition_noise(sigmas)
if self.config.final_sigmas_type == "sigma_min":
@@ -315,6 +318,7 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
s_noise: float = 1.0,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
pred_original_sample: Optional[torch.Tensor] = None,
) -> Union[EDMEulerSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
@@ -378,7 +382,8 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
pred_original_sample = self.precondition_outputs(sample, model_output, sigma_hat)
if pred_original_sample is None:
pred_original_sample = self.precondition_outputs(sample, model_output, sigma_hat)
# 2. Convert to an ODE derivative
derivative = (sample - pred_original_sample) / sigma_hat
@@ -435,5 +440,9 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
noisy_samples = original_samples + noise * sigma
return noisy_samples
def _get_conditioning_c_in(self, sigma):
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
return c_in
def __len__(self):
return self.config.num_train_timesteps
+4
View File
@@ -62,9 +62,11 @@ from .import_utils import (
get_objects_from_module,
is_accelerate_available,
is_accelerate_version,
is_better_profanity_available,
is_bitsandbytes_available,
is_bitsandbytes_version,
is_bs4_available,
is_cosmos_guardrail_available,
is_flax_available,
is_ftfy_available,
is_gguf_available,
@@ -78,6 +80,7 @@ from .import_utils import (
is_k_diffusion_version,
is_librosa_available,
is_matplotlib_available,
is_nltk_available,
is_note_seq_available,
is_onnx_available,
is_opencv_available,
@@ -85,6 +88,7 @@ from .import_utils import (
is_optimum_quanto_version,
is_peft_available,
is_peft_version,
is_pytorch_retinaface_available,
is_safetensors_available,
is_scipy_available,
is_sentencepiece_available,
+45
View File
@@ -160,6 +160,21 @@ class AutoencoderKLCogVideoX(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class AutoencoderKLCosmos(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoencoderKLHunyuanVideo(metaclass=DummyObject):
_backends = ["torch"]
@@ -430,6 +445,21 @@ class ControlNetXSAdapter(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class CosmosTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class DiTTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -565,6 +595,21 @@ class HunyuanDiT2DMultiControlNetModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class HunyuanVideoFramepackTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class HunyuanVideoTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -392,6 +392,51 @@ class CogView4Pipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class ConsisIDPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class CosmosTextToWorldPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class CosmosVideoToWorldPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class CycleDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -692,6 +737,21 @@ class HunyuanSkyreelsImageToVideoPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class HunyuanVideoFramepackPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class HunyuanVideoImageToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -1247,6 +1307,21 @@ class LTXImageToVideoPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class LTXLatentUpsamplePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class LTXPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -2732,6 +2807,36 @@ class VideoToVideoSDPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class VisualClozeGenerationPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class VisualClozePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class VQDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
+39
View File
@@ -215,6 +215,10 @@ _gguf_available, _gguf_version = _is_package_available("gguf")
_torchao_available, _torchao_version = _is_package_available("torchao")
_bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes")
_optimum_quanto_available, _optimum_quanto_version = _is_package_available("optimum", get_dist_name=True)
_pytorch_retinaface_available, _pytorch_retinaface_version = _is_package_available("pytorch_retinaface")
_better_profanity_available, _better_profanity_version = _is_package_available("better_profanity")
_nltk_available, _nltk_version = _is_package_available("nltk")
_cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("cosmos_guardrail")
def is_torch_available():
@@ -353,6 +357,22 @@ def is_timm_available():
return _timm_available
def is_pytorch_retinaface_available():
return _pytorch_retinaface_available
def is_better_profanity_available():
return _better_profanity_available
def is_nltk_available():
return _nltk_available
def is_cosmos_guardrail_available():
return _cosmos_guardrail_available
def is_hpu_available():
return all(importlib.util.find_spec(lib) for lib in ("habana_frameworks", "habana_frameworks.torch"))
@@ -505,6 +525,22 @@ QUANTO_IMPORT_ERROR = """
install optimum-quanto`
"""
# docstyle-ignore
PYTORCH_RETINAFACE_IMPORT_ERROR = """
{0} requires the pytorch_retinaface library but it was not found in your environment. You can install it with pip: `pip install pytorch_retinaface`
"""
# docstyle-ignore
BETTER_PROFANITY_IMPORT_ERROR = """
{0} requires the better_profanity library but it was not found in your environment. You can install it with pip: `pip install better_profanity`
"""
# docstyle-ignore
NLTK_IMPORT_ERROR = """
{0} requires the nltk library but it was not found in your environment. You can install it with pip: `pip install nltk`
"""
BACKENDS_MAPPING = OrderedDict(
[
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
@@ -533,6 +569,9 @@ BACKENDS_MAPPING = OrderedDict(
("gguf", (is_gguf_available, GGUF_IMPORT_ERROR)),
("torchao", (is_torchao_available, TORCHAO_IMPORT_ERROR)),
("quanto", (is_optimum_quanto_available, QUANTO_IMPORT_ERROR)),
("pytorch_retinaface", (is_pytorch_retinaface_available, PYTORCH_RETINAFACE_IMPORT_ERROR)),
("better_profanity", (is_better_profanity_available, BETTER_PROFANITY_IMPORT_ERROR)),
("nltk", (is_nltk_available, NLTK_IMPORT_ERROR)),
]
)
+8
View File
@@ -38,6 +38,7 @@ from .import_utils import (
is_note_seq_available,
is_onnx_available,
is_opencv_available,
is_optimum_quanto_available,
is_peft_available,
is_timm_available,
is_torch_available,
@@ -486,6 +487,13 @@ def require_bitsandbytes(test_case):
return unittest.skipUnless(is_bitsandbytes_available(), "test requires bitsandbytes")(test_case)
def require_quanto(test_case):
"""
Decorator marking a test that requires quanto. These tests are skipped when quanto isn't installed.
"""
return unittest.skipUnless(is_optimum_quanto_available(), "test requires quanto")(test_case)
def require_accelerate(test_case):
"""
Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
+15 -14
View File
@@ -31,13 +31,14 @@ from diffusers import FlowMatchEulerDiscreteScheduler, FluxControlPipeline, Flux
from diffusers.utils import load_image, logging
from diffusers.utils.testing_utils import (
CaptureLogger,
backend_empty_cache,
floats_tensor,
is_peft_available,
nightly,
numpy_cosine_similarity_distance,
require_big_gpu_with_torch_cuda,
require_big_accelerator,
require_peft_backend,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
@@ -809,10 +810,10 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@slow
@nightly
@require_torch_gpu
@require_torch_accelerator
@require_peft_backend
@require_big_gpu_with_torch_cuda
@pytest.mark.big_gpu_with_torch_cuda
@require_big_accelerator
@pytest.mark.big_accelerator
class FluxLoRAIntegrationTests(unittest.TestCase):
"""internal note: The integration slices were obtained on audace.
@@ -827,7 +828,7 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
self.pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
@@ -836,13 +837,13 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
del self.pipeline
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def test_flux_the_last_ben(self):
self.pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors")
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
# Instead of calling `enable_model_cpu_offload()`, we do a cuda placement here because the CI
# Instead of calling `enable_model_cpu_offload()`, we do a accelerator placement here because the CI
# run supports it. We have about 34GB RAM in the CI runner which kills the test when run with
# `enable_model_cpu_offload()`. We repeat this for the other tests, too.
self.pipeline = self.pipeline.to(torch_device)
@@ -956,10 +957,10 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
@nightly
@require_torch_gpu
@require_torch_accelerator
@require_peft_backend
@require_big_gpu_with_torch_cuda
@pytest.mark.big_gpu_with_torch_cuda
@require_big_accelerator
@pytest.mark.big_accelerator
class FluxControlLoRAIntegrationTests(unittest.TestCase):
num_inference_steps = 10
seed = 0
@@ -969,17 +970,17 @@ class FluxControlLoRAIntegrationTests(unittest.TestCase):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
self.pipeline = FluxControlPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
).to("cuda")
).to(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
@parameterized.expand(["black-forest-labs/FLUX.1-Canny-dev-lora", "black-forest-labs/FLUX.1-Depth-dev-lora"])
def test_lora(self, lora_ckpt_id):
+17 -9
View File
@@ -28,13 +28,16 @@ from diffusers import (
HunyuanVideoTransformer3DModel,
)
from diffusers.utils.testing_utils import (
Expectations,
backend_empty_cache,
floats_tensor,
nightly,
numpy_cosine_similarity_distance,
require_big_gpu_with_torch_cuda,
require_big_accelerator,
require_peft_backend,
require_torch_gpu,
require_torch_accelerator,
skip_mps,
torch_device,
)
@@ -192,10 +195,10 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@nightly
@require_torch_gpu
@require_torch_accelerator
@require_peft_backend
@require_big_gpu_with_torch_cuda
@pytest.mark.big_gpu_with_torch_cuda
@require_big_accelerator
@pytest.mark.big_accelerator
class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
"""internal note: The integration slices were obtained on DGX.
@@ -210,7 +213,7 @@ class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
model_id = "hunyuanvideo-community/HunyuanVideo"
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
@@ -218,13 +221,13 @@ class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
)
self.pipeline = HunyuanVideoPipeline.from_pretrained(
model_id, transformer=transformer, torch_dtype=torch.float16
).to("cuda")
).to(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def test_original_format_cseti(self):
self.pipeline.load_lora_weights(
@@ -249,8 +252,13 @@ class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
out_slice = np.concatenate((out[:8], out[-8:]))
# fmt: off
expected_slice = np.array([0.1013, 0.1924, 0.0078, 0.1021, 0.1929, 0.0078, 0.1023, 0.1919, 0.7402, 0.104, 0.4482, 0.7354, 0.0925, 0.4382, 0.7275, 0.0815])
expected_slices = Expectations(
{
("cuda", 7): np.array([0.1013, 0.1924, 0.0078, 0.1021, 0.1929, 0.0078, 0.1023, 0.1919, 0.7402, 0.104, 0.4482, 0.7354, 0.0925, 0.4382, 0.7275, 0.0815]),
}
)
# fmt: on
expected_slice = expected_slices.get_expectation()
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
+2 -2
View File
@@ -93,12 +93,12 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
# Keeping this test here makes sense because it doesn't look any integration
# (value assertions on logits).
+3 -3
View File
@@ -34,7 +34,7 @@ from diffusers.utils.testing_utils import (
is_flaky,
nightly,
numpy_cosine_similarity_distance,
require_big_gpu_with_torch_cuda,
require_big_accelerator,
require_peft_backend,
require_torch_accelerator,
torch_device,
@@ -138,8 +138,8 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@nightly
@require_torch_accelerator
@require_peft_backend
@require_big_gpu_with_torch_cuda
@pytest.mark.big_gpu_with_torch_cuda
@require_big_accelerator
@pytest.mark.big_accelerator
class SD3LoraIntegrationTests(unittest.TestCase):
pipeline_class = StableDiffusion3Img2ImgPipeline
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"
+7 -6
View File
@@ -37,12 +37,13 @@ from diffusers.utils import logging
from diffusers.utils.import_utils import is_accelerate_available
from diffusers.utils.testing_utils import (
CaptureLogger,
backend_empty_cache,
is_flaky,
load_image,
nightly,
numpy_cosine_similarity_distance,
require_peft_backend,
require_torch_gpu,
require_torch_accelerator,
slow,
torch_device,
)
@@ -105,12 +106,12 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
@is_flaky
def test_multiple_wrong_adapter_name_raises_error(self):
@@ -119,18 +120,18 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
@slow
@nightly
@require_torch_gpu
@require_torch_accelerator
@require_peft_backend
class LoraSDXLIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
backend_empty_cache(torch_device)
def test_sdxl_1_0_lora(self):
generator = torch.Generator("cpu").manual_seed(0)
@@ -0,0 +1,86 @@
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from diffusers import AutoencoderKLCosmos
from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
enable_full_determinism()
class AutoencoderKLCosmosTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = AutoencoderKLCosmos
main_input_name = "sample"
base_precision = 1e-2
def get_autoencoder_kl_cosmos_config(self):
return {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 4,
"encoder_block_out_channels": (8, 8, 8, 8),
"decode_block_out_channels": (8, 8, 8, 8),
"attention_resolutions": (8,),
"resolution": 64,
"num_layers": 2,
"patch_size": 4,
"patch_type": "haar",
"scaling_factor": 1.0,
"spatial_compression_ratio": 4,
"temporal_compression_ratio": 4,
}
@property
def dummy_input(self):
batch_size = 2
num_frames = 9
num_channels = 3
height = 32
width = 32
image = floats_tensor((batch_size, num_channels, num_frames, height, width)).to(torch_device)
return {"sample": image}
@property
def input_shape(self):
return (3, 9, 32, 32)
@property
def output_shape(self):
return (3, 9, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_cosmos_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"CosmosEncoder3d",
"CosmosDecoder3d",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@unittest.skip("Not sure why this test fails. Investigate later.")
def test_effective_gradient_checkpointing(self):
pass
@unittest.skip("Unsupported test.")
def test_forward_with_norm_groups(self):
pass
+108 -70
View File
@@ -62,7 +62,6 @@ from diffusers.utils.testing_utils import (
backend_max_memory_allocated,
backend_reset_peak_memory_stats,
backend_synchronize,
floats_tensor,
get_python_version,
is_torch_compile,
numpy_cosine_similarity_distance,
@@ -1581,6 +1580,34 @@ class ModelTesterMixin:
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5))
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5))
@parameterized.expand([(False, "block_level"), (True, "leaf_level")])
@require_torch_accelerator
@torch.no_grad()
def test_group_offloading_with_layerwise_casting(self, record_stream, offload_type):
torch.manual_seed(0)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
if not getattr(model, "_supports_group_offloading", True):
return
model.to(torch_device)
model.eval()
_ = model(**inputs_dict)[0]
torch.manual_seed(0)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
storage_dtype, compute_dtype = torch.float16, torch.float32
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
model = self.model_class(**init_dict)
model.eval()
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1}
model.enable_group_offload(
torch_device, offload_type=offload_type, use_stream=True, record_stream=record_stream, **additional_kwargs
)
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
_ = model(**inputs_dict)[0]
def test_auto_model(self, expected_max_diff=5e-5):
if self.forward_requires_fresh_args:
model = self.model_class(**self.init_dict)
@@ -1754,7 +1781,7 @@ class TorchCompileTesterMixin:
@require_peft_backend
@require_peft_version_greater("0.14.0")
@is_torch_compile
class TestLoraHotSwappingForModel(unittest.TestCase):
class LoraHotSwappingForModelTesterMixin:
"""Test that hotswapping does not result in recompilation on the model directly.
We're not extensively testing the hotswapping functionality since it is implemented in PEFT and is extensively
@@ -1775,48 +1802,24 @@ class TestLoraHotSwappingForModel(unittest.TestCase):
gc.collect()
backend_empty_cache(torch_device)
def get_small_unet(self):
# from diffusers UNet2DConditionModelTests
torch.manual_seed(0)
init_dict = {
"block_out_channels": (4, 8),
"norm_num_groups": 4,
"down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
"up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"),
"cross_attention_dim": 8,
"attention_head_dim": 2,
"out_channels": 4,
"in_channels": 4,
"layers_per_block": 1,
"sample_size": 16,
}
model = UNet2DConditionModel(**init_dict)
return model.to(torch_device)
def get_unet_lora_config(self, lora_rank, lora_alpha, target_modules):
def get_lora_config(self, lora_rank, lora_alpha, target_modules):
# from diffusers test_models_unet_2d_condition.py
from peft import LoraConfig
unet_lora_config = LoraConfig(
lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
target_modules=target_modules,
init_lora_weights=False,
use_dora=False,
)
return unet_lora_config
return lora_config
def get_dummy_input(self):
# from UNet2DConditionModelTests
batch_size = 4
num_channels = 4
sizes = (16, 16)
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
def get_linear_module_name_other_than_attn(self, model):
linear_names = [
name for name, module in model.named_modules() if isinstance(module, nn.Linear) and "to_" not in name
]
return linear_names[0]
def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_modules1=None):
"""
@@ -1834,23 +1837,27 @@ class TestLoraHotSwappingForModel(unittest.TestCase):
fine.
"""
# create 2 adapters with different ranks and alphas
dummy_input = self.get_dummy_input()
torch.manual_seed(0)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
alpha0, alpha1 = rank0, rank1
max_rank = max([rank0, rank1])
if target_modules1 is None:
target_modules1 = target_modules0[:]
lora_config0 = self.get_unet_lora_config(rank0, alpha0, target_modules0)
lora_config1 = self.get_unet_lora_config(rank1, alpha1, target_modules1)
lora_config0 = self.get_lora_config(rank0, alpha0, target_modules0)
lora_config1 = self.get_lora_config(rank1, alpha1, target_modules1)
unet = self.get_small_unet()
unet.add_adapter(lora_config0, adapter_name="adapter0")
model.add_adapter(lora_config0, adapter_name="adapter0")
with torch.inference_mode():
output0_before = unet(**dummy_input)["sample"]
torch.manual_seed(0)
output0_before = model(**inputs_dict)["sample"]
unet.add_adapter(lora_config1, adapter_name="adapter1")
unet.set_adapter("adapter1")
model.add_adapter(lora_config1, adapter_name="adapter1")
model.set_adapter("adapter1")
with torch.inference_mode():
output1_before = unet(**dummy_input)["sample"]
torch.manual_seed(0)
output1_before = model(**inputs_dict)["sample"]
# sanity checks:
tol = 5e-3
@@ -1860,40 +1867,43 @@ class TestLoraHotSwappingForModel(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmp_dirname:
# save the adapter checkpoints
unet.save_lora_adapter(os.path.join(tmp_dirname, "0"), safe_serialization=True, adapter_name="adapter0")
unet.save_lora_adapter(os.path.join(tmp_dirname, "1"), safe_serialization=True, adapter_name="adapter1")
del unet
model.save_lora_adapter(os.path.join(tmp_dirname, "0"), safe_serialization=True, adapter_name="adapter0")
model.save_lora_adapter(os.path.join(tmp_dirname, "1"), safe_serialization=True, adapter_name="adapter1")
del model
# load the first adapter
unet = self.get_small_unet()
torch.manual_seed(0)
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
if do_compile or (rank0 != rank1):
# no need to prepare if the model is not compiled or if the ranks are identical
unet.enable_lora_hotswap(target_rank=max_rank)
model.enable_lora_hotswap(target_rank=max_rank)
file_name0 = os.path.join(os.path.join(tmp_dirname, "0"), "pytorch_lora_weights.safetensors")
file_name1 = os.path.join(os.path.join(tmp_dirname, "1"), "pytorch_lora_weights.safetensors")
unet.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
if do_compile:
unet = torch.compile(unet, mode="reduce-overhead")
model = torch.compile(model, mode="reduce-overhead")
with torch.inference_mode():
output0_after = unet(**dummy_input)["sample"]
output0_after = model(**inputs_dict)["sample"]
assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)
# hotswap the 2nd adapter
unet.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
# we need to call forward to potentially trigger recompilation
with torch.inference_mode():
output1_after = unet(**dummy_input)["sample"]
output1_after = model(**inputs_dict)["sample"]
assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol)
# check error when not passing valid adapter name
name = "does-not-exist"
msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name"
with self.assertRaisesRegex(ValueError, msg):
unet.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None)
model.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None)
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
def test_hotswapping_model(self, rank0, rank1):
@@ -1910,6 +1920,9 @@ class TestLoraHotSwappingForModel(unittest.TestCase):
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
def test_hotswapping_compiled_model_conv2d(self, rank0, rank1):
if "unet" not in self.model_class.__name__.lower():
return
# It's important to add this context to raise an error on recompilation
target_modules = ["conv", "conv1", "conv2"]
with torch._dynamo.config.patch(error_on_recompile=True):
@@ -1917,52 +1930,77 @@ class TestLoraHotSwappingForModel(unittest.TestCase):
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
def test_hotswapping_compiled_model_both_linear_and_conv2d(self, rank0, rank1):
if "unet" not in self.model_class.__name__.lower():
return
# It's important to add this context to raise an error on recompilation
target_modules = ["to_q", "conv"]
with torch._dynamo.config.patch(error_on_recompile=True):
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
def test_hotswapping_compiled_model_both_linear_and_other(self, rank0, rank1):
# In `test_hotswapping_compiled_model_both_linear_and_conv2d()`, we check if we can do hotswapping
# with `torch.compile()` for models that have both linear and conv layers. In this test, we check
# if we can target a linear layer from the transformer blocks and another linear layer from non-attention
# block.
target_modules = ["to_q"]
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
target_modules.append(self.get_linear_module_name_other_than_attn(model))
del model
# It's important to add this context to raise an error on recompilation
with torch._dynamo.config.patch(error_on_recompile=True):
self.check_model_hotswap(do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules)
def test_enable_lora_hotswap_called_after_adapter_added_raises(self):
# ensure that enable_lora_hotswap is called before loading the first adapter
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
unet = self.get_small_unet()
unet.add_adapter(lora_config)
lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.")
with self.assertRaisesRegex(RuntimeError, msg):
unet.enable_lora_hotswap(target_rank=32)
model.enable_lora_hotswap(target_rank=32)
def test_enable_lora_hotswap_called_after_adapter_added_warning(self):
# ensure that enable_lora_hotswap is called before loading the first adapter
from diffusers.loaders.peft import logger
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
unet = self.get_small_unet()
unet.add_adapter(lora_config)
lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
msg = (
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
)
with self.assertLogs(logger=logger, level="WARNING") as cm:
unet.enable_lora_hotswap(target_rank=32, check_compiled="warn")
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
assert any(msg in log for log in cm.output)
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self):
# check possibility to ignore the error/warning
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
unet = self.get_small_unet()
unet.add_adapter(lora_config)
lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always") # Capture all warnings
unet.enable_lora_hotswap(target_rank=32, check_compiled="warn")
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
self.assertEqual(len(w), 0, f"Expected no warnings, but got: {[str(warn.message) for warn in w]}")
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
# check that wrong argument value raises an error
lora_config = self.get_unet_lora_config(8, 8, target_modules=["to_q"])
unet = self.get_small_unet()
unet.add_adapter(lora_config)
lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.")
with self.assertRaisesRegex(ValueError, msg):
unet.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")
model.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")
def test_hotswap_second_adapter_targets_more_layers_raises(self):
# check the error and log

Some files were not shown because too many files have changed in this diff Show More