Compare commits

...

73 Commits

Author SHA1 Message Date
DN6 7e13c18b29 update 2025-04-08 20:38:58 +05:30
Linoy Tsaban 71f34fc5a4 [Flux LoRA] fix issues in flux lora scripts (#11111)
* remove custom scheduler

* update requirements.txt

* log_validation with mixed precision

* add intermediate embeddings saving when checkpointing is enabled

* remove comment

* fix validation

* add unwrap_model for accelerator, torch.no_grad context for validation, fix accelerator.accumulate call in advanced script

* revert unwrap_model change temp

* add .module to address distributed training bug + replace accelerator.unwrap_model with unwrap model

* changes to align advanced script with canonical script

* make changes for distributed training + unify unwrap_model calls in advanced script

* add module.dtype fix to dreambooth script

* unify unwrap_model calls in dreambooth script

* fix condition in validation run

* mixed precision

* Update examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* smol style change

* change autocast

* Apply style fixes

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-04-08 17:40:30 +03:00
Yao Matrix c51b6bd837 introduce compute arch specific expectations and fix test_sd3_img2img_inference failure (#11227)
* add arch specfic expectations support, to support different arch's numerical characteristics

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* fix typo

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* Apply suggestions from code review

* Apply style fixes

* Update src/diffusers/utils/testing_utils.py

---------

Signed-off-by: YAO Matrix <matrix.yao@intel.com>
Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-04-08 14:57:49 +01:00
Benjamin Bossan fb54499614 [LoRA] Implement hot-swapping of LoRA (#9453)
* [WIP][LoRA] Implement hot-swapping of LoRA

This PR adds the possibility to hot-swap LoRA adapters. It is WIP.

Description

As of now, users can already load multiple LoRA adapters. They can
offload existing adapters or they can unload them (i.e. delete them).
However, they cannot "hotswap" adapters yet, i.e. substitute the weights
from one LoRA adapter with the weights of another, without the need to
create a separate LoRA adapter.

Generally, hot-swapping may not appear not super useful but when the
model is compiled, it is necessary to prevent recompilation. See #9279
for more context.

Caveats

To hot-swap a LoRA adapter for another, these two adapters should target
exactly the same layers and the "hyper-parameters" of the two adapters
should be identical. For instance, the LoRA alpha has to be the same:
Given that we keep the alpha from the first adapter, the LoRA scaling
would be incorrect for the second adapter otherwise.

Theoretically, we could override the scaling dict with the alpha values
derived from the second adapter's config, but changing the dict will
trigger a guard for recompilation, defeating the main purpose of the
feature.

I also found that compilation flags can have an impact on whether this
works or not. E.g. when passing "reduce-overhead", there will be errors
of the type:

> input name: arg861_1. data pointer changed from 139647332027392 to
139647331054592

I don't know enough about compilation to determine whether this is
problematic or not.

Current state

This is obviously WIP right now to collect feedback and discuss which
direction to take this. If this PR turns out to be useful, the
hot-swapping functions will be added to PEFT itself and can be imported
here (or there is a separate copy in diffusers to avoid the need for a
min PEFT version to use this feature).

Moreover, more tests need to be added to better cover this feature,
although we don't necessarily need tests for the hot-swapping
functionality itself, since those tests will be added to PEFT.

Furthermore, as of now, this is only implemented for the unet. Other
pipeline components have yet to implement this feature.

Finally, it should be properly documented.

I would like to collect feedback on the current state of the PR before
putting more time into finalizing it.

* Reviewer feedback

* Reviewer feedback, adjust test

* Fix, doc

* Make fix

* Fix for possible g++ error

* Add test for recompilation w/o hotswapping

* Make hotswap work

Requires https://github.com/huggingface/peft/pull/2366

More changes to make hotswapping work. Together with the mentioned PEFT
PR, the tests pass for me locally.

List of changes:

- docstring for hotswap
- remove code copied from PEFT, import from PEFT now
- adjustments to PeftAdapterMixin.load_lora_adapter (unfortunately, some
  state dict renaming was necessary, LMK if there is a better solution)
- adjustments to UNet2DConditionLoadersMixin._process_lora: LMK if this
  is even necessary or not, I'm unsure what the overall relationship is
  between this and PeftAdapterMixin.load_lora_adapter
- also in UNet2DConditionLoadersMixin._process_lora, I saw that there is
  no LoRA unloading when loading the adapter fails, so I added it
  there (in line with what happens in PeftAdapterMixin.load_lora_adapter)
- rewritten tests to avoid shelling out, make the test more precise by
  making sure that the outputs align, parametrize it
- also checked the pipeline code mentioned in this comment:
  https://github.com/huggingface/diffusers/pull/9453#issuecomment-2418508871;
  when running this inside the with
  torch._dynamo.config.patch(error_on_recompile=True) context, there is
  no error, so I think hotswapping is now working with pipelines.

* Address reviewer feedback:

- Revert deprecated method
- Fix PEFT doc link to main
- Don't use private function
- Clarify magic numbers
- Add pipeline test

Moreover:
- Extend docstrings
- Extend existing test for outputs != 0
- Extend existing test for wrong adapter name

* Change order of test decorators

parameterized.expand seems to ignore skip decorators if added in last
place (i.e. innermost decorator).

* Split model and pipeline tests

Also increase test coverage by also targeting conv2d layers (support of
which was added recently on the PEFT PR).

* Reviewer feedback: Move decorator to test classes

... instead of having them on each test method.

* Apply suggestions from code review

Co-authored-by: hlky <hlky@hlky.ac>

* Reviewer feedback: version check, TODO comment

* Add enable_lora_hotswap method

* Reviewer feedback: check _lora_loadable_modules

* Revert changes in unet.py

* Add possibility to ignore enabled at wrong time

* Fix docstrings

* Log possible PEFT error, test

* Raise helpful error if hotswap not supported

I.e. for the text encoder

* Formatting

* More linter

* More ruff

* Doc-builder complaint

* Update docstring:

- mention no text encoder support yet
- make it clear that LoRA is meant
- mention that same adapter name should be passed

* Fix error in docstring

* Update more methods with hotswap argument

- SDXL
- SD3
- Flux

No changes were made to load_lora_into_transformer.

* Add hotswap argument to load_lora_into_transformer

For SD3 and Flux. Use shorter docstring for brevity.

* Extend docstrings

* Add version guards to tests

* Formatting

* Fix LoRA loading call to add prefix=None

See:
https://github.com/huggingface/diffusers/pull/10187#issuecomment-2717571064

* Run make fix-copies

* Add hot swap documentation to the docs

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2025-04-08 17:05:31 +05:30
Álvaro Somoza 723dbdd363 [Training] Better image interpolation in training scripts (#11206)
* initial

* Update examples/dreambooth/train_dreambooth_lora_sdxl.py

Co-authored-by: hlky <hlky@hlky.ac>

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: hlky <hlky@hlky.ac>
2025-04-08 12:26:07 +05:30
Bhavay Malhotra fbf61f465b [train_controlnet.py] Fix the LR schedulers when num_train_epochs is passed in a distributed training env (#8461)
* Create diffusers.yml

* fix num_train_epochs

* Delete diffusers.yml

* Fixed Changes

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-04-08 12:10:09 +05:30
Inigo Goiri 841504bb1a Add support to pass image embeddings to the WAN I2V pipeline. (#11175)
* Add support to pass image embeddings to the pipeline.



---------

Co-authored-by: hlky <hlky@hlky.ac>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-04-07 15:47:06 -10:00
Steven Liu fc7a867ae5 [docs] MPS update (#11212)
mps
2025-04-07 14:32:27 -10:00
alex choi 5ded26cdc7 ensure dtype match between diffused latents and vae weights (#8391) 2025-04-07 12:59:10 -10:00
Yao Matrix 506f39af3a enable 1 case on XPU (#11219)
enable case on XPU: 1. tests/quantization/bnb/test_mixed_int8.py::BnB8bitTrainingTests::test_training

Signed-off-by: YAO Matrix <matrix.yao@intel.com>
2025-04-07 08:24:21 +01:00
Mikko Tukiainen 8ad68c1393 Add missing MochiEncoder3D.gradient_checkpointing attribute (#11146)
* Add missing 'gradient_checkpointing = False' attr

* Add (limited) tests for Mochi autoencoder

* Apply style fixes

* pass 'conv_cache' as arg instead of kwarg

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-04-06 02:46:45 +05:30
Edna 41afb6690c Add Wan with STG as a community pipeline (#11184)
* Add stg wan to community pipelines

* remove debug prints

* remove unused comment

* Update doc

* Add credit + fix typo

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-04-05 04:00:40 +02:00
Tolga Cangöz 13e48492f0 [LTX0.9.5] Refactor LTXConditionPipeline for text-only conditioning (#11174)
* Refactor `LTXConditionPipeline` to add text-only conditioning

* style

* up

* Refactor `LTXConditionPipeline` to streamline condition handling and improve clarity

* Improve condition checks

* Simplify latents handling based on conditioning type

* Refactor rope_interpolation_scale preparation for clarity and efficiency

* Update LTXConditionPipeline docstring to clarify supported input types

* Add LTX Video 0.9.5 model to documentation

* Clarify documentation to indicate support for text-only conditioning without passing `conditions`

* refactor: comment out unused parameters in LTXConditionPipeline

* fix: restore previously commented parameters in LTXConditionPipeline

* fix: remove unused parameters from LTXConditionPipeline

* refactor: remove unnecessary lines in LTXConditionPipeline
2025-04-04 16:43:15 +02:00
Suprhimp 94f2c48d58 [feat]Add strength in flux_fill pipeline (denoising strength for fluxfill) (#10603)
* [feat]add strength in flux_fill pipeline

* Update src/diffusers/pipelines/flux/pipeline_flux_fill.py

* Update src/diffusers/pipelines/flux/pipeline_flux_fill.py

* Update src/diffusers/pipelines/flux/pipeline_flux_fill.py

* [refactor] refactor after review

* [fix] change comment

* Apply style fixes

* empty

* fix

* update prepare_latents from flux.img2img pipeline

* style

* Update src/diffusers/pipelines/flux/pipeline_flux_fill.py

---------
2025-04-04 11:23:30 -03:00
Dhruv Nair aabf8ce20b Fix Single File loading for LTX VAE (#11200)
update
2025-04-04 18:02:39 +05:30
Kenneth Gerald Hamilton f10775b1b5 Fixed requests.get function call by adding timeout parameter. (#11156)
* Fixed requests.get function call by adding timeout parameter.

* declare DIFFUSERS_REQUEST_TIMEOUT in constants and import when needed

* remove unneeded os import

* Apply style fixes

---------

Co-authored-by: Sai-Suraj-27 <sai.suraj.27.729@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-04-04 07:23:14 +01:00
célina 6edb774b5e Update Style Bot workflow (#11202)
update style bot workflow
2025-04-03 19:31:49 +02:00
Basile Lewandowski 480510ada9 Change KolorsPipeline LoRA Loader to StableDiffusion (#11198)
Change LoRA Loader to StableDiffusion

Replace the SDXL LoRA Loader Mixin inheritance with the StableDiffusion one
2025-04-03 11:21:11 -03:00
Abhipsha Das d9023a671a [Model Card] standardize advanced diffusion training sdxl lora (#7615)
* model card gen code

* push modelcard creation

* remove optional from params

* add import

* add use_dora check

* correct lora var use in tags

* make style && make quality

---------

Co-authored-by: Aryan <aryan@huggingface.co>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-04-03 07:43:01 +05:30
Eliseu Silva c4646a3931 feat: [Community Pipeline] - FaithDiff Stable Diffusion XL Pipeline (#11188)
* feat: [Community Pipeline] - FaithDiff Stable Diffusion XL Pipeline for Image SR.

* added pipeline
2025-04-02 11:33:19 -10:00
Dhruv Nair c97b709afa Add CacheMixin to Wan and LTX Transformers (#11187)
* update

* update

* update
2025-04-02 10:16:31 -10:00
lakshay sharma b0ff822ed3 Update import_utils.py (#10329)
added onnxruntime-vitisai for custom build onnxruntime pkg
2025-04-02 20:47:10 +01:00
hlky 78c2fdc52e SchedulerMixin from_pretrained and ConfigMixin Self type annotation (#11192) 2025-04-02 08:24:02 -10:00
hlky 54dac3a87c Fix enable_sequential_cpu_offload in CogView4Pipeline (#11195)
* Fix enable_sequential_cpu_offload in CogView4Pipeline

* make fix-copies
2025-04-02 16:51:23 +01:00
hlky e5c6027ef8 [docs] torch_dtype map (#11194) 2025-04-02 12:46:28 +01:00
hlky da857bebb6 Revert save_model in ModelMixin save_pretrained and use safe_serialization=False in test (#11196) 2025-04-02 12:45:36 +01:00
Fanli Lin 52b460feb9 [tests] HunyuanDiTControlNetPipeline inference precision issue on XPU (#11197)
* add xpu part

* fix more cases

* remove some cases

* no canny

* format fix
2025-04-02 12:45:02 +01:00
hlky d8c617ccb0 allow models to run with a user-provided dtype map instead of a single dtype (#10301)
* allow models to run with a user-provided dtype map instead of a single dtype

* make style

* Add warning, change `_` to `default`

* make style

* add test

* handle shared tensors

* remove warning

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-04-02 09:05:46 +01:00
Bruno Magalhaes fe2b397426 remove unnecessary call to F.pad (#10620)
* rewrite memory count without implicitly using dimensions by @ic-synth

* replace F.pad by built-in padding in Conv3D

* in-place sums to reduce memory allocations

* fixed trailing whitespace

* file reformatted

* in-place sums

* simpler in-place expressions

* removed in-place sum, may affect backward propagation logic

* removed in-place sum, may affect backward propagation logic

* removed in-place sum, may affect backward propagation logic

* reverted change
2025-04-02 08:19:51 +01:00
Eliseu Silva be0b7f55cc fix: for checking mandatory and optional pipeline components (#11189)
fix: optional componentes verification on load
2025-04-02 08:07:24 +01:00
jiqing-feng 4d5a96e40a fix autocast (#11190)
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
2025-04-02 07:26:27 +01:00
Yao Matrix a7f07c1ef5 map BACKEND_RESET_MAX_MEMORY_ALLOCATED to reset_peak_memory_stats on XPU (#11191)
Signed-off-by: YAO Matrix <matrix.yao@intel.com>
2025-04-02 07:25:48 +01:00
Dhruv Nair df1d7b01f1 [WIP] Add Wan Video2Video (#11053)
* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update
2025-04-01 17:22:11 +05:30
Fanli Lin 5a6edac087 [tests] no hard-coded cuda (#11186)
no cuda only
2025-04-01 12:14:31 +01:00
kakukakujirori e8fc8b1f81 Bug fix in LTXImageToVideoPipeline.prepare_latents() when latents is already set (#10918)
* Bug fix in ltx

* Assume packed latents.

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-03-31 12:15:43 -10:00
hlky d6f4774c1c Add latents_mean and latents_std to SDXLLongPromptWeightingPipeline (#11034) 2025-03-31 11:32:29 -10:00
Mark eb50defff2 [Docs] Fix environment variables in installation.md (#11179) 2025-03-31 09:15:25 -07:00
Aryan 2c59af7222 Raise warning and round down if Wan num_frames is not 4k + 1 (#11167)
* update

* raise warning and round to nearest multiple of scale factor
2025-03-31 13:33:28 +05:30
hlky 75d7e5cc45 Fix LatteTransformer3DModel dtype mismatch with enable_temporal_attentions (#11139) 2025-03-29 15:52:56 +01:00
Dhruv Nair 617c208bb4 [Docs] Update Wan Docs with memory optimizations (#11089)
* update

* update
2025-03-28 19:05:56 +05:30
hlky 5d970a4aa9 WanI2V encode_image (#11164)
* WanI2V encode_image
2025-03-28 18:05:34 +05:30
kentdan3msu de6a88c2d7 Set self._hf_peft_config_loaded to True when LoRA is loaded using load_lora_adapter in PeftAdapterMixin class (#11155)
set self._hf_peft_config_loaded to True on successful lora load

Sets the `_hf_peft_config_loaded` flag if a LoRA is successfully loaded in `load_lora_adapter`. Fixes bug huggingface/diffusers/issues/11148

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-03-26 18:31:18 +01:00
Dhruv Nair 7dc52ea769 [Quantization] dtype fix for GGUF + fix BnB tests (#11159)
* update

* update

* update

* update
2025-03-26 22:22:16 +05:30
Junsong Chen 739d6ec731 add a timestep scale for sana-sprint teacher model (#11150) 2025-03-25 08:47:39 -10:00
Aryan 1ddf3f3a19 Improve information about group offloading and layerwise casting (#11101)
* update

* Update docs/source/en/optimization/memory.md

* Apply suggestions from code review

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* apply review suggestions

* update

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2025-03-24 23:25:59 +05:30
Jun Yeop Na 7aac77affa [doc] Fix Korean Controlnet Train doc (#11141)
* remove typo from korean controlnet train doc

* removed more paragraphs to remain in sync with the english document
2025-03-24 09:38:21 -07:00
Aryan 8907a70a36 New HunyuanVideo-I2V (#11066)
* update

* update

* update

* add tests

* update docs

* raise value error

* warning for true cfg and guidance scale

* fix test
2025-03-24 21:18:40 +05:30
Junsong Chen 5dbe4f5de6 [fix SANA-Sprint] (#11142)
* fix bug in sana conversion script;

* add more model paths;

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-03-23 23:38:14 -10:00
Yuxuan Zhang 1d37f42055 Modify the implementation of retrieve_timesteps in CogView4-Control. (#11125)
* 1

* change to channel 1

* cogview4 control training

* add CacheMixin

* 1

* remove initial_input_channels change for val

* 1

* update

* use 3.5

* new loss

* 1

* use imagetoken

* for megatron convert

* 1

* train con and uc

* 2

* remove guidance_scale

* Update pipeline_cogview4_control.py

* fix

* use cogview4 pipeline with timestep

* update shift_factor

* remove the uncond

* add max length

* change convert and use GLMModel instead of GLMForCasualLM

* fix

* [cogview4] Add attention mask support to transformer model

* [fix] Add attention mask for padded token

* update

* remove padding type

* Update train_control_cogview4.py

* resolve conflicts with #10981

* add control convert

* use control format

* fix

* add missing import

* update with cogview4 formate

* make style

* Update pipeline_cogview4_control.py

* Update pipeline_cogview4_control.py

* remove

* Update pipeline_cogview4_control.py

* put back

* Apply style fixes

---------

Co-authored-by: OleehyO <leehy0357@gmail.com>
Co-authored-by: yiyixuxu <yixu310@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-03-23 21:17:14 +05:30
Tolga Cangöz 0213179ba8 Update README and example code for AnyText usage (#11028)
* [Documentation] Update README and example code with additional usage instructions for AnyText

* [Documentation] Update README for AnyTextPipeline and improve logging in code

* Remove wget command for font file from example docstring in anytext.py
2025-03-23 21:15:57 +05:30
hlky a7d53a5939 Don't override torch_dtype and don't use when quantization_config is set (#11039)
* Don't use `torch_dtype` when `quantization_config` is set

* up

* djkajka

* Apply suggestions from code review

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-03-21 21:58:38 +05:30
YiYi Xu 8a63aa5e4f add sana-sprint (#11074)
* add sana-sprint




---------

Co-authored-by: Junsong Chen <cjs1020440147@icloud.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Aryan <aryan@huggingface.co>
2025-03-21 06:21:18 -10:00
Aryan 844221ae4e [core] FasterCache (#10163)
* init

* update

* update

* update

* make style

* update

* fix

* make it work with guidance distilled models

* update

* make fix-copies

* add tests

* update

* apply_faster_cache -> apply_fastercache

* fix

* reorder

* update

* refactor

* update docs

* add fastercache to CacheMixin

* update tests

* Apply suggestions from code review

* make style

* try to fix partial import error

* Apply style fixes

* raise warning

* update

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-03-21 09:35:04 +05:30
CyberVy 9b2c0a7dbe fix _callback_tensor_inputs of sd controlnet inpaint pipeline missing some elements (#11073)
* Update pipeline_controlnet_inpaint.py

* Apply style fixes
2025-03-20 23:56:12 -03:00
Parag Ekbote f424b1b062 Notebooks for Community Scripts-8 (#11128)
Add 4 Notebooks and update the missing links for the
example README.
2025-03-20 12:24:46 -07:00
YiYi Xu e9fda3924f remove F.rms_norm for now (#11126)
up
2025-03-20 07:55:01 -10:00
Dhruv Nair 2c1ed50fc5 Provide option to reduce CPU RAM usage in Group Offload (#11106)
* update

* update

* clean up
2025-03-20 17:01:09 +05:30
Fanli Lin 15ad97f782 [tests] make cuda only tests device-agnostic (#11058)
* enable bnb on xpu

* add 2 more cases

* add missing change

* add missing change

* add one more

* enable cuda only tests on xpu

* enable big gpu cases
2025-03-20 10:12:35 +00:00
hlky 9f2d5c9ee9 Flux with Remote Encode (#11091)
* Flux img2img remote encode

* Flux inpaint

* -copied from
2025-03-20 09:44:08 +00:00
Junsong Chen dc62e6931e [fix bug] PixArt inference_steps=1 (#11079)
* fix bug when pixart-dmd inference with `num_inference_steps=1`

* use return_dict=False and return [1] element for 1-step pixart model, which works for both lcm and dmd
2025-03-20 07:44:30 +00:00
Fanli Lin 56f740051d [tests] enable bnb tests on xpu (#11001)
* enable bnb on xpu

* add 2 more cases

* add missing change

* add missing change

* add one more
2025-03-19 16:33:11 +00:00
Linoy Tsaban a34d97cef0 [Wan LoRAs] make T2V LoRAs compatible with Wan I2V (#11107)
* @hlky t2v->i2v

* Apply style fixes

* try with ones to not nullify layers

* fix method name

* revert to zeros

* add check to state_dict keys

* add comment

* copies fix

* Revert "copies fix"

This reverts commit 051f534d18.

* remove copied from

* Update src/diffusers/loaders/lora_pipeline.py

Co-authored-by: hlky <hlky@hlky.ac>

* Update src/diffusers/loaders/lora_pipeline.py

Co-authored-by: hlky <hlky@hlky.ac>

* update

* update

* Update src/diffusers/loaders/lora_pipeline.py

Co-authored-by: hlky <hlky@hlky.ac>

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Linoy <linoy@hf.co>
Co-authored-by: hlky <hlky@hlky.ac>
2025-03-19 21:44:19 +05:30
Yuqian Hong fc28791fc8 [BUG] Fix Autoencoderkl train script (#11113)
* add disc_optimizer step (not fix)

* support syncbatchnorm in discriminator
2025-03-19 16:49:02 +05:30
Sayak Paul ae14612673 [CI] uninstall deps properly from pr gpu tests. (#11102)
uninstall deps properly from pr gpu tests.
2025-03-19 08:58:36 +05:30
hlky 0ab8fe49bf Quality options in export_to_video (#11090)
* Quality options in `export_to_video`

* make style
2025-03-18 10:32:33 -10:00
Aryan 3be6706018 Fix Group offloading behaviour when using streams (#11097)
* update

* update
2025-03-18 14:44:10 +05:30
Cheng Jin cb1b8b21b8 Resolve stride mismatch in UNet's ResNet to support Torch DDP (#11098)
Modify UNet's ResNet implementation to resolve stride mismatch in Torch's DDP
2025-03-18 07:38:13 +00:00
Juan Acevedo 27916822b2 update readme instructions. (#11096)
Co-authored-by: Juan Acevedo <jfacevedo@google.com>
2025-03-17 20:07:48 -10:00
co63oc 3fe3bc0642 Fix pipeline_flux_controlnet.py (#11095)
* Fix pipeline_flux_controlnet.py

* Fix style
2025-03-17 19:52:15 -10:00
Aryan 813d42cc96 Group offloading improvements (#11094)
update
2025-03-18 11:18:00 +05:30
Sayak Paul b4d7e9c632 make PR GPU tests conditioned on styling. (#11099) 2025-03-18 11:15:35 +05:30
Aryan 2e83cbbb6d LTX 0.9.5 (#10968)
* update


---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: hlky <hlky@hlky.ac>
2025-03-17 16:43:36 -10:00
C 33d10af28f Fix Wan I2V Quality (#11087)
* fix_wan_i2v_quality

* Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update pipeline_wan_i2v.py

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: hlky <hlky@hlky.ac>
2025-03-17 06:24:57 -10:00
151 changed files with 12626 additions and 988 deletions
-34
View File
@@ -13,39 +13,5 @@ jobs:
uses: huggingface/huggingface_hub/.github/workflows/style-bot-action.yml@main
with:
python_quality_dependencies: "[quality]"
pre_commit_script_name: "Download and Compare files from the main branch"
pre_commit_script: |
echo "Downloading the files from the main branch"
curl -o main_Makefile https://raw.githubusercontent.com/huggingface/diffusers/main/Makefile
curl -o main_setup.py https://raw.githubusercontent.com/huggingface/diffusers/refs/heads/main/setup.py
curl -o main_check_doc_toc.py https://raw.githubusercontent.com/huggingface/diffusers/refs/heads/main/utils/check_doc_toc.py
echo "Compare the files and raise error if needed"
diff_failed=0
if ! diff -q main_Makefile Makefile; then
echo "Error: The Makefile has changed. Please ensure it matches the main branch."
diff_failed=1
fi
if ! diff -q main_setup.py setup.py; then
echo "Error: The setup.py has changed. Please ensure it matches the main branch."
diff_failed=1
fi
if ! diff -q main_check_doc_toc.py utils/check_doc_toc.py; then
echo "Error: The utils/check_doc_toc.py has changed. Please ensure it matches the main branch."
diff_failed=1
fi
if [ $diff_failed -eq 1 ]; then
echo "❌ Error happened as we detected changes in the files that should not be changed ❌"
exit 1
fi
echo "No changes in the files. Proceeding..."
rm -rf main_Makefile main_setup.py main_check_doc_toc.py
style_command: "make style && make quality"
secrets:
bot_token: ${{ secrets.GITHUB_TOKEN }}
+47 -1
View File
@@ -28,7 +28,51 @@ env:
PIPELINE_USAGE_CUTOFF: 1000000000 # set high cutoff so that only always-test pipelines run
jobs:
check_code_quality:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.8"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[quality]
- name: Check quality
run: make quality
- name: Check if failure
if: ${{ failure() }}
run: |
echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'" >> $GITHUB_STEP_SUMMARY
check_repository_consistency:
needs: check_code_quality
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.8"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install .[quality]
- name: Check repo consistency
run: |
python utils/check_copies.py
python utils/check_dummies.py
python utils/check_support_list.py
make deps_table_check_updated
- name: Check if failure
if: ${{ failure() }}
run: |
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
setup_torch_cuda_pipeline_matrix:
needs: [check_code_quality, check_repository_consistency]
name: Setup Torch Pipelines CUDA Slow Tests Matrix
runs-on:
group: aws-general-8-plus
@@ -133,6 +177,7 @@ jobs:
torch_cuda_tests:
name: Torch CUDA Tests
needs: [check_code_quality, check_repository_consistency]
runs-on:
group: aws-g4dn-2xlarge
container:
@@ -201,7 +246,7 @@ jobs:
run_examples_tests:
name: Examples PyTorch CUDA tests on Ubuntu
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
needs: [check_code_quality, check_repository_consistency]
runs-on:
group: aws-g4dn-2xlarge
@@ -220,6 +265,7 @@ jobs:
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
pip uninstall transformers -y && python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
python -m uv pip install -e [quality,test,training]
- name: Environment
+2
View File
@@ -496,6 +496,8 @@
title: PixArt-Σ
- local: api/pipelines/sana
title: Sana
- local: api/pipelines/sana_sprint
title: Sana Sprint
- local: api/pipelines/self_attention_guidance
title: Self-Attention Guidance
- local: api/pipelines/semantic_stable_diffusion
+33
View File
@@ -38,6 +38,33 @@ config = PyramidAttentionBroadcastConfig(
pipe.transformer.enable_cache(config)
```
## Faster Cache
[FasterCache](https://huggingface.co/papers/2410.19355) from Zhengyao Lv, Chenyang Si, Junhao Song, Zhenyu Yang, Yu Qiao, Ziwei Liu, Kwan-Yee K. Wong.
FasterCache is a method that speeds up inference in diffusion transformers by:
- Reusing attention states between successive inference steps, due to high similarity between them
- Skipping unconditional branch prediction used in classifier-free guidance by revealing redundancies between unconditional and conditional branch outputs for the same timestep, and therefore approximating the unconditional branch output using the conditional branch output
```python
import torch
from diffusers import CogVideoXPipeline, FasterCacheConfig
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")
config = FasterCacheConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(-1, 681),
current_timestep_callback=lambda: pipe.current_timestep,
attention_weight_callback=lambda _: 0.3,
unconditional_batch_skip_range=5,
unconditional_batch_timestep_skip_range=(-1, 781),
tensor_format="BFCHW",
)
pipe.transformer.enable_cache(config)
```
### CacheMixin
[[autodoc]] CacheMixin
@@ -47,3 +74,9 @@ pipe.transformer.enable_cache(config)
[[autodoc]] PyramidAttentionBroadcastConfig
[[autodoc]] apply_pyramid_attention_broadcast
### FasterCacheConfig
[[autodoc]] FasterCacheConfig
[[autodoc]] apply_faster_cache
@@ -14,6 +14,7 @@ specific language governing permissions and limitations under the License.
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
</div>
## Overview
+1
View File
@@ -14,6 +14,7 @@ specific language governing permissions and limitations under the License.
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
</div>
Flux is a series of text-to-image generation models based on diffusion transformers. To know more about Flux, check out the original [blog post](https://blackforestlabs.ai/announcing-black-forest-labs/) by the creators of Flux, Black Forest Labs.
@@ -50,7 +50,8 @@ The following models are available for the image-to-video pipeline:
| Model name | Description |
|:---|:---|
| [`Skywork/SkyReels-V1-Hunyuan-I2V`](https://huggingface.co/Skywork/SkyReels-V1-Hunyuan-I2V) | Skywork's custom finetune of HunyuanVideo (de-distilled). Performs best with `97x544x960` resolution. Performs best at `97x544x960` resolution, `guidance_scale=1.0`, `true_cfg_scale=6.0` and a negative prompt. |
| [`hunyuanvideo-community/HunyuanVideo-I2V`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20) |
| [`hunyuanvideo-community/HunyuanVideo-I2V-33ch`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo 33-channel I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20). |
| [`hunyuanvideo-community/HunyuanVideo-I2V`](https://huggingface.co/hunyuanvideo-community/HunyuanVideo-I2V) | Tecent's official HunyuanVideo 16-channel I2V model. Performs best at resolutions of 480, 720, 960, 1280. A higher `shift` value when initializing the scheduler is recommended (good values are between 7 and 20) |
## Quantization
+1
View File
@@ -14,6 +14,7 @@ specific language governing permissions and limitations under the License.
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
</div>
![](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/kolors/kolors_header_collage.png)
@@ -16,6 +16,7 @@
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
</div>
[LTX Video](https://huggingface.co/Lightricks/LTX-Video) is the first DiT-based video generation model capable of generating high-quality videos in real-time. It produces 24 FPS videos at a 768x512 resolution faster than they can be watched. Trained on a large-scale dataset of diverse videos, the model generates high-resolution videos with realistic and varied content. We provide a model for both text-to-video as well as image + text-to-video usecases.
@@ -32,6 +33,7 @@ Available models:
|:-------------:|:-----------------:|
| [`LTX Video 0.9.0`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.safetensors) | `torch.bfloat16` |
| [`LTX Video 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` |
| [`LTX Video 0.9.5`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.5.safetensors) | `torch.bfloat16` |
Note: The recommended dtype is for the transformer component. The VAE and text encoders can be either `torch.float32`, `torch.bfloat16` or `torch.float16` but the recommended dtype is `torch.bfloat16` as used in the original repository.
@@ -196,6 +198,12 @@ export_to_video(video, "ship.mp4", fps=24)
- all
- __call__
## LTXConditionPipeline
[[autodoc]] LTXConditionPipeline
- all
- __call__
## LTXPipelineOutput
[[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput
+1
View File
@@ -16,6 +16,7 @@
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
</div>
[SANA: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformers](https://huggingface.co/papers/2410.10629) from NVIDIA and MIT HAN Lab, by Enze Xie, Junsong Chen, Junyu Chen, Han Cai, Haotian Tang, Yujun Lin, Zhekai Zhang, Muyang Li, Ligeng Zhu, Yao Lu, Song Han.
+100
View File
@@ -0,0 +1,100 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
# SanaSprintPipeline
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
</div>
[SANA-Sprint: One-Step Diffusion with Continuous-Time Consistency Distillation](https://huggingface.co/papers/2503.09641) from NVIDIA, MIT HAN Lab, and Hugging Face by Junsong Chen, Shuchen Xue, Yuyang Zhao, Jincheng Yu, Sayak Paul, Junyu Chen, Han Cai, Enze Xie, Song Han
The abstract from the paper is:
*This paper presents SANA-Sprint, an efficient diffusion model for ultra-fast text-to-image (T2I) generation. SANA-Sprint is built on a pre-trained foundation model and augmented with hybrid distillation, dramatically reducing inference steps from 20 to 1-4. We introduce three key innovations: (1) We propose a training-free approach that transforms a pre-trained flow-matching model for continuous-time consistency distillation (sCM), eliminating costly training from scratch and achieving high training efficiency. Our hybrid distillation strategy combines sCM with latent adversarial distillation (LADD): sCM ensures alignment with the teacher model, while LADD enhances single-step generation fidelity. (2) SANA-Sprint is a unified step-adaptive model that achieves high-quality generation in 1-4 steps, eliminating step-specific training and improving efficiency. (3) We integrate ControlNet with SANA-Sprint for real-time interactive image generation, enabling instant visual feedback for user interaction. SANA-Sprint establishes a new Pareto frontier in speed-quality tradeoffs, achieving state-of-the-art performance with 7.59 FID and 0.74 GenEval in only 1 step — outperforming FLUX-schnell (7.94 FID / 0.71 GenEval) while being 10× faster (0.1s vs 1.1s on H100). It also achieves 0.1s (T2I) and 0.25s (ControlNet) latency for 1024×1024 images on H100, and 0.31s (T2I) on an RTX 4090, showcasing its exceptional efficiency and potential for AI-powered consumer applications (AIPC). Code and pre-trained models will be open-sourced.*
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
This pipeline was contributed by [lawrence-cj](https://github.com/lawrence-cj), [shuchen Xue](https://github.com/scxue) and [Enze Xie](https://github.com/xieenze). The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://huggingface.co/Efficient-Large-Model/).
Available models:
| Model | Recommended dtype |
|:-------------------------------------------------------------------------------------------------------------------------------------------:|:-----------------:|
| [`Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers) | `torch.bfloat16` |
| [`Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers) | `torch.bfloat16` |
Refer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-sprint-67d6810d65235085b3b17c76) collection for more information.
Note: The recommended dtype mentioned is for the transformer weights. The text encoder must stay in `torch.bfloat16` and VAE weights must stay in `torch.bfloat16` or `torch.float32` for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype.
## Quantization
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`SanaSprintPipeline`] for inference with bitsandbytes.
```py
import torch
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SanaTransformer2DModel, SanaSprintPipeline
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModel
quant_config = BitsAndBytesConfig(load_in_8bit=True)
text_encoder_8bit = AutoModel.from_pretrained(
"Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
subfolder="text_encoder",
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
)
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
transformer_8bit = SanaTransformer2DModel.from_pretrained(
"Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
subfolder="transformer",
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
)
pipeline = SanaSprintPipeline.from_pretrained(
"Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
text_encoder=text_encoder_8bit,
transformer=transformer_8bit,
torch_dtype=torch.bfloat16,
device_map="balanced",
)
prompt = "a tiny astronaut hatching from an egg on the moon"
image = pipeline(prompt).images[0]
image.save("sana.png")
```
## Setting `max_timesteps`
Users can tweak the `max_timesteps` value for experimenting with the visual quality of the generated outputs. The default `max_timesteps` value was obtained with an inference-time search process. For more details about it, check out the paper.
## SanaSprintPipeline
[[autodoc]] SanaSprintPipeline
- all
- __call__
## SanaPipelineOutput
[[autodoc]] pipelines.sana.pipeline_output.SanaPipelineOutput
@@ -14,6 +14,7 @@ specific language governing permissions and limitations under the License.
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
</div>
Stable Diffusion 3 (SD3) was proposed in [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/pdf/2403.03206.pdf) by Patrick Esser, Sumith Kulal, Andreas Blattmann, Rahim Entezari, Jonas Muller, Harry Saini, Yam Levi, Dominik Lorenz, Axel Sauer, Frederic Boesel, Dustin Podell, Tim Dockhorn, Zion English, Kyle Lacey, Alex Goodwin, Yannik Marek, and Robin Rombach.
@@ -14,6 +14,7 @@ specific language governing permissions and limitations under the License.
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
</div>
Stable Diffusion XL (SDXL) was proposed in [SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis](https://huggingface.co/papers/2307.01952) by Dustin Podell, Zion English, Kyle Lacey, Andreas Blattmann, Tim Dockhorn, Jonas Müller, Joe Penna, and Robin Rombach.
+395 -12
View File
@@ -22,18 +22,397 @@
<!-- TODO(aryan): update abstract once paper is out -->
## Generating Videos with Wan 2.1
We will first need to install some addtional dependencies.
```shell
pip install -u ftfy imageio-ffmpeg imageio
```
### Text to Video Generation
The following example requires 11GB VRAM to run and uses the smaller `Wan-AI/Wan2.1-T2V-1.3B-Diffusers` model. You can switch it out
for the larger `Wan2.1-I2V-14B-720P-Diffusers` or `Wan-AI/Wan2.1-I2V-14B-480P-Diffusers` if you have at least 35GB VRAM available.
```python
from diffusers import WanPipeline
from diffusers.utils import export_to_video
# Available models: Wan-AI/Wan2.1-I2V-14B-720P-Diffusers or Wan-AI/Wan2.1-I2V-14B-480P-Diffusers
model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
pipe = WanPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
num_frames = 33
frames = pipe(prompt=prompt, negative_prompt=negative_prompt, num_frames=num_frames).frames[0]
export_to_video(frames, "wan-t2v.mp4", fps=16)
```
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
You can improve the quality of the generated video by running the decoding step in full precision.
</Tip>
Recommendations for inference:
- VAE in `torch.float32` for better decoding quality.
- `num_frames` should be of the form `4 * k + 1`, for example `49` or `81`.
- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution videos, try higher values (between `7.0` and `12.0`). The default value is `3.0` for Wan.
```python
from diffusers import WanPipeline, AutoencoderKLWan
from diffusers.utils import export_to_video
### Using a custom scheduler
model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
# replace this with pipe.to("cuda") if you have sufficient VRAM
pipe.enable_model_cpu_offload()
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
num_frames = 33
frames = pipe(prompt=prompt, num_frames=num_frames).frames[0]
export_to_video(frames, "wan-t2v.mp4", fps=16)
```
### Image to Video Generation
The Image to Video pipeline requires loading the `AutoencoderKLWan` and the `CLIPVisionModel` components in full precision. The following example will need at least
35GB of VRAM to run.
```python
import torch
import numpy as np
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
from diffusers.utils import export_to_video, load_image
from transformers import CLIPVisionModel
# Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
image_encoder = CLIPVisionModel.from_pretrained(
model_id, subfolder="image_encoder", torch_dtype=torch.float32
)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanImageToVideoPipeline.from_pretrained(
model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
)
# replace this with pipe.to("cuda") if you have sufficient VRAM
pipe.enable_model_cpu_offload()
image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
)
max_area = 480 * 832
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
prompt = (
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
)
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
num_frames = 33
output = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
guidance_scale=5.0,
).frames[0]
export_to_video(output, "wan-i2v.mp4", fps=16)
```
### Video to Video Generation
```python
import torch
from diffusers.utils import load_video, export_to_video
from diffusers import AutoencoderKLWan, WanVideoToVideoPipeline, UniPCMultistepScheduler
# Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers
model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(
model_id, subfolder="vae", torch_dtype=torch.float32
)
pipe = WanVideoToVideoPipeline.from_pretrained(
model_id, vae=vae, torch_dtype=torch.bfloat16
)
flow_shift = 3.0 # 5.0 for 720P, 3.0 for 480P
pipe.scheduler = UniPCMultistepScheduler.from_config(
pipe.scheduler.config, flow_shift=flow_shift
)
# change to pipe.to("cuda") if you have sufficient VRAM
pipe.enable_model_cpu_offload()
prompt = "A robot standing on a mountain top. The sun is setting in the background"
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
video = load_video(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4"
)
output = pipe(
video=video,
prompt=prompt,
negative_prompt=negative_prompt,
height=480,
width=512,
guidance_scale=7.0,
strength=0.7,
).frames[0]
export_to_video(output, "wan-v2v.mp4", fps=16)
```
## Memory Optimizations for Wan 2.1
Base inference with the large 14B Wan 2.1 models can take up to 35GB of VRAM when generating videos at 720p resolution. We'll outline a few memory optimizations we can apply to reduce the VRAM required to run the model.
We'll use `Wan-AI/Wan2.1-I2V-14B-720P-Diffusers` model in these examples to demonstrate the memory savings, but the techniques are applicable to all model checkpoints.
### Group Offloading the Transformer and UMT5 Text Encoder
Find more information about group offloading [here](../optimization/memory.md)
#### Block Level Group Offloading
We can reduce our VRAM requirements by applying group offloading to the larger model components of the pipeline; the `WanTransformer3DModel` and `UMT5EncoderModel`. Group offloading will break up the individual modules of a model and offload/onload them onto your GPU as needed during inference. In this example, we'll apply `block_level` offloading, which will group the modules in a model into blocks of size `num_blocks_per_group` and offload/onload them to GPU. Moving to between CPU and GPU does add latency to the inference process. You can trade off between latency and memory savings by increasing or decreasing the `num_blocks_per_group`.
The following example will now only require 14GB of VRAM to run, but will take approximately 30 minutes to generate a video.
```python
import torch
import numpy as np
from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanImageToVideoPipeline
from diffusers.hooks.group_offloading import apply_group_offloading
from diffusers.utils import export_to_video, load_image
from transformers import UMT5EncoderModel, CLIPVisionModel
# Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
image_encoder = CLIPVisionModel.from_pretrained(
model_id, subfolder="image_encoder", torch_dtype=torch.float32
)
text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
onload_device = torch.device("cuda")
offload_device = torch.device("cpu")
apply_group_offloading(text_encoder,
onload_device=onload_device,
offload_device=offload_device,
offload_type="block_level",
num_blocks_per_group=4
)
transformer.enable_group_offload(
onload_device=onload_device,
offload_device=offload_device,
offload_type="block_level",
num_blocks_per_group=4,
)
pipe = WanImageToVideoPipeline.from_pretrained(
model_id,
vae=vae,
transformer=transformer,
text_encoder=text_encoder,
image_encoder=image_encoder,
torch_dtype=torch.bfloat16
)
# Since we've offloaded the larger models alrady, we can move the rest of the model components to GPU
pipe.to("cuda")
image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
)
max_area = 720 * 832
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
prompt = (
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
)
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
num_frames = 33
output = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
guidance_scale=5.0,
).frames[0]
export_to_video(output, "wan-i2v.mp4", fps=16)
```
#### Block Level Group Offloading with CUDA Streams
We can speed up group offloading inference, by enabling the use of [CUDA streams](https://pytorch.org/docs/stable/generated/torch.cuda.Stream.html). However, using CUDA streams requires moving the model parameters into pinned memory. This allocation is handled by Pytorch under the hood, and can result in a significant spike in CPU RAM usage. Please consider this option if your CPU RAM is atleast 2X the size of the model you are group offloading.
In the following example we will use CUDA streams when group offloading the `WanTransformer3DModel`. When testing on an A100, this example will require 14GB of VRAM, 52GB of CPU RAM, but will generate a video in approximately 9 minutes.
```python
import torch
import numpy as np
from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanImageToVideoPipeline
from diffusers.hooks.group_offloading import apply_group_offloading
from diffusers.utils import export_to_video, load_image
from transformers import UMT5EncoderModel, CLIPVisionModel
# Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
image_encoder = CLIPVisionModel.from_pretrained(
model_id, subfolder="image_encoder", torch_dtype=torch.float32
)
text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
onload_device = torch.device("cuda")
offload_device = torch.device("cpu")
apply_group_offloading(text_encoder,
onload_device=onload_device,
offload_device=offload_device,
offload_type="block_level",
num_blocks_per_group=4
)
transformer.enable_group_offload(
onload_device=onload_device,
offload_device=offload_device,
offload_type="leaf_level",
use_stream=True
)
pipe = WanImageToVideoPipeline.from_pretrained(
model_id,
vae=vae,
transformer=transformer,
text_encoder=text_encoder,
image_encoder=image_encoder,
torch_dtype=torch.bfloat16
)
# Since we've offloaded the larger models alrady, we can move the rest of the model components to GPU
pipe.to("cuda")
image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
)
max_area = 720 * 832
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
prompt = (
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
)
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
num_frames = 33
output = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
guidance_scale=5.0,
).frames[0]
export_to_video(output, "wan-i2v.mp4", fps=16)
```
### Applying Layerwise Casting to the Transformer
Find more information about layerwise casting [here](../optimization/memory.md)
In this example, we will model offloading with layerwise casting. Layerwise casting will downcast each layer's weights to `torch.float8_e4m3fn`, temporarily upcast to `torch.bfloat16` during the forward pass of the layer, then revert to `torch.float8_e4m3fn` afterward. This approach reduces memory requirements by approximately 50% while introducing a minor quality reduction in the generated video due to the precision trade-off.
This example will require 20GB of VRAM.
```python
import torch
import numpy as np
from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanImageToVideoPipeline
from diffusers.hooks.group_offloading import apply_group_offloading
from diffusers.utils import export_to_video, load_image
from transformers import UMT5EncoderModel, CLIPVisionModel
model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
image_encoder = CLIPVisionModel.from_pretrained(
model_id, subfolder="image_encoder", torch_dtype=torch.float32
)
text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
pipe = WanImageToVideoPipeline.from_pretrained(
model_id,
vae=vae,
transformer=transformer,
text_encoder=text_encoder,
image_encoder=image_encoder,
torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg")
max_area = 720 * 832
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
image = image.resize((width, height))
prompt = (
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
)
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
num_frames = 33
output = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=num_frames,
num_inference_steps=50,
guidance_scale=5.0,
).frames[0]
export_to_video(output, "wan-i2v.mp4", fps=16)
```
## Using a Custom Scheduler
Wan can be used with many different schedulers, each with their own benefits regarding speed and generation quality. By default, Wan uses the `UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0)` scheduler. You can use a different scheduler as follows:
@@ -49,11 +428,10 @@ pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", scheduler
pipe.scheduler = <CUSTOM_SCHEDULER_HERE>
```
### Using single file loading with Wan
The `WanTransformer3DModel` and `AutoencoderKLWan` models support loading checkpoints in their original format via the `from_single_file` loading
method.
## Using Single File Loading with Wan 2.1
The `WanTransformer3DModel` and `AutoencoderKLWan` models support loading checkpoints in their original format via the `from_single_file` loading
method.
```python
import torch
@@ -65,6 +443,11 @@ transformer = WanTransformer3DModel.from_single_file(ckpt_path, torch_dtype=torc
pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", transformer=transformer)
```
## Recommendations for Inference
- Keep `AutencoderKLWan` in `torch.float32` for better decoding quality.
- `num_frames` should satisfy the following constraint: `(num_frames - 1) % 4 == 0`
- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution videos, try higher values (between `7.0` and `12.0`). The default value is `3.0` for Wan.
## WanPipeline
[[autodoc]] WanPipeline
+7 -5
View File
@@ -161,10 +161,10 @@ Your Python environment will find the `main` version of 🤗 Diffusers on the ne
Model weights and files are downloaded from the Hub to a cache which is usually your home directory. You can change the cache location by specifying the `HF_HOME` or `HUGGINFACE_HUB_CACHE` environment variables or configuring the `cache_dir` parameter in methods like [`~DiffusionPipeline.from_pretrained`].
Cached files allow you to run 🤗 Diffusers offline. To prevent 🤗 Diffusers from connecting to the internet, set the `HF_HUB_OFFLINE` environment variable to `True` and 🤗 Diffusers will only load previously downloaded files in the cache.
Cached files allow you to run 🤗 Diffusers offline. To prevent 🤗 Diffusers from connecting to the internet, set the `HF_HUB_OFFLINE` environment variable to `1` and 🤗 Diffusers will only load previously downloaded files in the cache.
```shell
export HF_HUB_OFFLINE=True
export HF_HUB_OFFLINE=1
```
For more details about managing and cleaning the cache, take a look at the [caching](https://huggingface.co/docs/huggingface_hub/guides/manage-cache) guide.
@@ -179,14 +179,16 @@ Telemetry is only sent when loading models and pipelines from the Hub,
and it is not collected if you're loading local files.
We understand that not everyone wants to share additional information,and we respect your privacy.
You can disable telemetry collection by setting the `DISABLE_TELEMETRY` environment variable from your terminal:
You can disable telemetry collection by setting the `HF_HUB_DISABLE_TELEMETRY` environment variable from your terminal:
On Linux/MacOS:
```bash
export DISABLE_TELEMETRY=YES
export HF_HUB_DISABLE_TELEMETRY=1
```
On Windows:
```bash
set DISABLE_TELEMETRY=YES
set HF_HUB_DISABLE_TELEMETRY=1
```
+20
View File
@@ -198,6 +198,18 @@ export_to_video(video, "output.mp4", fps=8)
Group offloading (for CUDA devices with support for asynchronous data transfer streams) overlaps data transfer and computation to reduce the overall execution time compared to sequential offloading. This is enabled using layer prefetching with CUDA streams. The next layer to be executed is loaded onto the accelerator device while the current layer is being executed - this increases the memory requirements slightly. Group offloading also supports leaf-level offloading (equivalent to sequential CPU offloading) but can be made much faster when using streams.
<Tip>
- Group offloading may not work with all models out-of-the-box. If the forward implementations of the model contain weight-dependent device-casting of inputs, it may clash with the offloading mechanism's handling of device-casting.
- The `offload_type` parameter can be set to either `block_level` or `leaf_level`. `block_level` offloads groups of `torch::nn::ModuleList` or `torch::nn:Sequential` modules based on a configurable attribute `num_blocks_per_group`. For example, if you set `num_blocks_per_group=2` on a standard transformer model containing 40 layers, it will onload/offload 2 layers at a time for a total of 20 onload/offloads. This drastically reduces the VRAM requirements. `leaf_level` offloads individual layers at the lowest level, which is equivalent to sequential offloading. However, unlike sequential offloading, group offloading can be made much faster when using streams, with minimal compromise to end-to-end generation time.
- The `use_stream` parameter can be used with CUDA devices to enable prefetching layers for onload. It defaults to `False`. Layer prefetching allows overlapping computation and data transfer of model weights, which drastically reduces the overall execution time compared to other offloading methods. However, it can increase the CPU RAM usage significantly. Ensure that available CPU RAM that is at least twice the size of the model when setting `use_stream=True`. You can find more information about CUDA streams [here](https://pytorch.org/docs/stable/generated/torch.cuda.Stream.html)
- If specifying `use_stream=True` on VAEs with tiling enabled, make sure to do a dummy forward pass (possibly with dummy inputs) before the actual inference to avoid device-mismatch errors. This may not work on all implementations. Please open an issue if you encounter any problems.
- The parameter `low_cpu_mem_usage` can be set to `True` to reduce CPU memory usage when using streams for group offloading. This is useful when the CPU memory is the bottleneck, but it may counteract the benefits of using streams and increase the overall execution time. The CPU memory savings come from creating pinned-tensors on-the-fly instead of pre-pinning them. This parameter is better suited for using `leaf_level` offloading.
For more information about available parameters and an explanation of how group offloading works, refer to [`~hooks.group_offloading.apply_group_offloading`].
</Tip>
## FP8 layerwise weight-casting
PyTorch supports `torch.float8_e4m3fn` and `torch.float8_e5m2` as weight storage dtypes, but they can't be used for computation in many different tensor operations due to unimplemented kernel support. However, you can use these dtypes to store model weights in fp8 precision and upcast them on-the-fly when the layers are used in the forward pass. This is known as layerwise weight-casting.
@@ -235,6 +247,14 @@ In the above example, layerwise casting is enabled on the transformer component
However, you gain more control and flexibility by directly utilizing the [`~hooks.layerwise_casting.apply_layerwise_casting`] function instead of [`~ModelMixin.enable_layerwise_casting`].
<Tip>
- Layerwise casting may not work with all models out-of-the-box. Sometimes, the forward implementations of the model might contain internal typecasting of weight values. Such implementations are not supported due to the currently simplistic implementation of layerwise casting, which assumes that the forward pass is independent of the weight precision and that the input dtypes are always in `compute_dtype`. An example of an incompatible implementation can be found [here](https://github.com/huggingface/transformers/blob/7f5077e53682ca855afc826162b204ebf809f1f9/src/transformers/models/t5/modeling_t5.py#L294-L299).
- Layerwise casting may fail on custom modeling implementations that make use of [PEFT](https://github.com/huggingface/peft) layers. Some minimal checks to handle this case is implemented but is not extensively tested or guaranteed to work in all cases.
- It can be also be applied partially to specific layers of a model. Partially applying layerwise casting can either be done manually by calling the `apply_layerwise_casting` function on specific internal modules, or by specifying the `skip_modules_pattern` and `skip_modules_classes` parameters for a root module. These parameters are particularly useful for layers such as normalization and modulation.
</Tip>
## Channels-last memory format
The channels-last memory format is an alternative way of ordering NCHW tensors in memory to preserve dimension ordering. Channels-last tensors are ordered in such a way that the channels become the densest dimension (storing images pixel-per-pixel). Since not all operators currently support the channels-last format, it may result in worst performance but you should still try and see if it works for your model.
+12 -1
View File
@@ -12,6 +12,9 @@ specific language governing permissions and limitations under the License.
# Metal Performance Shaders (MPS)
> [!TIP]
> Pipelines with a <img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22"> badge indicate a model can take advantage of the MPS backend on Apple silicon devices for faster inference. Feel free to open a [Pull Request](https://github.com/huggingface/diffusers/compare) to add this badge to pipelines that are missing it.
🤗 Diffusers is compatible with Apple silicon (M1/M2 chips) using the PyTorch [`mps`](https://pytorch.org/docs/stable/notes/mps.html) device, which uses the Metal framework to leverage the GPU on MacOS devices. You'll need to have:
- macOS computer with Apple silicon (M1/M2) hardware
@@ -37,7 +40,7 @@ image
<Tip warning={true}>
Generating multiple prompts in a batch can [crash](https://github.com/huggingface/diffusers/issues/363) or fail to work reliably. We believe this is related to the [`mps`](https://github.com/pytorch/pytorch/issues/84039) backend in PyTorch. While this is being investigated, you should iterate instead of batching.
The PyTorch [mps](https://pytorch.org/docs/stable/notes/mps.html) backend does not support NDArray sizes greater than `2**32`. Please open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose) if you encounter this problem so we can investigate.
</Tip>
@@ -59,6 +62,10 @@ If you're using **PyTorch 1.13**, you need to "prime" the pipeline with an addit
## Troubleshoot
This section lists some common issues with using the `mps` backend and how to solve them.
### Attention slicing
M1/M2 performance is very sensitive to memory pressure. When this occurs, the system automatically swaps if it needs to which significantly degrades performance.
To prevent this from happening, we recommend *attention slicing* to reduce memory pressure during inference and prevent swapping. This is especially relevant if your computer has less than 64GB of system RAM, or if you generate images at non-standard resolutions larger than 512×512 pixels. Call the [`~DiffusionPipeline.enable_attention_slicing`] function on your pipeline:
@@ -72,3 +79,7 @@ pipeline.enable_attention_slicing()
```
Attention slicing performs the costly attention operation in multiple steps instead of all at once. It usually improves performance by ~20% in computers without universal memory, but we've observed *better performance* in most Apple silicon computers unless you have 64GB of RAM or more.
### Batch inference
Generating multiple prompts in a batch can crash or fail to work reliably. If this is the case, try iterating instead of batching.
+17
View File
@@ -95,6 +95,23 @@ Use the Space below to gauge a pipeline's memory requirements before you downloa
></iframe>
</div>
### Specifying Component-Specific Data Types
You can customize the data types for individual sub-models by passing a dictionary to the `torch_dtype` parameter. This allows you to load different components of a pipeline in different floating point precisions. For instance, if you want to load the transformer with `torch.bfloat16` and all other components with `torch.float16`, you can pass a dictionary mapping:
```python
from diffusers import HunyuanVideoPipeline
import torch
pipe = HunyuanVideoPipeline.from_pretrained(
"hunyuanvideo-community/HunyuanVideo",
torch_dtype={'transformer': torch.bfloat16, 'default': torch.float16},
)
print(pipe.transformer.dtype, pipe.vae.dtype) # (torch.bfloat16, torch.float16)
```
If a component is not explicitly specified in the dictionary and no `default` is provided, it will be loaded with `torch.float32`.
### Local pipeline
To load a pipeline locally, use [git-lfs](https://git-lfs.github.com/) to manually download a checkpoint to your local disk.
@@ -194,6 +194,59 @@ Currently, [`~loaders.StableDiffusionLoraLoaderMixin.set_adapters`] only support
</Tip>
### Hotswapping LoRA adapters
A common use case when serving multiple adapters is to load one adapter first, generate images, load another adapter, generate more images, load another adapter, etc. This workflow normally requires calling [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`], [`~loaders.StableDiffusionLoraLoaderMixin.set_adapters`], and possibly [`~loaders.peft.PeftAdapterMixin.delete_adapters`] to save memory. Moreover, if the model is compiled using `torch.compile`, performing these steps requires recompilation, which takes time.
To better support this common workflow, you can "hotswap" a LoRA adapter, to avoid accumulating memory and in some cases, recompilation. It requires an adapter to already be loaded, and the new adapter weights are swapped in-place for the existing adapter.
Pass `hotswap=True` when loading a LoRA adapter to enable this feature. It is important to indicate the name of the existing adapter, (`default_0` is the default adapter name), to be swapped. If you loaded the first adapter with a different name, use that name instead.
```python
pipe = ...
# load adapter 1 as normal
pipeline.load_lora_weights(file_name_adapter_1)
# generate some images with adapter 1
...
# now hot swap the 2nd adapter
pipeline.load_lora_weights(file_name_adapter_2, hotswap=True, adapter_name="default_0")
# generate images with adapter 2
```
<Tip warning={true}>
Hotswapping is not currently supported for LoRA adapters that target the text encoder.
</Tip>
For compiled models, it is often (though not always if the second adapter targets identical LoRA ranks and scales) necessary to call [`~loaders.lora_base.LoraBaseMixin.enable_lora_hotswap`] to avoid recompilation. Use [`~loaders.lora_base.LoraBaseMixin.enable_lora_hotswap`] _before_ loading the first adapter, and `torch.compile` should be called _after_ loading the first adapter.
```python
pipe = ...
# call this extra method
pipe.enable_lora_hotswap(target_rank=max_rank)
# now load adapter 1
pipe.load_lora_weights(file_name_adapter_1)
# now compile the unet of the pipeline
pipe.unet = torch.compile(pipeline.unet, ...)
# generate some images with adapter 1
...
# now hot swap adapter 2
pipeline.load_lora_weights(file_name_adapter_2, hotswap=True, adapter_name="default_0")
# generate images with adapter 2
```
The `target_rank=max_rank` argument is important for setting the maximum rank among all LoRA adapters that will be loaded. If you have one adapter with rank 8 and another with rank 16, pass `target_rank=16`. You should use a higher value if in doubt. By default, this value is 128.
However, there can be situations where recompilation is unavoidable. For example, if the hotswapped adapter targets more layers than the initial adapter, then recompilation is triggered. Try to load the adapter that targets the most layers first. Refer to the PEFT docs on [hotswapping](https://huggingface.co/docs/peft/main/en/package_reference/hotswap#peft.utils.hotswap.hotswap_adapter) for more details about the limitations of this feature.
<Tip>
Move your code inside the `with torch._dynamo.config.patch(error_on_recompile=True)` context manager to detect if a model was recompiled. If you detect recompilation despite following all the steps above, please open an issue with [Diffusers](https://github.com/huggingface/diffusers/issues) with a reproducible example.
</Tip>
### Kohya and TheLastBen
Other popular LoRA trainers from the community include those by [Kohya](https://github.com/kohya-ss/sd-scripts/) and [TheLastBen](https://github.com/TheLastBen/fast-stable-diffusion). These trainers create different LoRA checkpoints than those trained by 🤗 Diffusers, but they can still be loaded in the same way.
-6
View File
@@ -66,12 +66,6 @@ from accelerate.utils import write_basic_config
write_basic_config()
```
## 원을 채우는 데이터셋
원본 데이터셋은 ControlNet [repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip)에 올라와있지만, 우리는 [여기](https://huggingface.co/datasets/fusing/fill50k)에 새롭게 다시 올려서 🤗 Datasets 과 호환가능합니다. 그래서 학습 스크립트 상에서 데이터 불러오기를 다룰 수 있습니다.
우리의 학습 예시는 원래 ControlNet의 학습에 쓰였던 [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)을 사용합니다. 그렇지만 ControlNet은 대응되는 어느 Stable Diffusion 모델([`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4)) 혹은 [`stabilityai/stable-diffusion-2-1`](https://huggingface.co/stabilityai/stable-diffusion-2-1)의 증가를 위해 학습될 수 있습니다.
자체 데이터셋을 사용하기 위해서는 [학습을 위한 데이터셋 생성하기](create_dataset) 가이드를 확인하세요.
## 학습
@@ -1,7 +1,8 @@
accelerate>=0.16.0
accelerate>=0.31.0
torchvision
transformers>=4.25.1
transformers>=4.41.2
ftfy
tensorboard
Jinja2
peft==0.7.0
peft>=0.11.1
sentencepiece
@@ -24,7 +24,7 @@ import re
import shutil
from contextlib import nullcontext
from pathlib import Path
from typing import List, Optional, Union
from typing import List, Optional
import numpy as np
import torch
@@ -228,10 +228,20 @@ def log_validation(
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
autocast_ctx = nullcontext()
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
with autocast_ctx:
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
# pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
pipeline_args["prompt"], prompt_2=pipeline_args["prompt"]
)
images = []
for _ in range(args.num_validation_images):
with autocast_ctx:
image = pipeline(
prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator
).images[0]
images.append(image)
for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
@@ -657,6 +667,7 @@ def parse_args(input_args=None):
parser.add_argument(
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
)
parser.add_argument(
"--lora_layers",
type=str,
@@ -666,6 +677,7 @@ def parse_args(input_args=None):
'E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only. For more examples refer to https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/README_flux.md'
),
)
parser.add_argument(
"--adam_epsilon",
type=float,
@@ -738,6 +750,15 @@ def parse_args(input_args=None):
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--upcast_before_saving",
action="store_true",
default=False,
help=(
"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
"Defaults to precision dtype used for training to save memory"
),
)
parser.add_argument(
"--prior_generation_precision",
type=str,
@@ -1147,7 +1168,7 @@ def tokenize_prompt(tokenizer, prompt, max_sequence_length, add_special_tokens=F
return text_input_ids
def _get_t5_prompt_embeds(
def _encode_prompt_with_t5(
text_encoder,
tokenizer,
max_sequence_length=512,
@@ -1176,7 +1197,10 @@ def _get_t5_prompt_embeds(
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
dtype = text_encoder.dtype
if hasattr(text_encoder, "module"):
dtype = text_encoder.module.dtype
else:
dtype = text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
@@ -1188,7 +1212,7 @@ def _get_t5_prompt_embeds(
return prompt_embeds
def _get_clip_prompt_embeds(
def _encode_prompt_with_clip(
text_encoder,
tokenizer,
prompt: str,
@@ -1217,9 +1241,13 @@ def _get_clip_prompt_embeds(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
if hasattr(text_encoder, "module"):
dtype = text_encoder.module.dtype
else:
dtype = text_encoder.dtype
# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -1238,136 +1266,35 @@ def encode_prompt(
text_input_ids_list=None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
dtype = text_encoders[0].dtype
if hasattr(text_encoders[0], "module"):
dtype = text_encoders[0].module.dtype
else:
dtype = text_encoders[0].dtype
pooled_prompt_embeds = _get_clip_prompt_embeds(
pooled_prompt_embeds = _encode_prompt_with_clip(
text_encoder=text_encoders[0],
tokenizer=tokenizers[0],
prompt=prompt,
device=device if device is not None else text_encoders[0].device,
num_images_per_prompt=num_images_per_prompt,
text_input_ids=text_input_ids_list[0] if text_input_ids_list is not None else None,
text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
)
prompt_embeds = _get_t5_prompt_embeds(
prompt_embeds = _encode_prompt_with_t5(
text_encoder=text_encoders[1],
tokenizer=tokenizers[1],
max_sequence_length=max_sequence_length,
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
device=device if device is not None else text_encoders[1].device,
text_input_ids=text_input_ids_list[1] if text_input_ids_list is not None else None,
text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
)
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
return prompt_embeds, pooled_prompt_embeds, text_ids
# CustomFlowMatchEulerDiscreteScheduler was taken from ostris ai-toolkit trainer:
# https://github.com/ostris/ai-toolkit/blob/9ee1ef2a0a2a9a02b92d114a95f21312e5906e54/toolkit/samplers/custom_flowmatch_sampler.py#L95
class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
with torch.no_grad():
# create weights for timesteps
num_timesteps = 1000
# generate the multiplier based on cosmap loss weighing
# this is only used on linear timesteps for now
# cosine map weighing is higher in the middle and lower at the ends
# bot = 1 - 2 * self.sigmas + 2 * self.sigmas ** 2
# cosmap_weighing = 2 / (math.pi * bot)
# sigma sqrt weighing is significantly higher at the end and lower at the beginning
sigma_sqrt_weighing = (self.sigmas**-2.0).float()
# clip at 1e4 (1e6 is too high)
sigma_sqrt_weighing = torch.clamp(sigma_sqrt_weighing, max=1e4)
# bring to a mean of 1
sigma_sqrt_weighing = sigma_sqrt_weighing / sigma_sqrt_weighing.mean()
# Create linear timesteps from 1000 to 0
timesteps = torch.linspace(1000, 0, num_timesteps, device="cpu")
self.linear_timesteps = timesteps
# self.linear_timesteps_weights = cosmap_weighing
self.linear_timesteps_weights = sigma_sqrt_weighing
# self.sigmas = self.get_sigmas(timesteps, n_dim=1, dtype=torch.float32, device='cpu')
pass
def get_weights_for_timesteps(self, timesteps: torch.Tensor) -> torch.Tensor:
# Get the indices of the timesteps
step_indices = [(self.timesteps == t).nonzero().item() for t in timesteps]
# Get the weights for the timesteps
weights = self.linear_timesteps_weights[step_indices].flatten()
return weights
def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor:
sigmas = self.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = self.timesteps.to(device)
timesteps = timesteps.to(device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
) -> torch.Tensor:
## ref https://github.com/huggingface/diffusers/blob/fbe29c62984c33c6cf9cf7ad120a992fe6d20854/examples/dreambooth/train_dreambooth_sd3.py#L1578
## Add noise according to flow matching.
## zt = (1 - texp) * x + texp * z1
# sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
# noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
# timestep needs to be in [0, 1], we store them in [0, 1000]
# noisy_sample = (1 - timestep) * latent + timestep * noise
t_01 = (timesteps / 1000).to(original_samples.device)
noisy_model_input = (1 - t_01) * original_samples + t_01 * noise
# n_dim = original_samples.ndim
# sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device)
# noisy_model_input = (1.0 - sigmas) * original_samples + sigmas * noise
return noisy_model_input
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
return sample
def set_train_timesteps(self, num_timesteps, device, linear=False):
if linear:
timesteps = torch.linspace(1000, 0, num_timesteps, device=device)
self.timesteps = timesteps
return timesteps
else:
# distribute them closer to center. Inference distributes them as a bias toward first
# Generate values from 0 to 1
t = torch.sigmoid(torch.randn((num_timesteps,), device=device))
# Scale and reverse the values to go from 1000 to 0
timesteps = (1 - t) * 1000
# Sort the timesteps in descending order
timesteps, _ = torch.sort(timesteps, descending=True)
self.timesteps = timesteps.to(device=device)
return timesteps
def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
@@ -1499,7 +1426,7 @@ def main(args):
)
# Load scheduler and models
noise_scheduler = CustomFlowMatchEulerDiscreteScheduler.from_pretrained(
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
args.pretrained_model_name_or_path, subfolder="scheduler"
)
noise_scheduler_copy = copy.deepcopy(noise_scheduler)
@@ -1619,7 +1546,6 @@ def main(args):
target_modules=target_modules,
)
transformer.add_adapter(transformer_lora_config)
if args.train_text_encoder:
text_lora_config = LoraConfig(
r=args.rank,
@@ -1727,7 +1653,6 @@ def main(args):
cast_training_params(models, dtype=torch.float32)
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
if args.train_text_encoder:
text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
# if we use textual inversion, we freeze all parameters except for the token embeddings
@@ -1737,7 +1662,8 @@ def main(args):
for name, param in text_encoder_one.named_parameters():
if "token_embedding" in name:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
param.data = param.to(dtype=torch.float32)
if args.mixed_precision == "fp16":
param.data = param.to(dtype=torch.float32)
param.requires_grad = True
text_lora_parameters_one.append(param)
else:
@@ -1747,7 +1673,8 @@ def main(args):
for name, param in text_encoder_two.named_parameters():
if "shared" in name:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
param.data = param.to(dtype=torch.float32)
if args.mixed_precision == "fp16":
param.data = param.to(dtype=torch.float32)
param.requires_grad = True
text_lora_parameters_two.append(param)
else:
@@ -1828,6 +1755,7 @@ def main(args):
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(
params_to_optimize,
betas=(args.adam_beta1, args.adam_beta2),
@@ -2021,6 +1949,7 @@ def main(args):
lr_scheduler,
)
else:
print("I SHOULD BE HERE")
transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
transformer, text_encoder_one, optimizer, train_dataloader, lr_scheduler
)
@@ -2125,7 +2054,7 @@ def main(args):
if args.train_text_encoder:
text_encoder_one.train()
# set top parameter requires_grad = True for gradient checkpointing works
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
elif args.train_text_encoder_ti: # textual inversion / pivotal tuning
text_encoder_one.train()
if args.enable_t5_ti:
@@ -2137,6 +2066,11 @@ def main(args):
pivoted_tr = True
for step, batch in enumerate(train_dataloader):
models_to_accumulate = [transformer]
if not freeze_text_encoder:
models_to_accumulate.extend([text_encoder_one])
if args.enable_t5_ti:
models_to_accumulate.extend([text_encoder_two])
if pivoted_te:
# stopping optimization of text_encoder params
optimizer.param_groups[te_idx]["lr"] = 0.0
@@ -2145,7 +2079,7 @@ def main(args):
logger.info(f"PIVOT TRANSFORMER {epoch}")
optimizer.param_groups[0]["lr"] = 0.0
with accelerator.accumulate(transformer):
with accelerator.accumulate(models_to_accumulate):
prompts = batch["prompts"]
# encode batch prompts when custom prompts are provided for each image -
@@ -2189,7 +2123,7 @@ def main(args):
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
model_input = model_input.to(dtype=weight_dtype)
vae_scale_factor = 2 ** (len(vae_config_block_out_channels))
vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1)
latent_image_ids = FluxPipeline._prepare_latent_image_ids(
model_input.shape[0],
@@ -2228,7 +2162,7 @@ def main(args):
)
# handle guidance
if transformer.config.guidance_embeds:
if unwrap_model(transformer).config.guidance_embeds:
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
guidance = guidance.expand(model_input.shape[0])
else:
@@ -2288,16 +2222,26 @@ def main(args):
accelerator.backward(loss)
if accelerator.sync_gradients:
if not freeze_text_encoder:
if args.train_text_encoder:
if args.train_text_encoder: # text encoder tuning
params_to_clip = itertools.chain(transformer.parameters(), text_encoder_one.parameters())
elif pure_textual_inversion:
params_to_clip = itertools.chain(
text_encoder_one.parameters(), text_encoder_two.parameters()
)
if args.enable_t5_ti:
params_to_clip = itertools.chain(
text_encoder_one.parameters(), text_encoder_two.parameters()
)
else:
params_to_clip = itertools.chain(text_encoder_one.parameters())
else:
params_to_clip = itertools.chain(
transformer.parameters(), text_encoder_one.parameters(), text_encoder_two.parameters()
)
if args.enable_t5_ti:
params_to_clip = itertools.chain(
transformer.parameters(),
text_encoder_one.parameters(),
text_encoder_two.parameters(),
)
else:
params_to_clip = itertools.chain(
transformer.parameters(), text_encoder_one.parameters()
)
else:
params_to_clip = itertools.chain(transformer.parameters())
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
@@ -2339,6 +2283,10 @@ def main(args):
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
if args.train_text_encoder_ti:
embedding_handler.save_embeddings(
f"{args.output_dir}/{Path(args.output_dir).name}_emb_checkpoint_{global_step}.safetensors"
)
logger.info(f"Saved state to {save_path}")
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
@@ -2351,14 +2299,16 @@ def main(args):
if accelerator.is_main_process:
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
# create pipeline
if freeze_text_encoder:
if freeze_text_encoder: # no text encoder one, two optimizations
text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
text_encoder_one.to(weight_dtype)
text_encoder_two.to(weight_dtype)
pipeline = FluxPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
text_encoder=accelerator.unwrap_model(text_encoder_one),
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
transformer=accelerator.unwrap_model(transformer),
text_encoder=unwrap_model(text_encoder_one),
text_encoder_2=unwrap_model(text_encoder_two),
transformer=unwrap_model(transformer),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
@@ -2372,21 +2322,21 @@ def main(args):
epoch=epoch,
torch_dtype=weight_dtype,
)
images = None
del pipeline
if freeze_text_encoder:
del text_encoder_one, text_encoder_two
free_memory()
elif args.train_text_encoder:
del text_encoder_two
free_memory()
images = None
del pipeline
# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
transformer = unwrap_model(transformer)
transformer = transformer.to(weight_dtype)
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
if args.train_text_encoder:
@@ -2428,8 +2378,8 @@ def main(args):
accelerator=accelerator,
pipeline_args=pipeline_args,
epoch=epoch,
torch_dtype=weight_dtype,
is_final_validation=True,
torch_dtype=weight_dtype,
)
save_model_card(
@@ -2452,6 +2402,7 @@ def main(args):
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
images = None
del pipeline
@@ -71,6 +71,7 @@ from diffusers.utils import (
convert_unet_state_dict_to_peft,
is_wandb_available,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
@@ -101,7 +102,7 @@ def determine_scheduler_type(pretrained_model_name_or_path, revision):
def save_model_card(
repo_id: str,
use_dora: bool,
images=None,
images: list = None,
base_model: str = None,
train_text_encoder=False,
train_text_encoder_ti=False,
@@ -111,20 +112,17 @@ def save_model_card(
repo_folder=None,
vae_path=None,
):
img_str = "widget:\n"
lora = "lora" if not use_dora else "dora"
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"""
- text: '{validation_prompt if validation_prompt else ' ' }'
output:
url:
"image_{i}.png"
"""
if not images:
img_str += f"""
- text: '{instance_prompt}'
"""
widget_dict = []
if images is not None:
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
widget_dict.append(
{"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
)
else:
widget_dict.append({"text": instance_prompt})
embeddings_filename = f"{repo_folder}_emb"
instance_prompt_webui = re.sub(r"<s\d+>", "", re.sub(r"<s\d+>", embeddings_filename, instance_prompt, count=1))
ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"<s\d+>", instance_prompt))
@@ -169,23 +167,7 @@ pipeline.load_textual_inversion(state_dict["clip_g"], token=[{ti_keys}], text_en
to trigger concept `{key}` → use `{tokens}` in your prompt \n
"""
yaml = f"""---
tags:
- stable-diffusion-xl
- stable-diffusion-xl-diffusers
- diffusers-training
- text-to-image
- diffusers
- {lora}
- template:sd-lora
{img_str}
base_model: {base_model}
instance_prompt: {instance_prompt}
license: openrail++
---
"""
model_card = f"""
model_description = f"""
# SDXL LoRA DreamBooth - {repo_id}
<Gallery />
@@ -234,8 +216,25 @@ Special VAE used for training: {vae_path}.
{license}
"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="openrail++",
base_model=base_model,
prompt=instance_prompt,
model_description=model_description,
widget=widget_dict,
)
tags = [
"text-to-image",
"stable-diffusion-xl",
"stable-diffusion-xl-diffusers",
"text-to-image",
"diffusers",
lora,
"template:sd-lora",
]
model_card = populate_model_card(model_card, tags=tags)
def log_validation(
+147 -23
View File
@@ -10,7 +10,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
| Example | Description | Code Example | Colab | Author |
|:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:|
|Spatiotemporal Skip Guidance (STG)|[Spatiotemporal Skip Guidance for Enhanced Video Diffusion Sampling](https://arxiv.org/abs/2411.18664) (CVPR 2025) enhances video diffusion models by generating a weaker model through layer skipping and using it as guidance, improving fidelity in models like HunyuanVideo, LTXVideo, and Mochi.|[Spatiotemporal Skip Guidance](#spatiotemporal-skip-guidance)|-|[Junha Hyung](https://junhahyung.github.io/), [Kinam Kim](https://kinam0252.github.io/)|
|Spatiotemporal Skip Guidance (STG)|[Spatiotemporal Skip Guidance for Enhanced Video Diffusion Sampling](https://arxiv.org/abs/2411.18664) (CVPR 2025) enhances video diffusion models by generating a weaker model through layer skipping and using it as guidance, improving fidelity in models like HunyuanVideo, LTXVideo, and Mochi.|[Spatiotemporal Skip Guidance](#spatiotemporal-skip-guidance)|-|[Junha Hyung](https://junhahyung.github.io/), [Kinam Kim](https://kinam0252.github.io/), and [Ednaordinary](https://github.com/Ednaordinary)|
|Adaptive Mask Inpainting|Adaptive Mask Inpainting algorithm from [Beyond the Contact: Discovering Comprehensive Affordance for 3D Objects from Pre-trained 2D Diffusion Models](https://github.com/snuvclab/coma) (ECCV '24, Oral) provides a way to insert human inside the scene image without altering the background, by inpainting with adapting mask.|[Adaptive Mask Inpainting](#adaptive-mask-inpainting)|-|[Hyeonwoo Kim](https://sshowbiz.xyz),[Sookwan Han](https://jellyheadandrew.github.io)|
|Flux with CFG|[Flux with CFG](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md) provides an implementation of using CFG in [Flux](https://blackforestlabs.ai/announcing-black-forest-labs/).|[Flux with CFG](#flux-with-cfg)|[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/flux_with_cfg.ipynb)|[Linoy Tsaban](https://github.com/linoytsaban), [Apolinário](https://github.com/apolinario), and [Sayak Paul](https://github.com/sayakpaul)|
|Differential Diffusion|[Differential Diffusion](https://github.com/exx8/differential-diffusion) modifies an image according to a text prompt, and according to a map that specifies the amount of change in each region.|[Differential Diffusion](#differential-diffusion)|[![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/exx8/differential-diffusion) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/exx8/differential-diffusion/blob/main/examples/SD2.ipynb)|[Eran Levin](https://github.com/exx8) and [Ohad Fried](https://www.ohadf.com/)|
@@ -24,12 +24,12 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
| Long Prompt Weighting Stable Diffusion | **One** Stable Diffusion Pipeline without tokens length limit, and support parsing weighting in prompt. | [Long Prompt Weighting Stable Diffusion](#long-prompt-weighting-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/long_prompt_weighting_stable_diffusion.ipynb) | [SkyTNT](https://github.com/SkyTNT) |
| Speech to Image | Using automatic-speech-recognition to transcribe text and Stable Diffusion to generate images | [Speech to Image](#speech-to-image) |[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/speech_to_image.ipynb) | [Mikail Duzenli](https://github.com/MikailINTech)
| Wild Card Stable Diffusion | Stable Diffusion Pipeline that supports prompts that contain wildcard terms (indicated by surrounding double underscores), with values instantiated randomly from a corresponding txt file or a dictionary of possible values | [Wildcard Stable Diffusion](#wildcard-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/wildcard_stable_diffusion.ipynb) | [Shyam Sudhakaran](https://github.com/shyamsn97) |
| [Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) | Stable Diffusion Pipeline that supports prompts that contain "&#124;" in prompts (as an AND condition) and weights (separated by "&#124;" as well) to positively / negatively weight prompts. | [Composable Stable Diffusion](#composable-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) |
| [Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) | Stable Diffusion Pipeline that supports prompts that contain "&#124;" in prompts (as an AND condition) and weights (separated by "&#124;" as well) to positively / negatively weight prompts. | [Composable Stable Diffusion](#composable-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/composable_stable_diffusion.ipynb) | [Mark Rich](https://github.com/MarkRich) |
| Seed Resizing Stable Diffusion | Stable Diffusion Pipeline that supports resizing an image and retaining the concepts of the 512 by 512 generation. | [Seed Resizing](#seed-resizing) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/seed_resizing.ipynb) | [Mark Rich](https://github.com/MarkRich) |
| Imagic Stable Diffusion | Stable Diffusion Pipeline that enables writing a text prompt to edit an existing image | [Imagic Stable Diffusion](#imagic-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/imagic_stable_diffusion.ipynb) | [Mark Rich](https://github.com/MarkRich) |
| Multilingual Stable Diffusion | Stable Diffusion Pipeline that supports prompts in 50 different languages. | [Multilingual Stable Diffusion](#multilingual-stable-diffusion-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/multilingual_stable_diffusion.ipynb) | [Juan Carlos Piñeros](https://github.com/juancopi81) |
| GlueGen Stable Diffusion | Stable Diffusion Pipeline that supports prompts in different languages using GlueGen adapter. | [GlueGen Stable Diffusion](#gluegen-stable-diffusion-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/gluegen_stable_diffusion.ipynb) | [Phạm Hồng Vinh](https://github.com/rootonchair) |
| Image to Image Inpainting Stable Diffusion | Stable Diffusion Pipeline that enables the overlaying of two images and subsequent inpainting | [Image to Image Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | - | [Alex McKinney](https://github.com/vvvm23) |
| Image to Image Inpainting Stable Diffusion | Stable Diffusion Pipeline that enables the overlaying of two images and subsequent inpainting | [Image to Image Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/image_to_image_inpainting_stable_diffusion.ipynb) | [Alex McKinney](https://github.com/vvvm23) |
| Text Based Inpainting Stable Diffusion | Stable Diffusion Inpainting Pipeline that enables passing a text prompt to generate the mask for inpainting | [Text Based Inpainting Stable Diffusion](#text-based-inpainting-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/text_based_inpainting_stable_dffusion.ipynb) | [Dhruv Karan](https://github.com/unography) |
| Bit Diffusion | Diffusion on discrete data | [Bit Diffusion](#bit-diffusion) | - | [Stuti R.](https://github.com/kingstut) |
| K-Diffusion Stable Diffusion | Run Stable Diffusion with any of [K-Diffusion's samplers](https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py) | [Stable Diffusion with K Diffusion](#stable-diffusion-with-k-diffusion) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
@@ -41,7 +41,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
| UnCLIP Image Interpolation Pipeline | Diffusion Pipeline that allows passing two images/image_embeddings and produces images while interpolating between their image-embeddings | [UnCLIP Image Interpolation Pipeline](#unclip-image-interpolation-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/unclip_image_interpolation.ipynb)| [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
| DDIM Noise Comparative Analysis Pipeline | Investigating how the diffusion models learn visual concepts from each noise level (which is a contribution of [P2 weighting (CVPR 2022)](https://arxiv.org/abs/2204.00227)) | [DDIM Noise Comparative Analysis Pipeline](#ddim-noise-comparative-analysis-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/ddim_noise_comparative_analysis.ipynb)| [Aengus (Duc-Anh)](https://github.com/aengusng8) |
| CLIP Guided Img2Img Stable Diffusion Pipeline | Doing CLIP guidance for image to image generation with Stable Diffusion | [CLIP Guided Img2Img Stable Diffusion](#clip-guided-img2img-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/clip_guided_img2img_stable_diffusion.ipynb) | [Nipun Jindal](https://github.com/nipunjindal/) |
| TensorRT Stable Diffusion Text to Image Pipeline | Accelerates the Stable Diffusion Text2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Text to Image Pipeline](#tensorrt-text2image-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
| TensorRT Stable Diffusion Text to Image Pipeline | Accelerates the Stable Diffusion Text2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Text to Image Pipeline](#tensorrt-text2image-stable-diffusion-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/tensorrt_text2image_stable_diffusion_pipeline.ipynb) | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
| EDICT Image Editing Pipeline | Diffusion pipeline for text-guided image editing | [EDICT Image Editing Pipeline](#edict-image-editing-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/edict_image_pipeline.ipynb) | [Joqsan Azocar](https://github.com/Joqsan) |
| Stable Diffusion RePaint | Stable Diffusion pipeline using [RePaint](https://arxiv.org/abs/2201.09865) for inpainting. | [Stable Diffusion RePaint](#stable-diffusion-repaint )|[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_diffusion_repaint.ipynb)| [Markus Pobitzer](https://github.com/Markus-Pobitzer) |
| TensorRT Stable Diffusion Image to Image Pipeline | Accelerates the Stable Diffusion Image2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Image to Image Pipeline](#tensorrt-image2image-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
@@ -58,7 +58,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
| FABRIC - Stable Diffusion with feedback Pipeline | pipeline supports feedback from liked and disliked images | [Stable Diffusion Fabric Pipeline](#stable-diffusion-fabric-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_diffusion_fabric.ipynb)| [Shauray Singh](https://shauray8.github.io/about_shauray/) |
| sketch inpaint - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion Pipeline](#stable-diffusion-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
| sketch inpaint xl - Inpainting with non-inpaint Stable Diffusion | sketch inpaint much like in automatic1111 | [Masked Im2Im Stable Diffusion XL Pipeline](#stable-diffusion-xl-masked-im2im) | - | [Anatoly Belikov](https://github.com/noskill) |
| prompt-to-prompt | change parts of a prompt and retain image structure (see [paper page](https://prompt-to-prompt.github.io/)) | [Prompt2Prompt Pipeline](#prompt2prompt-pipeline) | - | [Umer H. Adil](https://twitter.com/UmerHAdil) |
| prompt-to-prompt | change parts of a prompt and retain image structure (see [paper page](https://prompt-to-prompt.github.io/)) | [Prompt2Prompt Pipeline](#prompt2prompt-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/prompt_2_prompt_pipeline.ipynb) | [Umer H. Adil](https://twitter.com/UmerHAdil) |
| Latent Consistency Pipeline | Implementation of [Latent Consistency Models: Synthesizing High-Resolution Images with Few-Step Inference](https://arxiv.org/abs/2310.04378) | [Latent Consistency Pipeline](#latent-consistency-pipeline) | - | [Simian Luo](https://github.com/luosiallen) |
| Latent Consistency Img2img Pipeline | Img2img pipeline for Latent Consistency Models | [Latent Consistency Img2Img Pipeline](#latent-consistency-img2img-pipeline) | - | [Logan Zoellner](https://github.com/nagolinc) |
| Latent Consistency Interpolation Pipeline | Interpolate the latent space of Latent Consistency Models with multiple prompts | [Latent Consistency Interpolation Pipeline](#latent-consistency-interpolation-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1pK3NrLWJSiJsBynLns1K1-IDTW9zbPvl?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
@@ -85,7 +85,7 @@ PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixar
| Stable Diffusion XL Attentive Eraser Pipeline |[[AAAI2025 Oral] Attentive Eraser](https://github.com/Anonym0u3/AttentiveEraser) is a novel tuning-free method that enhances object removal capabilities in pre-trained diffusion models.|[Stable Diffusion XL Attentive Eraser Pipeline](#stable-diffusion-xl-attentive-eraser-pipeline)|-|[Wenhao Sun](https://github.com/Anonym0u3) and [Benlei Cui](https://github.com/Benny079)|
| Perturbed-Attention Guidance |StableDiffusionPAGPipeline is a modification of StableDiffusionPipeline to support Perturbed-Attention Guidance (PAG).|[Perturbed-Attention Guidance](#perturbed-attention-guidance)|[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/perturbed_attention_guidance.ipynb)|[Hyoungwon Cho](https://github.com/HyoungwonCho)|
| CogVideoX DDIM Inversion Pipeline | Implementation of DDIM inversion and guided attention-based editing denoising process on CogVideoX. | [CogVideoX DDIM Inversion Pipeline](#cogvideox-ddim-inversion-pipeline) | - | [LittleNyima](https://github.com/LittleNyima) |
| FaithDiff Stable Diffusion XL Pipeline | Implementation of [(CVPR 2025) FaithDiff: Unleashing Diffusion Priors for Faithful Image Super-resolutionUnleashing Diffusion Priors for Faithful Image Super-resolution](https://arxiv.org/abs/2411.18824) - FaithDiff is a faithful image super-resolution method that leverages latent diffusion models by actively adapting the diffusion prior and jointly fine-tuning its components (encoder and diffusion model) with an alignment module to ensure high fidelity and structural consistency. | [FaithDiff Stable Diffusion XL Pipeline](#faithdiff-stable-diffusion-xl-pipeline) | [![Hugging Face Models](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/jychen9811/FaithDiff) | [Junyang Chen, Jinshan Pan, Jiangxin Dong, IMAG Lab, (Adapted by Eliseu Silva)](https://github.com/JyChen9811/FaithDiff) |
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
```py
@@ -124,7 +124,6 @@ pipe = pipe.to("cuda")
#--------Option--------#
prompt = "A close-up of a beautiful woman's face with colored powder exploding around her, creating an abstract splash of vibrant hues, realistic style."
stg_applied_layers_idx = [34]
stg_mode = "STG"
stg_scale = 1.0 # 0.0 for CFG
#----------------------#
@@ -954,6 +953,7 @@ for i in range(args.num_images):
images.append(th.from_numpy(np.array(image)).permute(2, 0, 1) / 255.)
grid = tvu.make_grid(th.stack(images, dim=0), nrow=4, padding=0)
tvu.save_image(grid, f'{prompt}_{args.weights}' + '.png')
print("Image saved successfully!")
```
### Imagic Stable Diffusion
@@ -1269,28 +1269,39 @@ The aim is to overlay two images, then mask out the boundary between `image` and
For example, this could be used to place a logo on a shirt and make it blend seamlessly.
```python
import PIL
import torch
import requests
from PIL import Image
from io import BytesIO
from diffusers import DiffusionPipeline
image_path = "./path-to-image.png"
inner_image_path = "./path-to-inner-image.png"
mask_path = "./path-to-mask.png"
image_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
inner_image_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
init_image = PIL.Image.open(image_path).convert("RGB").resize((512, 512))
inner_image = PIL.Image.open(inner_image_path).convert("RGBA").resize((512, 512))
mask_image = PIL.Image.open(mask_path).convert("RGB").resize((512, 512))
def load_image(url, mode="RGB"):
response = requests.get(url)
if response.status_code == 200:
return Image.open(BytesIO(response.content)).convert(mode).resize((512, 512))
else:
raise FileNotFoundError(f"Could not retrieve image from {url}")
init_image = load_image(image_url, mode="RGB")
inner_image = load_image(inner_image_url, mode="RGBA")
mask_image = load_image(mask_url, mode="RGB")
pipe = DiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting",
"stable-diffusion-v1-5/stable-diffusion-inpainting",
custom_pipeline="img2img_inpainting",
torch_dtype=torch.float16
)
pipe = pipe.to("cuda")
prompt = "Your prompt here!"
prompt = "a mecha robot sitting on a bench"
image = pipe(prompt=prompt, image=init_image, inner_image=inner_image, mask_image=mask_image).images[0]
image.save("output.png")
```
![2 by 2 grid demonstrating image to image inpainting.](https://user-images.githubusercontent.com/44398246/203506577-ec303be4-887e-4ebd-a773-c83fcb3dd01a.png)
@@ -3252,14 +3263,19 @@ Here's a full example for `ReplaceEdit``:
```python
import torch
import numpy as np
import matplotlib.pyplot as plt
from diffusers import DiffusionPipeline
import numpy as np
from PIL import Image
pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", custom_pipeline="pipeline_prompt2prompt").to("cuda")
pipe = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
custom_pipeline="pipeline_prompt2prompt"
).to("cuda")
prompts = ["A turtle playing with a ball",
"A monkey playing with a ball"]
prompts = [
"A turtle playing with a ball",
"A monkey playing with a ball"
]
cross_attention_kwargs = {
"edit_type": "replace",
@@ -3267,7 +3283,15 @@ cross_attention_kwargs = {
"self_replace_steps": 0.4
}
outputs = pipe(prompt=prompts, height=512, width=512, num_inference_steps=50, cross_attention_kwargs=cross_attention_kwargs)
outputs = pipe(
prompt=prompts,
height=512,
width=512,
num_inference_steps=50,
cross_attention_kwargs=cross_attention_kwargs
)
outputs.images[0].save("output_image_0.png")
```
And abbreviated examples for the other edits:
@@ -5309,3 +5333,103 @@ output = pipeline_for_inversion(
pipeline.export_latents_to_video(output.inverse_latents[-1], "path/to/inverse_video.mp4", fps=8)
pipeline.export_latents_to_video(output.recon_latents[-1], "path/to/recon_video.mp4", fps=8)
```
# FaithDiff Stable Diffusion XL Pipeline
[Project](https://jychen9811.github.io/FaithDiff_page/) / [GitHub](https://github.com/JyChen9811/FaithDiff/)
This the implementation of the FaithDiff pipeline for SDXL, adapted to use the HuggingFace Diffusers.
For more details see the project links above.
## Example Usage
This example upscale and restores a low-quality image. The input image has a resolution of 512x512 and will be upscaled at a scale of 2x, to a final resolution of 1024x1024. It is possible to upscale to a larger scale, but it is recommended that the input image be at least 1024x1024 in these cases. To upscale this image by 4x, for example, it would be recommended to re-input the result into a new 2x processing, thus performing progressive scaling.
````py
import random
import numpy as np
import torch
from diffusers import DiffusionPipeline, AutoencoderKL, UniPCMultistepScheduler
from huggingface_hub import hf_hub_download
from diffusers.utils import load_image
from PIL import Image
device = "cuda"
dtype = torch.float16
MAX_SEED = np.iinfo(np.int32).max
# Download weights for additional unet layers
model_file = hf_hub_download(
"jychen9811/FaithDiff",
filename="FaithDiff.bin", local_dir="./proc_data/faithdiff", local_dir_use_symlinks=False
)
# Initialize the models and pipeline
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=dtype)
model_id = "SG161222/RealVisXL_V4.0"
pipe = DiffusionPipeline.from_pretrained(
model_id,
torch_dtype=dtype,
vae=vae,
unet=None, #<- Do not load with original model.
custom_pipeline="pipeline_faithdiff_stable_diffusion_xl",
use_safetensors=True,
variant="fp16",
).to(device)
# Here we need use pipeline internal unet model
pipe.unet = pipe.unet_model.from_pretrained(model_id, subfolder="unet", variant="fp16", use_safetensors=True)
# Load aditional layers to the model
pipe.unet.load_additional_layers(weight_path="proc_data/faithdiff/FaithDiff.bin", dtype=dtype)
# Enable vae tiling
pipe.set_encoder_tile_settings()
pipe.enable_vae_tiling()
# Optimization
pipe.enable_model_cpu_offload()
# Set selected scheduler
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
#input params
prompt = "The image features a woman in her 55s with blonde hair and a white shirt, smiling at the camera. She appears to be in a good mood and is wearing a white scarf around her neck. "
upscale = 2 # scale here
start_point = "lr" # or "noise"
latent_tiled_overlap = 0.5
latent_tiled_size = 1024
# Load image
lq_image = load_image("https://huggingface.co/datasets/DEVAIEXP/assets/resolve/main/woman.png")
original_height = lq_image.height
original_width = lq_image.width
print(f"Current resolution: H:{original_height} x W:{original_width}")
width = original_width * int(upscale)
height = original_height * int(upscale)
print(f"Final resolution: H:{height} x W:{width}")
# Restoration
image = lq_image.resize((width, height), Image.LANCZOS)
input_image, width_init, height_init, width_now, height_now = pipe.check_image_size(image)
generator = torch.Generator(device=device).manual_seed(random.randint(0, MAX_SEED))
gen_image = pipe(lr_img=input_image,
prompt = prompt,
num_inference_steps=20,
guidance_scale=5,
generator=generator,
start_point=start_point,
height = height_now,
width=width_now,
overlap=latent_tiled_overlap,
target_size=(latent_tiled_size, latent_tiled_size)
).images[0]
cropped_image = gen_image.crop((0, 0, width_init, height_init))
cropped_image.save("data/result.png")
````
### Result
[<img src="https://huggingface.co/datasets/DEVAIEXP/assets/resolve/main/faithdiff_restored.PNG" width="512px" height="512px"/>](https://imgsli.com/MzY1NzE2)
+17 -2
View File
@@ -1773,7 +1773,7 @@ class SDXLLongPromptWeightingPipeline(
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
" `pipeline.unet` or your `mask_image` or `image` input."
)
elif num_channels_unet != 4:
@@ -1924,7 +1924,22 @@ class SDXLLongPromptWeightingPipeline(
self.upcast_vae()
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
# unscale/denormalize the latents
# denormalize with the mean and std if available and not None
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
if has_latents_mean and has_latents_std:
latents_mean = (
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
)
latents_std = (
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
)
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
else:
latents = latents / self.vae.config.scaling_factor
image = self.vae.decode(latents, return_dict=False)[0]
# cast back to fp16 if needed
if needs_upcasting:
File diff suppressed because it is too large Load Diff
+661
View File
@@ -0,0 +1,661 @@
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import html
import types
from typing import Any, Callable, Dict, List, Optional, Union
import ftfy
import regex as re
import torch
from transformers import AutoTokenizer, UMT5EncoderModel
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.loaders import WanLoraLoaderMixin
from diffusers.models import AutoencoderKLWan, WanTransformer3DModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers.utils import export_to_video
>>> from diffusers import AutoencoderKLWan
>>> from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
>>> from examples.community.pipeline_stg_wan import WanSTGPipeline
>>> # Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers
>>> model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
>>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
>>> pipe = WanSTGPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
>>> flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
>>> pipe.to("cuda")
>>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
>>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
>>> # Configure STG mode options
>>> stg_applied_layers_idx = [8] # Layer indices from 0 to 39 for 14b or 0 to 29 for 1.3b
>>> stg_scale = 1.0 # Set 0.0 for CFG
>>> output = pipe(
... prompt=prompt,
... negative_prompt=negative_prompt,
... height=720,
... width=1280,
... num_frames=81,
... guidance_scale=5.0,
... stg_applied_layers_idx=stg_applied_layers_idx,
... stg_scale=stg_scale,
... ).frames[0]
>>> export_to_video(output, "output.mp4", fps=16)
```
"""
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
def prompt_clean(text):
text = whitespace_clean(basic_clean(text))
return text
def forward_with_stg(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
rotary_emb: torch.Tensor,
) -> torch.Tensor:
return hidden_states
def forward_without_stg(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
rotary_emb: torch.Tensor,
) -> torch.Tensor:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
self.scale_shift_table + temb.float()
).chunk(6, dim=1)
# 1. Self-attention
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
# 2. Cross-attention
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
hidden_states = hidden_states + attn_output
# 3. Feed-forward
norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(hidden_states)
ff_output = self.ffn(norm_hidden_states)
hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
return hidden_states
class WanSTGPipeline(DiffusionPipeline, WanLoraLoaderMixin):
r"""
Pipeline for text-to-video generation using Wan.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Args:
tokenizer ([`T5Tokenizer`]):
Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
text_encoder ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
transformer ([`WanTransformer3DModel`]):
Conditional Transformer to denoise the input latents.
scheduler ([`UniPCMultistepScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLWan`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
tokenizer: AutoTokenizer,
text_encoder: UMT5EncoderModel,
transformer: WanTransformer3DModel,
vae: AutoencoderKLWan,
scheduler: FlowMatchEulerDiscreteScheduler,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_videos_per_prompt: int = 1,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt = [prompt_clean(u) for u in prompt]
batch_size = len(prompt)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_attention_mask=True,
return_tensors="pt",
)
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
seq_lens = mask.gt(0).sum(dim=1).long()
prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
return prompt_embeds
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = self._get_t5_prompt_embeds(
prompt=negative_prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
return prompt_embeds, negative_prompt_embeds
def check_inputs(
self,
prompt,
negative_prompt,
height,
width,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif negative_prompt is not None and (
not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
):
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
def prepare_latents(
self,
batch_size: int,
num_channels_latents: int = 16,
height: int = 480,
width: int = 832,
num_frames: int = 81,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if latents is not None:
return latents.to(device=device, dtype=dtype)
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
shape = (
batch_size,
num_channels_latents,
num_latent_frames,
int(height) // self.vae_scale_factor_spatial,
int(width) // self.vae_scale_factor_spatial,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0
@property
def do_spatio_temporal_guidance(self):
return self._stg_scale > 0.0
@property
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@property
def attention_kwargs(self):
return self._attention_kwargs
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: Union[str, List[str]] = None,
height: int = 480,
width: int = 832,
num_frames: int = 81,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "np",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
stg_applied_layers_idx: Optional[List[int]] = [3, 8, 16],
stg_scale: Optional[float] = 0.0,
):
r"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, defaults to `480`):
The height in pixels of the generated image.
width (`int`, defaults to `832`):
The width in pixels of the generated image.
num_frames (`int`, defaults to `81`):
The number of frames in the generated video.
num_inference_steps (`int`, defaults to `50`):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to `5.0`):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
The dtype to use for the torch.amp.autocast.
Examples:
Returns:
[`~WanPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where
the first element is a list with the generated images and the second element is a list of `bool`s
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
negative_prompt,
height,
width,
prompt_embeds,
negative_prompt_embeds,
callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._stg_scale = stg_scale
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
device = self._execution_device
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
device=device,
)
transformer_dtype = self.transformer.dtype
prompt_embeds = prompt_embeds.to(transformer_dtype)
if negative_prompt_embeds is not None:
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
num_frames,
torch.float32,
device,
generator,
latents,
)
# 6. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = latents.to(transformer_dtype)
timestep = t.expand(latents.shape[0])
if self.do_spatio_temporal_guidance:
for idx, block in enumerate(self.transformer.blocks):
block.forward = types.MethodType(forward_without_stg, block)
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
noise_uncond = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_spatio_temporal_guidance:
for idx, block in enumerate(self.transformer.blocks):
if idx in stg_applied_layers_idx:
block.forward = types.MethodType(forward_with_stg, block)
noise_perturb = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = (
noise_uncond
+ guidance_scale * (noise_pred - noise_uncond)
+ self._stg_scale * (noise_pred - noise_perturb)
)
else:
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if not output_type == "latent":
latents = latents.to(self.vae.dtype)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents.device, latents.dtype
)
latents = latents / latents_std + latents_mean
video = self.vae.decode(latents, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return WanPipelineOutput(frames=video)
+18 -7
View File
@@ -927,17 +927,22 @@ def main(args):
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
num_training_steps_for_scheduler = (
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
)
else:
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_warmup_steps=num_warmup_steps_for_scheduler,
num_training_steps=num_training_steps_for_scheduler,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
@@ -962,8 +967,14 @@ def main(args):
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
logger.warning(
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
f"This inconsistency may result in the learning rate scheduler not functioning properly."
)
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+19 -7
View File
@@ -895,7 +895,10 @@ def _encode_prompt_with_t5(
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
dtype = text_encoder.dtype
if hasattr(text_encoder, "module"):
dtype = text_encoder.module.dtype
else:
dtype = text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
@@ -936,9 +939,13 @@ def _encode_prompt_with_clip(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
if hasattr(text_encoder, "module"):
dtype = text_encoder.module.dtype
else:
dtype = text_encoder.dtype
# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -958,7 +965,12 @@ def encode_prompt(
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
dtype = text_encoders[0].dtype
if hasattr(text_encoders[0], "module"):
dtype = text_encoders[0].module.dtype
else:
dtype = text_encoders[0].dtype
device = device if device is not None else text_encoders[1].device
pooled_prompt_embeds = _encode_prompt_with_clip(
text_encoder=text_encoders[0],
@@ -1590,7 +1602,7 @@ def main(args):
)
# handle guidance
if accelerator.unwrap_model(transformer).config.guidance_embeds:
if unwrap_model(transformer).config.guidance_embeds:
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
guidance = guidance.expand(model_input.shape[0])
else:
@@ -1716,9 +1728,9 @@ def main(args):
pipeline = FluxPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
text_encoder=accelerator.unwrap_model(text_encoder_one, keep_fp32_wrapper=False),
text_encoder_2=accelerator.unwrap_model(text_encoder_two, keep_fp32_wrapper=False),
transformer=accelerator.unwrap_model(transformer, keep_fp32_wrapper=False),
text_encoder=unwrap_model(text_encoder_one, keep_fp32_wrapper=False),
text_encoder_2=unwrap_model(text_encoder_two, keep_fp32_wrapper=False),
transformer=unwrap_model(transformer, keep_fp32_wrapper=False),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
@@ -177,16 +177,25 @@ def log_validation(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
autocast_ctx = nullcontext()
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
with autocast_ctx:
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
# pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
pipeline_args["prompt"], prompt_2=pipeline_args["prompt"]
)
images = []
for _ in range(args.num_validation_images):
with autocast_ctx:
image = pipeline(
prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator
).images[0]
images.append(image)
for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
@@ -203,8 +212,7 @@ def log_validation(
)
del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
free_memory()
return images
@@ -932,7 +940,10 @@ def _encode_prompt_with_t5(
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
dtype = text_encoder.dtype
if hasattr(text_encoder, "module"):
dtype = text_encoder.module.dtype
else:
dtype = text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
@@ -973,9 +984,13 @@ def _encode_prompt_with_clip(
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
if hasattr(text_encoder, "module"):
dtype = text_encoder.module.dtype
else:
dtype = text_encoder.dtype
# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -994,7 +1009,11 @@ def encode_prompt(
text_input_ids_list=None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
dtype = text_encoders[0].dtype
if hasattr(text_encoders[0], "module"):
dtype = text_encoders[0].module.dtype
else:
dtype = text_encoders[0].dtype
pooled_prompt_embeds = _encode_prompt_with_clip(
text_encoder=text_encoders[0],
@@ -1619,7 +1638,7 @@ def main(args):
if args.train_text_encoder:
text_encoder_one.train()
# set top parameter requires_grad = True for gradient checkpointing works
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
for step, batch in enumerate(train_dataloader):
models_to_accumulate = [transformer]
@@ -1710,7 +1729,7 @@ def main(args):
)
# handle guidance
if accelerator.unwrap_model(transformer).config.guidance_embeds:
if unwrap_model(transformer).config.guidance_embeds:
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
guidance = guidance.expand(model_input.shape[0])
else:
@@ -1828,9 +1847,9 @@ def main(args):
pipeline = FluxPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
text_encoder=accelerator.unwrap_model(text_encoder_one),
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
transformer=accelerator.unwrap_model(transformer),
text_encoder=unwrap_model(text_encoder_one),
text_encoder_2=unwrap_model(text_encoder_two),
transformer=unwrap_model(transformer),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
@@ -669,6 +669,16 @@ def parse_args(input_args=None):
),
)
parser.add_argument(
"--image_interpolation_mode",
type=str,
default="lanczos",
choices=[
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
],
help="The image interpolation method to use for resizing images.",
)
if input_args is not None:
args = parser.parse_args(input_args)
else:
@@ -790,7 +800,12 @@ class DreamBoothDataset(Dataset):
self.original_sizes = []
self.crop_top_lefts = []
self.pixel_values = []
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
if interpolation is None:
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
train_resize = transforms.Resize(size, interpolation=interpolation)
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
train_flip = transforms.RandomHorizontalFlip(p=1.0)
train_transforms = transforms.Compose(
@@ -49,6 +49,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInstructPix2P
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version, deprecate, is_wandb_available
from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
@@ -418,7 +419,7 @@ def convert_to_np(image, resolution):
def download_image(url):
image = PIL.Image.open(requests.get(url, stream=True).raw)
image = PIL.Image.open(requests.get(url, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT).raw)
image = PIL.ImageOps.exif_transpose(image)
image = image.convert("RGB")
return image
+12 -4
View File
@@ -1,20 +1,27 @@
# AnyTextPipeline Pipeline
# AnyTextPipeline
Project page: https://aigcdesigngroup.github.io/homepage_anytext
"AnyText comprises a diffusion pipeline with two primary elements: an auxiliary latent module and a text embedding module. The former uses inputs like text glyph, position, and masked image to generate latent features for text generation or editing. The latter employs an OCR model for encoding stroke data as embeddings, which blend with image caption embeddings from the tokenizer to generate texts that seamlessly integrate with the background. We employed text-control diffusion loss and text perceptual loss for training to further enhance writing accuracy."
Each text line that needs to be generated should be enclosed in double quotes. For any usage questions, please refer to the [paper](https://arxiv.org/abs/2311.03054).
> **Note:** Each text line that needs to be generated should be enclosed in double quotes.
For any usage questions, please refer to the [paper](https://arxiv.org/abs/2311.03054).
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/tolgacangoz/b87ec9d2f265b448dd947c9d4a0da389/anytext.ipynb)
```py
# This example requires the `anytext_controlnet.py` file:
# !git clone --depth 1 https://github.com/huggingface/diffusers.git
# %cd diffusers/examples/research_projects/anytext
# Let's choose a font file shared by an HF staff:
# !wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf
import torch
from diffusers import DiffusionPipeline
from anytext_controlnet import AnyTextControlNetModel
from diffusers.utils import load_image
# I chose a font file shared by an HF staff:
# !wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf
anytext_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16,
variant="fp16",)
@@ -26,6 +33,7 @@ pipe = DiffusionPipeline.from_pretrained("tolgacangoz/anytext", font_path="arial
# generate image
prompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream'
draw_pos = load_image("https://raw.githubusercontent.com/tyxsspa/AnyText/refs/heads/main/example_images/gen9.png")
# There are two modes: "generate" and "edit". "edit" mode requires `ori_image` parameter for the image to be edited.
image = pipe(prompt, num_inference_steps=20, mode="generate", draw_pos=draw_pos,
).images[0]
image
+11 -5
View File
@@ -146,14 +146,17 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> # This example requires the `anytext_controlnet.py` file:
>>> # !git clone --depth 1 https://github.com/huggingface/diffusers.git
>>> # %cd diffusers/examples/research_projects/anytext
>>> # Let's choose a font file shared by an HF staff:
>>> # !wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf
>>> import torch
>>> from diffusers import DiffusionPipeline
>>> from anytext_controlnet import AnyTextControlNetModel
>>> from diffusers.utils import load_image
>>> # I chose a font file shared by an HF staff:
>>> !wget https://huggingface.co/spaces/ysharma/TranslateQuotesInImageForwards/resolve/main/arial-unicode-ms.ttf
>>> anytext_controlnet = AnyTextControlNetModel.from_pretrained("tolgacangoz/anytext-controlnet", torch_dtype=torch.float16,
... variant="fp16",)
>>> pipe = DiffusionPipeline.from_pretrained("tolgacangoz/anytext", font_path="arial-unicode-ms.ttf",
@@ -165,6 +168,7 @@ EXAMPLE_DOC_STRING = """
>>> # generate image
>>> prompt = 'photo of caramel macchiato coffee on the table, top-down perspective, with "Any" "Text" written on it using cream'
>>> draw_pos = load_image("https://raw.githubusercontent.com/tyxsspa/AnyText/refs/heads/main/example_images/gen9.png")
>>> # There are two modes: "generate" and "edit". "edit" mode requires `ori_image` parameter for the image to be edited.
>>> image = pipe(prompt, num_inference_steps=20, mode="generate", draw_pos=draw_pos,
... ).images[0]
>>> image
@@ -257,11 +261,11 @@ class EmbeddingManager(ModelMixin, ConfigMixin):
idx = tokenized_text[i] == self.placeholder_token.to(device)
if sum(idx) > 0:
if i >= len(self.text_embs_all):
print("truncation for log images...")
logger.warning("truncation for log images...")
break
text_emb = torch.cat(self.text_embs_all[i], dim=0)
if sum(idx) != len(text_emb):
print("truncation for long caption...")
logger.warning("truncation for long caption...")
text_emb = text_emb.to(embedded_text.device)
embedded_text[i][idx] = text_emb[: sum(idx)]
return embedded_text
@@ -1058,6 +1062,8 @@ class AuxiliaryLatentModule(ModelMixin, ConfigMixin):
raise ValueError(f"Can't read ori_image image from {ori_image}!")
elif isinstance(ori_image, torch.Tensor):
ori_image = ori_image.cpu().numpy()
elif isinstance(ori_image, PIL.Image.Image):
ori_image = np.array(ori_image.convert("RGB"))
else:
if not isinstance(ori_image, np.ndarray):
raise ValueError(f"Unknown format of ori_image: {type(ori_image)}")
@@ -627,6 +627,7 @@ def main(args):
ema_vae = EMAModel(vae.parameters(), model_cls=AutoencoderKL, model_config=vae.config)
perceptual_loss = lpips.LPIPS(net="vgg").eval()
discriminator = NLayerDiscriminator(input_nc=3, n_layers=3, use_actnorm=False).apply(weights_init)
discriminator = torch.nn.SyncBatchNorm.convert_sync_batchnorm(discriminator)
# Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files)
def unwrap_model(model):
@@ -951,13 +952,20 @@ def main(args):
logits_fake = discriminator(reconstructions)
disc_loss = hinge_d_loss if args.disc_loss == "hinge" else vanilla_d_loss
disc_factor = args.disc_factor if global_step >= args.disc_start else 0.0
disc_loss = disc_factor * disc_loss(logits_real, logits_fake)
d_loss = disc_factor * disc_loss(logits_real, logits_fake)
logs = {
"disc_loss": disc_loss.detach().mean().item(),
"disc_loss": d_loss.detach().mean().item(),
"logits_real": logits_real.detach().mean().item(),
"logits_fake": logits_fake.detach().mean().item(),
"disc_lr": disc_lr_scheduler.get_last_lr()[0],
}
accelerator.backward(d_loss)
if accelerator.sync_gradients:
params_to_clip = discriminator.parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
disc_optimizer.step()
disc_lr_scheduler.step()
disc_optimizer.zero_grad(set_to_none=args.set_grads_to_none)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
@@ -54,6 +54,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInstructPix2P
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel, cast_training_params
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, deprecate, is_wandb_available
from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
@@ -475,7 +476,7 @@ def convert_to_np(image, resolution):
def download_image(url):
image = PIL.Image.open(requests.get(url, stream=True).raw)
image = PIL.Image.open(requests.get(url, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT).raw)
image = PIL.ImageOps.exif_transpose(image)
image = image.convert("RGB")
return image
@@ -59,6 +59,7 @@ from diffusers.schedulers import (
UnCLIPScheduler,
)
from diffusers.utils import is_accelerate_available, logging
from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
if is_accelerate_available():
@@ -1435,7 +1436,7 @@ def download_from_original_stable_diffusion_ckpt(
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml"
if config_url is not None:
original_config_file = BytesIO(requests.get(config_url).content)
original_config_file = BytesIO(requests.get(config_url, timeout=DIFFUSERS_REQUEST_TIMEOUT).content)
else:
with open(original_config_file, "r") as f:
original_config_file = f.read()
@@ -1,8 +1,6 @@
# Generating images using Flux and PyTorch/XLA
The `flux_inference` script shows how to do image generation using Flux on TPU devices using PyTorch/XLA. It uses the pallas kernel for flash attention for faster generation.
It has been tested on [Trillium](https://cloud.google.com/blog/products/compute/introducing-trillium-6th-gen-tpus) TPU versions. No other TPU types have been tested.
The `flux_inference` script shows how to do image generation using Flux on TPU devices using PyTorch/XLA. It uses the pallas kernel for flash attention for faster generation using custom flash block sizes for better performance on [Trillium](https://cloud.google.com/blog/products/compute/introducing-trillium-6th-gen-tpus) TPU versions. No other TPU types have been tested.
## Create TPU
@@ -23,20 +21,23 @@ Verify that PyTorch and PyTorch/XLA were installed correctly:
python3 -c "import torch; import torch_xla;"
```
Install dependencies
Clone the diffusers repo and install dependencies
```bash
git clone https://github.com/huggingface/diffusers.git
cd diffusers
pip install transformers accelerate sentencepiece structlog
pushd ../../..
pip install .
popd
cd examples/research_projects/pytorch_xla/inference/flux/
```
## Run the inference job
### Authenticate
Run the following command to authenticate your token in order to download Flux weights.
**Gated Model**
As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows youve accepted the gate. Use the command below to log in:
```bash
huggingface-cli login
@@ -11,6 +11,7 @@ from diffusion import sampling
from torch import nn
from diffusers import DanceDiffusionPipeline, IPNDMScheduler, UNet1DModel
from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
MODELS_MAP = {
@@ -74,7 +75,7 @@ class DiffusionUncond(nn.Module):
def download(model_name):
url = MODELS_MAP[model_name]["url"]
r = requests.get(url, stream=True)
r = requests.get(url, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT)
local_filename = f"./{model_name}.ckpt"
with open(local_filename, "wb") as fp:
+22 -1
View File
@@ -160,8 +160,9 @@ TRANSFORMER_CONFIGS = {
"pooled_projection_dim": 768,
"rope_theta": 256.0,
"rope_axes_dim": (16, 56, 56),
"image_condition_type": None,
},
"HYVideo-T/2-I2V": {
"HYVideo-T/2-I2V-33ch": {
"in_channels": 16 * 2 + 1,
"out_channels": 16,
"num_attention_heads": 24,
@@ -178,6 +179,26 @@ TRANSFORMER_CONFIGS = {
"pooled_projection_dim": 768,
"rope_theta": 256.0,
"rope_axes_dim": (16, 56, 56),
"image_condition_type": "latent_concat",
},
"HYVideo-T/2-I2V-16ch": {
"in_channels": 16,
"out_channels": 16,
"num_attention_heads": 24,
"attention_head_dim": 128,
"num_layers": 20,
"num_single_layers": 40,
"num_refiner_layers": 2,
"mlp_ratio": 4.0,
"patch_size": 2,
"patch_size_t": 1,
"qk_norm": "rms_norm",
"guidance_embeds": True,
"text_embed_dim": 4096,
"pooled_projection_dim": 768,
"rope_theta": 256.0,
"rope_axes_dim": (16, 56, 56),
"image_condition_type": "token_replace",
},
}
+89 -15
View File
@@ -74,6 +74,32 @@ VAE_091_RENAME_DICT = {
"last_scale_shift_table": "scale_shift_table",
}
VAE_095_RENAME_DICT = {
# decoder
"up_blocks.0": "mid_block",
"up_blocks.1": "up_blocks.0.upsamplers.0",
"up_blocks.2": "up_blocks.0",
"up_blocks.3": "up_blocks.1.upsamplers.0",
"up_blocks.4": "up_blocks.1",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
"up_blocks.7": "up_blocks.3.upsamplers.0",
"up_blocks.8": "up_blocks.3",
# encoder
"down_blocks.0": "down_blocks.0",
"down_blocks.1": "down_blocks.0.downsamplers.0",
"down_blocks.2": "down_blocks.1",
"down_blocks.3": "down_blocks.1.downsamplers.0",
"down_blocks.4": "down_blocks.2",
"down_blocks.5": "down_blocks.2.downsamplers.0",
"down_blocks.6": "down_blocks.3",
"down_blocks.7": "down_blocks.3.downsamplers.0",
"down_blocks.8": "mid_block",
# common
"last_time_embedder": "time_embedder",
"last_scale_shift_table": "scale_shift_table",
}
VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_,
"per_channel_statistics.mean-of-means": remove_keys_,
@@ -81,10 +107,6 @@ VAE_SPECIAL_KEYS_REMAP = {
"model.diffusion_model": remove_keys_,
}
VAE_091_SPECIAL_KEYS_REMAP = {
"timestep_scale_multiplier": remove_keys_,
}
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
state_dict = saved_dict
@@ -104,12 +126,16 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
def convert_transformer(
ckpt_path: str,
dtype: torch.dtype,
version: str = "0.9.0",
):
PREFIX_KEY = "model.diffusion_model."
original_state_dict = get_state_dict(load_file(ckpt_path))
config = {}
if version == "0.9.5":
config["_use_causal_rope_fix"] = True
with init_empty_weights():
transformer = LTXVideoTransformer3DModel()
transformer = LTXVideoTransformer3DModel(**config)
for key in list(original_state_dict.keys()):
new_key = key[:]
@@ -161,12 +187,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (128, 256, 512, 512),
"down_block_types": (
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
),
"decoder_block_out_channels": (128, 256, 512, 512),
"layers_per_block": (4, 3, 3, 3, 4),
"decoder_layers_per_block": (4, 3, 3, 3, 4),
"spatio_temporal_scaling": (True, True, True, False),
"decoder_spatio_temporal_scaling": (True, True, True, False),
"decoder_inject_noise": (False, False, False, False, False),
"downsample_type": ("conv", "conv", "conv", "conv"),
"upsample_residual": (False, False, False, False),
"upsample_factor": (1, 1, 1, 1),
"patch_size": 4,
@@ -183,12 +216,19 @@ def get_vae_config(version: str) -> Dict[str, Any]:
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (128, 256, 512, 512),
"down_block_types": (
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
),
"decoder_block_out_channels": (256, 512, 1024),
"layers_per_block": (4, 3, 3, 3, 4),
"decoder_layers_per_block": (5, 6, 7, 8),
"spatio_temporal_scaling": (True, True, True, False),
"decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (True, True, True, False),
"downsample_type": ("conv", "conv", "conv", "conv"),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": True,
@@ -200,7 +240,38 @@ def get_vae_config(version: str) -> Dict[str, Any]:
"decoder_causal": False,
}
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
VAE_SPECIAL_KEYS_REMAP.update(VAE_091_SPECIAL_KEYS_REMAP)
elif version == "0.9.5":
config = {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (128, 256, 512, 1024, 2048),
"down_block_types": (
"LTXVideo095DownBlock3D",
"LTXVideo095DownBlock3D",
"LTXVideo095DownBlock3D",
"LTXVideo095DownBlock3D",
),
"decoder_block_out_channels": (256, 512, 1024),
"layers_per_block": (4, 6, 6, 2, 2),
"decoder_layers_per_block": (5, 5, 5, 5),
"spatio_temporal_scaling": (True, True, True, True),
"decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (False, False, False, False),
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": True,
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
"scaling_factor": 1.0,
"encoder_causal": True,
"decoder_causal": False,
"spatial_compression_ratio": 32,
"temporal_compression_ratio": 8,
}
VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
return config
@@ -223,7 +294,7 @@ def get_args():
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
parser.add_argument(
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1"], help="Version of the LTX model"
"--version", type=str, default="0.9.0", choices=["0.9.0", "0.9.1", "0.9.5"], help="Version of the LTX model"
)
return parser.parse_args()
@@ -277,14 +348,17 @@ if __name__ == "__main__":
for param in text_encoder.parameters():
param.data = param.data.contiguous()
scheduler = FlowMatchEulerDiscreteScheduler(
use_dynamic_shifting=True,
base_shift=0.95,
max_shift=2.05,
base_image_seq_len=1024,
max_image_seq_len=4096,
shift_terminal=0.1,
)
if args.version == "0.9.5":
scheduler = FlowMatchEulerDiscreteScheduler(use_dynamic_shifting=False)
else:
scheduler = FlowMatchEulerDiscreteScheduler(
use_dynamic_shifting=True,
base_shift=0.95,
max_shift=2.05,
base_image_seq_len=1024,
max_image_seq_len=4096,
shift_terminal=0.1,
)
pipe = LTXPipeline(
scheduler=scheduler,
+200 -62
View File
@@ -16,7 +16,9 @@ from diffusers import (
DPMSolverMultistepScheduler,
FlowMatchEulerDiscreteScheduler,
SanaPipeline,
SanaSprintPipeline,
SanaTransformer2DModel,
SCMScheduler,
)
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.utils.import_utils import is_accelerate_available
@@ -25,6 +27,10 @@ from diffusers.utils.import_utils import is_accelerate_available
CTX = init_empty_weights if is_accelerate_available else nullcontext
ckpt_ids = [
"Efficient-Large-Model/Sana_Sprint_0.6B_1024px/checkpoints/Sana_Sprint_0.6B_1024px.pth"
"Efficient-Large-Model/Sana_Sprint_1.6B_1024px/checkpoints/Sana_Sprint_1.6B_1024px.pth"
"Efficient-Large-Model/SANA1.5_4.8B_1024px/checkpoints/SANA1.5_4.8B_1024px.pth",
"Efficient-Large-Model/SANA1.5_1.6B_1024px/checkpoints/SANA1.5_1.6B_1024px.pth",
"Efficient-Large-Model/Sana_1600M_4Kpx_BF16/checkpoints/Sana_1600M_4Kpx_BF16.pth",
"Efficient-Large-Model/Sana_1600M_2Kpx_BF16/checkpoints/Sana_1600M_2Kpx_BF16.pth",
"Efficient-Large-Model/Sana_1600M_1024px_MultiLing/checkpoints/Sana_1600M_1024px_MultiLing.pth",
@@ -72,15 +78,42 @@ def main(args):
converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
# AdaLN-single LN
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
"t_embedder.mlp.0.weight"
)
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
"t_embedder.mlp.2.weight"
)
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
# Handle different time embedding structure based on model type
if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
# For Sana Sprint, the time embedding structure is different
converted_state_dict["time_embed.timestep_embedder.linear_1.weight"] = state_dict.pop(
"t_embedder.mlp.0.weight"
)
converted_state_dict["time_embed.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
converted_state_dict["time_embed.timestep_embedder.linear_2.weight"] = state_dict.pop(
"t_embedder.mlp.2.weight"
)
converted_state_dict["time_embed.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
# Guidance embedder for Sana Sprint
converted_state_dict["time_embed.guidance_embedder.linear_1.weight"] = state_dict.pop(
"cfg_embedder.mlp.0.weight"
)
converted_state_dict["time_embed.guidance_embedder.linear_1.bias"] = state_dict.pop("cfg_embedder.mlp.0.bias")
converted_state_dict["time_embed.guidance_embedder.linear_2.weight"] = state_dict.pop(
"cfg_embedder.mlp.2.weight"
)
converted_state_dict["time_embed.guidance_embedder.linear_2.bias"] = state_dict.pop("cfg_embedder.mlp.2.bias")
else:
# Original Sana time embedding structure
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
"t_embedder.mlp.0.weight"
)
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop(
"t_embedder.mlp.0.bias"
)
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
"t_embedder.mlp.2.weight"
)
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop(
"t_embedder.mlp.2.bias"
)
# Shared norm.
converted_state_dict["time_embed.linear.weight"] = state_dict.pop("t_block.1.weight")
@@ -96,14 +129,22 @@ def main(args):
flow_shift = 3.0
# model config
if args.model_type == "SanaMS_1600M_P1_D20":
if args.model_type in ["SanaMS_1600M_P1_D20", "SanaSprint_1600M_P1_D20", "SanaMS1.5_1600M_P1_D20"]:
layer_num = 20
elif args.model_type == "SanaMS_600M_P1_D28":
elif args.model_type in ["SanaMS_600M_P1_D28", "SanaSprint_600M_P1_D28"]:
layer_num = 28
elif args.model_type == "SanaMS_4800M_P1_D60":
layer_num = 60
else:
raise ValueError(f"{args.model_type} is not supported.")
# Positional embedding interpolation scale.
interpolation_scale = {512: None, 1024: None, 2048: 1.0, 4096: 2.0}
qk_norm = (
"rms_norm_across_heads"
if args.model_type
in ["SanaMS1.5_1600M_P1_D20", "SanaMS1.5_4800M_P1_D60", "SanaSprint_600M_P1_D28", "SanaSprint_1600M_P1_D20"]
else None
)
for depth in range(layer_num):
# Transformer blocks.
@@ -117,6 +158,14 @@ def main(args):
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
if qk_norm is not None:
# Add Q/K normalization for self-attention (attn1) - needed for Sana-Sprint and Sana-1.5
converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_q.weight"] = state_dict.pop(
f"blocks.{depth}.attn.q_norm.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_k.weight"] = state_dict.pop(
f"blocks.{depth}.attn.k_norm.weight"
)
# Projection.
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
f"blocks.{depth}.attn.proj.weight"
@@ -154,6 +203,14 @@ def main(args):
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
if qk_norm is not None:
# Add Q/K normalization for cross-attention (attn2) - needed for Sana-Sprint and Sana-1.5
converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_q.weight"] = state_dict.pop(
f"blocks.{depth}.cross_attn.q_norm.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_k.weight"] = state_dict.pop(
f"blocks.{depth}.cross_attn.k_norm.weight"
)
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop(
f"blocks.{depth}.cross_attn.proj.weight"
@@ -169,24 +226,37 @@ def main(args):
# Transformer
with CTX():
transformer = SanaTransformer2DModel(
in_channels=32,
out_channels=32,
num_attention_heads=model_kwargs[args.model_type]["num_attention_heads"],
attention_head_dim=model_kwargs[args.model_type]["attention_head_dim"],
num_layers=model_kwargs[args.model_type]["num_layers"],
num_cross_attention_heads=model_kwargs[args.model_type]["num_cross_attention_heads"],
cross_attention_head_dim=model_kwargs[args.model_type]["cross_attention_head_dim"],
cross_attention_dim=model_kwargs[args.model_type]["cross_attention_dim"],
caption_channels=2304,
mlp_ratio=2.5,
attention_bias=False,
sample_size=args.image_size // 32,
patch_size=1,
norm_elementwise_affine=False,
norm_eps=1e-6,
interpolation_scale=interpolation_scale[args.image_size],
)
transformer_kwargs = {
"in_channels": 32,
"out_channels": 32,
"num_attention_heads": model_kwargs[args.model_type]["num_attention_heads"],
"attention_head_dim": model_kwargs[args.model_type]["attention_head_dim"],
"num_layers": model_kwargs[args.model_type]["num_layers"],
"num_cross_attention_heads": model_kwargs[args.model_type]["num_cross_attention_heads"],
"cross_attention_head_dim": model_kwargs[args.model_type]["cross_attention_head_dim"],
"cross_attention_dim": model_kwargs[args.model_type]["cross_attention_dim"],
"caption_channels": 2304,
"mlp_ratio": 2.5,
"attention_bias": False,
"sample_size": args.image_size // 32,
"patch_size": 1,
"norm_elementwise_affine": False,
"norm_eps": 1e-6,
"interpolation_scale": interpolation_scale[args.image_size],
}
# Add qk_norm parameter for Sana Sprint
if args.model_type in [
"SanaMS1.5_1600M_P1_D20",
"SanaMS1.5_4800M_P1_D60",
"SanaSprint_600M_P1_D28",
"SanaSprint_1600M_P1_D20",
]:
transformer_kwargs["qk_norm"] = "rms_norm_across_heads"
if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
transformer_kwargs["guidance_embeds"] = True
transformer = SanaTransformer2DModel(**transformer_kwargs)
if is_accelerate_available():
load_model_dict_into_meta(transformer, converted_state_dict)
@@ -196,6 +266,8 @@ def main(args):
try:
state_dict.pop("y_embedder.y_embedding")
state_dict.pop("pos_embed")
state_dict.pop("logvar_linear.weight")
state_dict.pop("logvar_linear.bias")
except KeyError:
print("y_embedder.y_embedding or pos_embed not found in the state_dict")
@@ -210,47 +282,74 @@ def main(args):
print(
colored(
f"Only saving transformer model of {args.model_type}. "
f"Set --save_full_pipeline to save the whole SanaPipeline",
f"Set --save_full_pipeline to save the whole Pipeline",
"green",
attrs=["bold"],
)
)
transformer.save_pretrained(
os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB", variant=variant
os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB"
)
else:
print(colored(f"Saving the whole SanaPipeline containing {args.model_type}", "green", attrs=["bold"]))
print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"]))
# VAE
ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers", torch_dtype=torch.float32)
ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers", torch_dtype=torch.float32)
# Text Encoder
text_encoder_model_path = "google/gemma-2-2b-it"
text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(text_encoder_model_path)
tokenizer.padding_side = "right"
text_encoder = AutoModelForCausalLM.from_pretrained(
text_encoder_model_path, torch_dtype=torch.bfloat16
).get_decoder()
# Scheduler
if args.scheduler_type == "flow-dpm_solver":
scheduler = DPMSolverMultistepScheduler(
flow_shift=flow_shift,
use_flow_sigmas=True,
prediction_type="flow_prediction",
)
elif args.scheduler_type == "flow-euler":
scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
else:
raise ValueError(f"Scheduler type {args.scheduler_type} is not supported")
# Choose the appropriate pipeline and scheduler based on model type
if args.model_type in ["SanaSprint_1600M_P1_D20", "SanaSprint_600M_P1_D28"]:
# Force SCM Scheduler for Sana Sprint regardless of scheduler_type
if args.scheduler_type != "scm":
print(
colored(
f"Warning: Overriding scheduler_type '{args.scheduler_type}' to 'scm' for SanaSprint model",
"yellow",
attrs=["bold"],
)
)
pipe = SanaPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=transformer,
vae=ae,
scheduler=scheduler,
)
pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB", variant=variant)
# SCM Scheduler for Sana Sprint
scheduler_config = {
"prediction_type": "trigflow",
"sigma_data": 0.5,
}
scheduler = SCMScheduler(**scheduler_config)
pipe = SanaSprintPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=transformer,
vae=ae,
scheduler=scheduler,
)
else:
# Original Sana scheduler
if args.scheduler_type == "flow-dpm_solver":
scheduler = DPMSolverMultistepScheduler(
flow_shift=flow_shift,
use_flow_sigmas=True,
prediction_type="flow_prediction",
)
elif args.scheduler_type == "flow-euler":
scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
else:
raise ValueError(f"Scheduler type {args.scheduler_type} is not supported")
pipe = SanaPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=transformer,
vae=ae,
scheduler=scheduler,
)
pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB")
DTYPE_MAPPING = {
@@ -259,12 +358,6 @@ DTYPE_MAPPING = {
"bf16": torch.bfloat16,
}
VARIANT_MAPPING = {
"fp32": None,
"fp16": "fp16",
"bf16": "bf16",
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
@@ -281,10 +374,24 @@ if __name__ == "__main__":
help="Image size of pretrained model, 512, 1024, 2048 or 4096.",
)
parser.add_argument(
"--model_type", default="SanaMS_1600M_P1_D20", type=str, choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"]
"--model_type",
default="SanaMS_1600M_P1_D20",
type=str,
choices=[
"SanaMS_1600M_P1_D20",
"SanaMS_600M_P1_D28",
"SanaMS1.5_1600M_P1_D20",
"SanaMS1.5_4800M_P1_D60",
"SanaSprint_1600M_P1_D20",
"SanaSprint_600M_P1_D28",
],
)
parser.add_argument(
"--scheduler_type", default="flow-dpm_solver", type=str, choices=["flow-dpm_solver", "flow-euler"]
"--scheduler_type",
default="flow-dpm_solver",
type=str,
choices=["flow-dpm_solver", "flow-euler", "scm"],
help="Scheduler type to use. Use 'scm' for Sana Sprint models.",
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipelien elemets in one.")
@@ -309,10 +416,41 @@ if __name__ == "__main__":
"cross_attention_dim": 1152,
"num_layers": 28,
},
"SanaMS1.5_1600M_P1_D20": {
"num_attention_heads": 70,
"attention_head_dim": 32,
"num_cross_attention_heads": 20,
"cross_attention_head_dim": 112,
"cross_attention_dim": 2240,
"num_layers": 20,
},
"SanaMS1.5_4800M_P1_D60": {
"num_attention_heads": 70,
"attention_head_dim": 32,
"num_cross_attention_heads": 20,
"cross_attention_head_dim": 112,
"cross_attention_dim": 2240,
"num_layers": 60,
},
"SanaSprint_600M_P1_D28": {
"num_attention_heads": 36,
"attention_head_dim": 32,
"num_cross_attention_heads": 16,
"cross_attention_head_dim": 72,
"cross_attention_dim": 1152,
"num_layers": 28,
},
"SanaSprint_1600M_P1_D20": {
"num_attention_heads": 70,
"attention_head_dim": 32,
"num_cross_attention_heads": 20,
"cross_attention_head_dim": 112,
"cross_attention_dim": 2240,
"num_layers": 20,
},
}
device = "cuda" if torch.cuda.is_available() else "cpu"
weight_dtype = DTYPE_MAPPING[args.dtype]
variant = VARIANT_MAPPING[args.dtype]
main(args)
+3 -1
View File
@@ -13,6 +13,7 @@ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
renew_vae_attention_paths,
renew_vae_resnet_paths,
)
from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
def custom_convert_ldm_vae_checkpoint(checkpoint, config):
@@ -122,7 +123,8 @@ def vae_pt_to_vae_diffuser(
):
# Only support V1
r = requests.get(
" https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
" https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml",
timeout=DIFFUSERS_REQUEST_TIMEOUT,
)
io_obj = io.BytesIO(r.content)
+17 -1
View File
@@ -131,8 +131,10 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["hooks"].extend(
[
"FasterCacheConfig",
"HookRegistry",
"PyramidAttentionBroadcastConfig",
"apply_faster_cache",
"apply_pyramid_attention_broadcast",
]
)
@@ -271,6 +273,7 @@ else:
"RePaintScheduler",
"SASolverScheduler",
"SchedulerMixin",
"SCMScheduler",
"ScoreSdeVeScheduler",
"TCDScheduler",
"UnCLIPScheduler",
@@ -402,6 +405,7 @@ else:
"LDMTextToImagePipeline",
"LEditsPPPipelineStableDiffusion",
"LEditsPPPipelineStableDiffusionXL",
"LTXConditionPipeline",
"LTXImageToVideoPipeline",
"LTXPipeline",
"Lumina2Pipeline",
@@ -422,6 +426,7 @@ else:
"ReduxImageEncoder",
"SanaPAGPipeline",
"SanaPipeline",
"SanaSprintPipeline",
"SemanticStableDiffusionPipeline",
"ShapEImg2ImgPipeline",
"ShapEPipeline",
@@ -504,6 +509,7 @@ else:
"VQDiffusionPipeline",
"WanImageToVideoPipeline",
"WanPipeline",
"WanVideoToVideoPipeline",
"WuerstchenCombinedPipeline",
"WuerstchenDecoderPipeline",
"WuerstchenPriorPipeline",
@@ -702,7 +708,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from .utils.dummy_pt_objects import * # noqa F403
else:
from .hooks import HookRegistry, PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
from .hooks import (
FasterCacheConfig,
HookRegistry,
PyramidAttentionBroadcastConfig,
apply_faster_cache,
apply_pyramid_attention_broadcast,
)
from .models import (
AllegroTransformer3DModel,
AsymmetricAutoencoderKL,
@@ -835,6 +847,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
RePaintScheduler,
SASolverScheduler,
SchedulerMixin,
SCMScheduler,
ScoreSdeVeScheduler,
TCDScheduler,
UnCLIPScheduler,
@@ -947,6 +960,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LDMTextToImagePipeline,
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
LTXConditionPipeline,
LTXImageToVideoPipeline,
LTXPipeline,
Lumina2Pipeline,
@@ -967,6 +981,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ReduxImageEncoder,
SanaPAGPipeline,
SanaPipeline,
SanaSprintPipeline,
SemanticStableDiffusionPipeline,
ShapEImg2ImgPipeline,
ShapEPipeline,
@@ -1048,6 +1063,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
VQDiffusionPipeline,
WanImageToVideoPipeline,
WanPipeline,
WanVideoToVideoPipeline,
WuerstchenCombinedPipeline,
WuerstchenDecoderPipeline,
WuerstchenPriorPipeline,
+4 -1
View File
@@ -35,6 +35,7 @@ from huggingface_hub.utils import (
validate_hf_hub_args,
)
from requests import HTTPError
from typing_extensions import Self
from . import __version__
from .utils import (
@@ -185,7 +186,9 @@ class ConfigMixin:
)
@classmethod
def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
def from_config(
cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs
) -> Union[Self, Tuple[Self, Dict[str, Any]]]:
r"""
Instantiate a Python class from a config dictionary.
+1
View File
@@ -2,6 +2,7 @@ from ..utils import is_torch_available
if is_torch_available():
from .faster_cache import FasterCacheConfig, apply_faster_cache
from .group_offloading import apply_group_offloading
from .hooks import HookRegistry, ModelHook
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
+653
View File
@@ -0,0 +1,653 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from dataclasses import dataclass
from typing import Any, Callable, List, Optional, Tuple
import torch
from ..models.attention_processor import Attention, MochiAttention
from ..models.modeling_outputs import Transformer2DModelOutput
from ..utils import logging
from .hooks import HookRegistry, ModelHook
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
_FASTER_CACHE_DENOISER_HOOK = "faster_cache_denoiser"
_FASTER_CACHE_BLOCK_HOOK = "faster_cache_block"
_ATTENTION_CLASSES = (Attention, MochiAttention)
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = (
"^blocks.*attn",
"^transformer_blocks.*attn",
"^single_transformer_blocks.*attn",
)
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("^temporal_transformer_blocks.*attn",)
_TRANSFORMER_BLOCK_IDENTIFIERS = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
_UNCOND_COND_INPUT_KWARGS_IDENTIFIERS = (
"hidden_states",
"encoder_hidden_states",
"timestep",
"attention_mask",
"encoder_attention_mask",
)
@dataclass
class FasterCacheConfig:
r"""
Configuration for [FasterCache](https://huggingface.co/papers/2410.19355).
Attributes:
spatial_attention_block_skip_range (`int`, defaults to `2`):
Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will
be skipped `N - 1` times (i.e., cached attention states will be re-used) before computing the new attention
states again.
temporal_attention_block_skip_range (`int`, *optional*, defaults to `None`):
Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will
be skipped `N - 1` times (i.e., cached attention states will be re-used) before computing the new attention
states again.
spatial_attention_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 681)`):
The timestep range within which the spatial attention computation can be skipped without a significant loss
in quality. This is to be determined by the user based on the underlying model. The first value in the
tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for
denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at
timestep 0). For the default values, this would mean that the spatial attention computation skipping will
be applicable only after denoising timestep 681 is reached, and continue until the end of the denoising
process.
temporal_attention_timestep_skip_range (`Tuple[float, float]`, *optional*, defaults to `None`):
The timestep range within which the temporal attention computation can be skipped without a significant
loss in quality. This is to be determined by the user based on the underlying model. The first value in the
tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for
denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at
timestep 0).
low_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(99, 901)`):
The timestep range within which the low frequency weight scaling update is applied. The first value in the
tuple is the lower bound and the second value is the upper bound of the timestep range. The callback
function for the update is called only within this range.
high_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(-1, 301)`):
The timestep range within which the high frequency weight scaling update is applied. The first value in the
tuple is the lower bound and the second value is the upper bound of the timestep range. The callback
function for the update is called only within this range.
alpha_low_frequency (`float`, defaults to `1.1`):
The weight to scale the low frequency updates by. This is used to approximate the unconditional branch from
the conditional branch outputs.
alpha_high_frequency (`float`, defaults to `1.1`):
The weight to scale the high frequency updates by. This is used to approximate the unconditional branch
from the conditional branch outputs.
unconditional_batch_skip_range (`int`, defaults to `5`):
Process the unconditional branch every `N` iterations. If this is set to `N`, the unconditional branch
computation will be skipped `N - 1` times (i.e., cached unconditional branch states will be re-used) before
computing the new unconditional branch states again.
unconditional_batch_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 641)`):
The timestep range within which the unconditional branch computation can be skipped without a significant
loss in quality. This is to be determined by the user based on the underlying model. The first value in the
tuple is the lower bound and the second value is the upper bound.
spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks.*attn1", "transformer_blocks.*attn1", "single_transformer_blocks.*attn1")`):
The identifiers to match the spatial attention blocks in the model. If the name of the block contains any
of these identifiers, FasterCache will be applied to that block. This can either be the full layer names,
partial layer names, or regex patterns. Matching will always be done using a regex match.
temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks.*attn1",)`):
The identifiers to match the temporal attention blocks in the model. If the name of the block contains any
of these identifiers, FasterCache will be applied to that block. This can either be the full layer names,
partial layer names, or regex patterns. Matching will always be done using a regex match.
attention_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`):
The callback function to determine the weight to scale the attention outputs by. This function should take
the attention module as input and return a float value. This is used to approximate the unconditional
branch from the conditional branch outputs. If not provided, the default weight is 0.5 for all timesteps.
Typically, as described in the paper, this weight should gradually increase from 0 to 1 as the inference
progresses. Users are encouraged to experiment and provide custom weight schedules that take into account
the number of inference steps and underlying model behaviour as denoising progresses.
low_frequency_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`):
The callback function to determine the weight to scale the low frequency updates by. If not provided, the
default weight is 1.1 for timesteps within the range specified (as described in the paper).
high_frequency_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`):
The callback function to determine the weight to scale the high frequency updates by. If not provided, the
default weight is 1.1 for timesteps within the range specified (as described in the paper).
tensor_format (`str`, defaults to `"BCFHW"`):
The format of the input tensors. This should be one of `"BCFHW"`, `"BFCHW"`, or `"BCHW"`. The format is
used to split individual latent frames in order for low and high frequency components to be computed.
is_guidance_distilled (`bool`, defaults to `False`):
Whether the model is guidance distilled or not. If the model is guidance distilled, FasterCache will not be
applied at the denoiser-level to skip the unconditional branch computation (as there is none).
_unconditional_conditional_input_kwargs_identifiers (`List[str]`, defaults to `("hidden_states", "encoder_hidden_states", "timestep", "attention_mask", "encoder_attention_mask")`):
The identifiers to match the input kwargs that contain the batchwise-concatenated unconditional and
conditional inputs. If the name of the input kwargs contains any of these identifiers, FasterCache will
split the inputs into unconditional and conditional branches. This must be a list of exact input kwargs
names that contain the batchwise-concatenated unconditional and conditional inputs.
"""
# In the paper and codebase, they hardcode these values to 2. However, it can be made configurable
# after some testing. We default to 2 if these parameters are not provided.
spatial_attention_block_skip_range: int = 2
temporal_attention_block_skip_range: Optional[int] = None
spatial_attention_timestep_skip_range: Tuple[int, int] = (-1, 681)
temporal_attention_timestep_skip_range: Tuple[int, int] = (-1, 681)
# Indicator functions for low/high frequency as mentioned in Equation 11 of the paper
low_frequency_weight_update_timestep_range: Tuple[int, int] = (99, 901)
high_frequency_weight_update_timestep_range: Tuple[int, int] = (-1, 301)
# 1 and 2 as mentioned in Equation 11 of the paper
alpha_low_frequency: float = 1.1
alpha_high_frequency: float = 1.1
# n as described in CFG-Cache explanation in the paper - dependant on the model
unconditional_batch_skip_range: int = 5
unconditional_batch_timestep_skip_range: Tuple[int, int] = (-1, 641)
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
attention_weight_callback: Callable[[torch.nn.Module], float] = None
low_frequency_weight_callback: Callable[[torch.nn.Module], float] = None
high_frequency_weight_callback: Callable[[torch.nn.Module], float] = None
tensor_format: str = "BCFHW"
is_guidance_distilled: bool = False
current_timestep_callback: Callable[[], int] = None
_unconditional_conditional_input_kwargs_identifiers: List[str] = _UNCOND_COND_INPUT_KWARGS_IDENTIFIERS
def __repr__(self) -> str:
return (
f"FasterCacheConfig(\n"
f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n"
f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n"
f" spatial_attention_timestep_skip_range={self.spatial_attention_timestep_skip_range},\n"
f" temporal_attention_timestep_skip_range={self.temporal_attention_timestep_skip_range},\n"
f" low_frequency_weight_update_timestep_range={self.low_frequency_weight_update_timestep_range},\n"
f" high_frequency_weight_update_timestep_range={self.high_frequency_weight_update_timestep_range},\n"
f" alpha_low_frequency={self.alpha_low_frequency},\n"
f" alpha_high_frequency={self.alpha_high_frequency},\n"
f" unconditional_batch_skip_range={self.unconditional_batch_skip_range},\n"
f" unconditional_batch_timestep_skip_range={self.unconditional_batch_timestep_skip_range},\n"
f" spatial_attention_block_identifiers={self.spatial_attention_block_identifiers},\n"
f" temporal_attention_block_identifiers={self.temporal_attention_block_identifiers},\n"
f" tensor_format={self.tensor_format},\n"
f")"
)
class FasterCacheDenoiserState:
r"""
State for [FasterCache](https://huggingface.co/papers/2410.19355) top-level denoiser module.
"""
def __init__(self) -> None:
self.iteration: int = 0
self.low_frequency_delta: torch.Tensor = None
self.high_frequency_delta: torch.Tensor = None
def reset(self):
self.iteration = 0
self.low_frequency_delta = None
self.high_frequency_delta = None
class FasterCacheBlockState:
r"""
State for [FasterCache](https://huggingface.co/papers/2410.19355). Every underlying block that FasterCache is
applied to will have an instance of this state.
"""
def __init__(self) -> None:
self.iteration: int = 0
self.batch_size: int = None
self.cache: Tuple[torch.Tensor, torch.Tensor] = None
def reset(self):
self.iteration = 0
self.batch_size = None
self.cache = None
class FasterCacheDenoiserHook(ModelHook):
_is_stateful = True
def __init__(
self,
unconditional_batch_skip_range: int,
unconditional_batch_timestep_skip_range: Tuple[int, int],
tensor_format: str,
is_guidance_distilled: bool,
uncond_cond_input_kwargs_identifiers: List[str],
current_timestep_callback: Callable[[], int],
low_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor],
high_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor],
) -> None:
super().__init__()
self.unconditional_batch_skip_range = unconditional_batch_skip_range
self.unconditional_batch_timestep_skip_range = unconditional_batch_timestep_skip_range
# We can't easily detect what args are to be split in unconditional and conditional branches. We
# can only do it for kwargs, hence they are the only ones we split. The args are passed as-is.
# If a model is to be made compatible with FasterCache, the user must ensure that the inputs that
# contain batchwise-concatenated unconditional and conditional inputs are passed as kwargs.
self.uncond_cond_input_kwargs_identifiers = uncond_cond_input_kwargs_identifiers
self.tensor_format = tensor_format
self.is_guidance_distilled = is_guidance_distilled
self.current_timestep_callback = current_timestep_callback
self.low_frequency_weight_callback = low_frequency_weight_callback
self.high_frequency_weight_callback = high_frequency_weight_callback
def initialize_hook(self, module):
self.state = FasterCacheDenoiserState()
return module
@staticmethod
def _get_cond_input(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# Note: this method assumes that the input tensor is batchwise-concatenated with unconditional inputs
# followed by conditional inputs.
_, cond = input.chunk(2, dim=0)
return cond
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
# Split the unconditional and conditional inputs. We only want to infer the conditional branch if the
# requirements for skipping the unconditional branch are met as described in the paper.
# We skip the unconditional branch only if the following conditions are met:
# 1. We have completed at least one iteration of the denoiser
# 2. The current timestep is within the range specified by the user. This is the optimal timestep range
# where approximating the unconditional branch from the computation of the conditional branch is possible
# without a significant loss in quality.
# 3. The current iteration is not a multiple of the unconditional batch skip range. This is done so that
# we compute the unconditional branch at least once every few iterations to ensure minimal quality loss.
is_within_timestep_range = (
self.unconditional_batch_timestep_skip_range[0]
< self.current_timestep_callback()
< self.unconditional_batch_timestep_skip_range[1]
)
should_skip_uncond = (
self.state.iteration > 0
and is_within_timestep_range
and self.state.iteration % self.unconditional_batch_skip_range != 0
and not self.is_guidance_distilled
)
if should_skip_uncond:
is_any_kwarg_uncond = any(k in self.uncond_cond_input_kwargs_identifiers for k in kwargs.keys())
if is_any_kwarg_uncond:
logger.debug("FasterCache - Skipping unconditional branch computation")
args = tuple([self._get_cond_input(arg) if torch.is_tensor(arg) else arg for arg in args])
kwargs = {
k: v if k not in self.uncond_cond_input_kwargs_identifiers else self._get_cond_input(v)
for k, v in kwargs.items()
}
output = self.fn_ref.original_forward(*args, **kwargs)
if self.is_guidance_distilled:
self.state.iteration += 1
return output
if torch.is_tensor(output):
hidden_states = output
elif isinstance(output, (tuple, Transformer2DModelOutput)):
hidden_states = output[0]
batch_size = hidden_states.size(0)
if should_skip_uncond:
self.state.low_frequency_delta = self.state.low_frequency_delta * self.low_frequency_weight_callback(
module
)
self.state.high_frequency_delta = self.state.high_frequency_delta * self.high_frequency_weight_callback(
module
)
if self.tensor_format == "BCFHW":
hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
hidden_states = hidden_states.flatten(0, 1)
low_freq_cond, high_freq_cond = _split_low_high_freq(hidden_states.float())
# Approximate/compute the unconditional branch outputs as described in Equation 9 and 10 of the paper
low_freq_uncond = self.state.low_frequency_delta + low_freq_cond
high_freq_uncond = self.state.high_frequency_delta + high_freq_cond
uncond_freq = low_freq_uncond + high_freq_uncond
uncond_states = torch.fft.ifftshift(uncond_freq)
uncond_states = torch.fft.ifft2(uncond_states).real
if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
uncond_states = uncond_states.unflatten(0, (batch_size, -1))
hidden_states = hidden_states.unflatten(0, (batch_size, -1))
if self.tensor_format == "BCFHW":
uncond_states = uncond_states.permute(0, 2, 1, 3, 4)
hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
# Concatenate the approximated unconditional and predicted conditional branches
uncond_states = uncond_states.to(hidden_states.dtype)
hidden_states = torch.cat([uncond_states, hidden_states], dim=0)
else:
uncond_states, cond_states = hidden_states.chunk(2, dim=0)
if self.tensor_format == "BCFHW":
uncond_states = uncond_states.permute(0, 2, 1, 3, 4)
cond_states = cond_states.permute(0, 2, 1, 3, 4)
if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW":
uncond_states = uncond_states.flatten(0, 1)
cond_states = cond_states.flatten(0, 1)
low_freq_uncond, high_freq_uncond = _split_low_high_freq(uncond_states.float())
low_freq_cond, high_freq_cond = _split_low_high_freq(cond_states.float())
self.state.low_frequency_delta = low_freq_uncond - low_freq_cond
self.state.high_frequency_delta = high_freq_uncond - high_freq_cond
self.state.iteration += 1
if torch.is_tensor(output):
output = hidden_states
elif isinstance(output, tuple):
output = (hidden_states, *output[1:])
else:
output.sample = hidden_states
return output
def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
self.state.reset()
return module
class FasterCacheBlockHook(ModelHook):
_is_stateful = True
def __init__(
self,
block_skip_range: int,
timestep_skip_range: Tuple[int, int],
is_guidance_distilled: bool,
weight_callback: Callable[[torch.nn.Module], float],
current_timestep_callback: Callable[[], int],
) -> None:
super().__init__()
self.block_skip_range = block_skip_range
self.timestep_skip_range = timestep_skip_range
self.is_guidance_distilled = is_guidance_distilled
self.weight_callback = weight_callback
self.current_timestep_callback = current_timestep_callback
def initialize_hook(self, module):
self.state = FasterCacheBlockState()
return module
def _compute_approximated_attention_output(
self, t_2_output: torch.Tensor, t_output: torch.Tensor, weight: float, batch_size: int
) -> torch.Tensor:
if t_2_output.size(0) != batch_size:
# The cache t_2_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just
# take the conditional branch outputs.
assert t_2_output.size(0) == 2 * batch_size
t_2_output = t_2_output[batch_size:]
if t_output.size(0) != batch_size:
# The cache t_output contains both batchwise-concatenated unconditional-conditional branch outputs. Just
# take the conditional branch outputs.
assert t_output.size(0) == 2 * batch_size
t_output = t_output[batch_size:]
return t_output + (t_output - t_2_output) * weight
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
batch_size = [
*[arg.size(0) for arg in args if torch.is_tensor(arg)],
*[v.size(0) for v in kwargs.values() if torch.is_tensor(v)],
][0]
if self.state.batch_size is None:
# Will be updated on first forward pass through the denoiser
self.state.batch_size = batch_size
# If we have to skip due to the skip conditions, then let's skip as expected.
# But, we can't skip if the denoiser wants to infer both unconditional and conditional branches. This
# is because the expected output shapes of attention layer will not match if we only return values from
# the cache (which only caches conditional branch outputs). So, if state.batch_size (which is the true
# unconditional-conditional batch size) is same as the current batch size, we don't perform the layer
# skip. Otherwise, we conditionally skip the layer based on what state.skip_callback returns.
is_within_timestep_range = (
self.timestep_skip_range[0] < self.current_timestep_callback() < self.timestep_skip_range[1]
)
if not is_within_timestep_range:
should_skip_attention = False
else:
should_compute_attention = self.state.iteration > 0 and self.state.iteration % self.block_skip_range == 0
should_skip_attention = not should_compute_attention
if should_skip_attention:
should_skip_attention = self.is_guidance_distilled or self.state.batch_size != batch_size
if should_skip_attention:
logger.debug("FasterCache - Skipping attention and using approximation")
if torch.is_tensor(self.state.cache[-1]):
t_2_output, t_output = self.state.cache
weight = self.weight_callback(module)
output = self._compute_approximated_attention_output(t_2_output, t_output, weight, batch_size)
else:
# The cache contains multiple tensors from past N iterations (N=2 for FasterCache). We need to handle all of them.
# Diffusers blocks can return multiple tensors - let's call them [A, B, C, ...] for simplicity.
# In our cache, we would have [[A_1, B_1, C_1, ...], [A_2, B_2, C_2, ...], ...] where each list is the output from
# a forward pass of the block. We need to compute the approximated output for each of these tensors.
# The zip(*state.cache) operation will give us [(A_1, A_2, ...), (B_1, B_2, ...), (C_1, C_2, ...), ...] which
# allows us to compute the approximated attention output for each tensor in the cache.
output = ()
for t_2_output, t_output in zip(*self.state.cache):
result = self._compute_approximated_attention_output(
t_2_output, t_output, self.weight_callback(module), batch_size
)
output += (result,)
else:
logger.debug("FasterCache - Computing attention")
output = self.fn_ref.original_forward(*args, **kwargs)
# Note that the following condition for getting hidden_states should suffice since Diffusers blocks either return
# a single hidden_states tensor, or a tuple of (hidden_states, encoder_hidden_states) tensors. We need to handle
# both cases.
if torch.is_tensor(output):
cache_output = output
if not self.is_guidance_distilled and cache_output.size(0) == self.state.batch_size:
# The output here can be both unconditional-conditional branch outputs or just conditional branch outputs.
# This is determined at the higher-level denoiser module. We only want to cache the conditional branch outputs.
cache_output = cache_output.chunk(2, dim=0)[1]
else:
# Cache all return values and perform the same operation as above
cache_output = ()
for out in output:
if not self.is_guidance_distilled and out.size(0) == self.state.batch_size:
out = out.chunk(2, dim=0)[1]
cache_output += (out,)
if self.state.cache is None:
self.state.cache = [cache_output, cache_output]
else:
self.state.cache = [self.state.cache[-1], cache_output]
self.state.iteration += 1
return output
def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
self.state.reset()
return module
def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> None:
r"""
Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline.
Args:
pipeline (`DiffusionPipeline`):
The diffusion pipeline to apply FasterCache to.
config (`Optional[FasterCacheConfig]`, `optional`, defaults to `None`):
The configuration to use for FasterCache.
Example:
```python
>>> import torch
>>> from diffusers import CogVideoXPipeline, FasterCacheConfig, apply_faster_cache
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> config = FasterCacheConfig(
... spatial_attention_block_skip_range=2,
... spatial_attention_timestep_skip_range=(-1, 681),
... low_frequency_weight_update_timestep_range=(99, 641),
... high_frequency_weight_update_timestep_range=(-1, 301),
... spatial_attention_block_identifiers=["transformer_blocks"],
... attention_weight_callback=lambda _: 0.3,
... tensor_format="BFCHW",
... )
>>> apply_faster_cache(pipe.transformer, config)
```
"""
logger.warning(
"FasterCache is a purely experimental feature and may not work as expected. Not all models support FasterCache. "
"The API is subject to change in future releases, with no guarantee of backward compatibility. Please report any issues at "
"https://github.com/huggingface/diffusers/issues."
)
if config.attention_weight_callback is None:
# If the user has not provided a weight callback, we default to 0.5 for all timesteps.
# In the paper, they recommend using a gradually increasing weight from 0 to 1 as the inference progresses, but
# this depends from model-to-model. It is required by the user to provide a weight callback if they want to
# use a different weight function. Defaulting to 0.5 works well in practice for most cases.
logger.warning(
"No `attention_weight_callback` provided when enabling FasterCache. Defaulting to using a weight of 0.5 for all timesteps."
)
config.attention_weight_callback = lambda _: 0.5
if config.low_frequency_weight_callback is None:
logger.debug(
"Low frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper."
)
def low_frequency_weight_callback(module: torch.nn.Module) -> float:
is_within_range = (
config.low_frequency_weight_update_timestep_range[0]
< config.current_timestep_callback()
< config.low_frequency_weight_update_timestep_range[1]
)
return config.alpha_low_frequency if is_within_range else 1.0
config.low_frequency_weight_callback = low_frequency_weight_callback
if config.high_frequency_weight_callback is None:
logger.debug(
"High frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper."
)
def high_frequency_weight_callback(module: torch.nn.Module) -> float:
is_within_range = (
config.high_frequency_weight_update_timestep_range[0]
< config.current_timestep_callback()
< config.high_frequency_weight_update_timestep_range[1]
)
return config.alpha_high_frequency if is_within_range else 1.0
config.high_frequency_weight_callback = high_frequency_weight_callback
supported_tensor_formats = ["BCFHW", "BFCHW", "BCHW"] # TODO(aryan): Support BSC for LTX Video
if config.tensor_format not in supported_tensor_formats:
raise ValueError(f"`tensor_format` must be one of {supported_tensor_formats}, but got {config.tensor_format}.")
_apply_faster_cache_on_denoiser(module, config)
for name, submodule in module.named_modules():
if not isinstance(submodule, _ATTENTION_CLASSES):
continue
if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS):
_apply_faster_cache_on_attention_class(name, submodule, config)
def _apply_faster_cache_on_denoiser(module: torch.nn.Module, config: FasterCacheConfig) -> None:
hook = FasterCacheDenoiserHook(
config.unconditional_batch_skip_range,
config.unconditional_batch_timestep_skip_range,
config.tensor_format,
config.is_guidance_distilled,
config._unconditional_conditional_input_kwargs_identifiers,
config.current_timestep_callback,
config.low_frequency_weight_callback,
config.high_frequency_weight_callback,
)
registry = HookRegistry.check_if_exists_or_initialize(module)
registry.register_hook(hook, _FASTER_CACHE_DENOISER_HOOK)
def _apply_faster_cache_on_attention_class(name: str, module: Attention, config: FasterCacheConfig) -> None:
is_spatial_self_attention = (
any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers)
and config.spatial_attention_block_skip_range is not None
and not getattr(module, "is_cross_attention", False)
)
is_temporal_self_attention = (
any(re.search(identifier, name) is not None for identifier in config.temporal_attention_block_identifiers)
and config.temporal_attention_block_skip_range is not None
and not module.is_cross_attention
)
block_skip_range, timestep_skip_range, block_type = None, None, None
if is_spatial_self_attention:
block_skip_range = config.spatial_attention_block_skip_range
timestep_skip_range = config.spatial_attention_timestep_skip_range
block_type = "spatial"
elif is_temporal_self_attention:
block_skip_range = config.temporal_attention_block_skip_range
timestep_skip_range = config.temporal_attention_timestep_skip_range
block_type = "temporal"
if block_skip_range is None or timestep_skip_range is None:
logger.debug(
f'Unable to apply FasterCache to the selected layer: "{name}" because it does '
f"not match any of the required criteria for spatial or temporal attention layers. Note, "
f"however, that this layer may still be valid for applying PAB. Please specify the correct "
f"block identifiers in the configuration or use the specialized `apply_faster_cache_on_module` "
f"function to apply FasterCache to this layer."
)
return
logger.debug(f"Enabling FasterCache ({block_type}) for layer: {name}")
hook = FasterCacheBlockHook(
block_skip_range,
timestep_skip_range,
config.is_guidance_distilled,
config.attention_weight_callback,
config.current_timestep_callback,
)
registry = HookRegistry.check_if_exists_or_initialize(module)
registry.register_hook(hook, _FASTER_CACHE_BLOCK_HOOK)
# Reference: https://github.com/Vchitect/FasterCache/blob/fab32c15014636dc854948319c0a9a8d92c7acb4/scripts/latte/faster_cache_sample_latte.py#L127C1-L143C39
@torch.no_grad()
def _split_low_high_freq(x):
fft = torch.fft.fft2(x)
fft_shifted = torch.fft.fftshift(fft)
height, width = x.shape[-2:]
radius = min(height, width) // 5
y_grid, x_grid = torch.meshgrid(torch.arange(height), torch.arange(width))
center_x, center_y = width // 2, height // 2
mask = (x_grid - center_x) ** 2 + (y_grid - center_y) ** 2 <= radius**2
low_freq_mask = mask.unsqueeze(0).unsqueeze(0).to(x.device)
high_freq_mask = ~low_freq_mask
low_freq_fft = fft_shifted * low_freq_mask
high_freq_fft = fft_shifted * high_freq_mask
return low_freq_fft, high_freq_fft
+105 -48
View File
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext
from contextlib import contextmanager, nullcontext
from typing import Dict, List, Optional, Set, Tuple
import torch
@@ -56,7 +56,7 @@ class ModuleGroup:
buffers: Optional[List[torch.Tensor]] = None,
non_blocking: bool = False,
stream: Optional[torch.cuda.Stream] = None,
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None,
low_cpu_mem_usage=False,
onload_self: bool = True,
) -> None:
self.modules = modules
@@ -64,15 +64,50 @@ class ModuleGroup:
self.onload_device = onload_device
self.offload_leader = offload_leader
self.onload_leader = onload_leader
self.parameters = parameters
self.buffers = buffers
self.parameters = parameters or []
self.buffers = buffers or []
self.non_blocking = non_blocking or stream is not None
self.stream = stream
self.cpu_param_dict = cpu_param_dict
self.onload_self = onload_self
self.low_cpu_mem_usage = low_cpu_mem_usage
if self.stream is not None and self.cpu_param_dict is None:
raise ValueError("cpu_param_dict must be provided when using stream for data transfer.")
self.cpu_param_dict = self._init_cpu_param_dict()
def _init_cpu_param_dict(self):
cpu_param_dict = {}
if self.stream is None:
return cpu_param_dict
for module in self.modules:
for param in module.parameters():
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
for buffer in module.buffers():
cpu_param_dict[buffer] = (
buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
)
for param in self.parameters:
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
for buffer in self.buffers:
cpu_param_dict[buffer] = buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
return cpu_param_dict
@contextmanager
def _pinned_memory_tensors(self):
pinned_dict = {}
try:
for param, tensor in self.cpu_param_dict.items():
if not tensor.is_pinned():
pinned_dict[param] = tensor.pin_memory()
else:
pinned_dict[param] = tensor
yield pinned_dict
finally:
pinned_dict = None
def onload_(self):
r"""Onloads the group of modules to the onload_device."""
@@ -82,12 +117,30 @@ class ModuleGroup:
self.stream.synchronize()
with context:
for group_module in self.modules:
group_module.to(self.onload_device, non_blocking=self.non_blocking)
if self.parameters is not None:
if self.stream is not None:
with self._pinned_memory_tensors() as pinned_memory:
for group_module in self.modules:
for param in group_module.parameters():
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
for buffer in group_module.buffers():
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
for param in self.parameters:
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
for buffer in self.buffers:
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
else:
for group_module in self.modules:
for param in group_module.parameters():
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
for buffer in group_module.buffers():
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
for param in self.parameters:
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
if self.buffers is not None:
for buffer in self.buffers:
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
@@ -98,15 +151,18 @@ class ModuleGroup:
for group_module in self.modules:
for param in group_module.parameters():
param.data = self.cpu_param_dict[param]
for param in self.parameters:
param.data = self.cpu_param_dict[param]
for buffer in self.buffers:
buffer.data = self.cpu_param_dict[buffer]
else:
for group_module in self.modules:
group_module.to(self.offload_device, non_blocking=self.non_blocking)
if self.parameters is not None:
for param in self.parameters:
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
if self.buffers is not None:
for buffer in self.buffers:
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
for param in self.parameters:
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
for buffer in self.buffers:
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
class GroupOffloadingHook(ModelHook):
@@ -172,6 +228,13 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
self._layer_execution_tracker_module_names = set()
def initialize_hook(self, module):
def make_execution_order_update_callback(current_name, current_submodule):
def callback():
logger.debug(f"Adding {current_name} to the execution order")
self.execution_order.append((current_name, current_submodule))
return callback
# To every submodule that contains a group offloading hook (at this point, no prefetching is enabled for any
# of the groups), we add a layer execution tracker hook that will be used to determine the order in which the
# layers are executed during the forward pass.
@@ -183,14 +246,8 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
group_offloading_hook = registry.get_hook(_GROUP_OFFLOADING)
if group_offloading_hook is not None:
def make_execution_order_update_callback(current_name, current_submodule):
def callback():
logger.debug(f"Adding {current_name} to the execution order")
self.execution_order.append((current_name, current_submodule))
return callback
# For the first forward pass, we have to load in a blocking manner
group_offloading_hook.group.non_blocking = False
layer_tracker_hook = LayerExecutionTrackerHook(make_execution_order_update_callback(name, submodule))
registry.register_hook(layer_tracker_hook, _LAYER_EXECUTION_TRACKER)
self._layer_execution_tracker_module_names.add(name)
@@ -220,6 +277,7 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
# Remove the layer execution tracker hooks from the submodules
base_module_registry = module._diffusers_hook
registries = [submodule._diffusers_hook for _, submodule in self.execution_order]
group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries]
for i in range(num_executed):
registries[i].remove_hook(_LAYER_EXECUTION_TRACKER, recurse=False)
@@ -227,8 +285,13 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
# Remove the current lazy prefetch group offloading hook so that it doesn't interfere with the next forward pass
base_module_registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=False)
# Apply lazy prefetching by setting required attributes
group_offloading_hooks = [registry.get_hook(_GROUP_OFFLOADING) for registry in registries]
# LazyPrefetchGroupOffloadingHook is only used with streams, so we know that non_blocking should be True.
# We disable non_blocking for the first forward pass, but need to enable it for the subsequent passes to
# see the benefits of prefetching.
for hook in group_offloading_hooks:
hook.group.non_blocking = True
# Set required attributes for prefetching
if num_executed > 0:
base_module_group_offloading_hook = base_module_registry.get_hook(_GROUP_OFFLOADING)
base_module_group_offloading_hook.next_group = group_offloading_hooks[0].group
@@ -268,6 +331,7 @@ def apply_group_offloading(
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
use_stream: bool = False,
low_cpu_mem_usage: bool = False,
) -> None:
r"""
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -314,6 +378,10 @@ def apply_group_offloading(
use_stream (`bool`, defaults to `False`):
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
overlapping computation and data transfer.
low_cpu_mem_usage (`bool`, defaults to `False`):
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
the CPU memory is a bottleneck but may counteract the benefits of using streams.
Example:
```python
@@ -349,10 +417,12 @@ def apply_group_offloading(
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
_apply_group_offloading_block_level(
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
)
elif offload_type == "leaf_level":
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream)
_apply_group_offloading_leaf_level(
module, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage
)
else:
raise ValueError(f"Unsupported offload_type: {offload_type}")
@@ -364,6 +434,7 @@ def _apply_group_offloading_block_level(
onload_device: torch.device,
non_blocking: bool,
stream: Optional[torch.cuda.Stream] = None,
low_cpu_mem_usage: bool = False,
) -> None:
r"""
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
@@ -384,13 +455,6 @@ def _apply_group_offloading_block_level(
for overlapping computation and data transfer.
"""
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
cpu_param_dict = None
if stream is not None:
for param in module.parameters():
param.data = param.data.cpu().pin_memory()
cpu_param_dict = {param: param.data for param in module.parameters()}
# Create module groups for ModuleList and Sequential blocks
modules_with_group_offloading = set()
unmatched_modules = []
@@ -411,7 +475,7 @@ def _apply_group_offloading_block_level(
onload_leader=current_modules[0],
non_blocking=non_blocking,
stream=stream,
cpu_param_dict=cpu_param_dict,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=stream is None,
)
matched_module_groups.append(group)
@@ -448,7 +512,6 @@ def _apply_group_offloading_block_level(
buffers=buffers,
non_blocking=False,
stream=None,
cpu_param_dict=None,
onload_self=True,
)
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
@@ -461,6 +524,7 @@ def _apply_group_offloading_leaf_level(
onload_device: torch.device,
non_blocking: bool,
stream: Optional[torch.cuda.Stream] = None,
low_cpu_mem_usage: bool = False,
) -> None:
r"""
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
@@ -483,13 +547,6 @@ def _apply_group_offloading_leaf_level(
for overlapping computation and data transfer.
"""
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
cpu_param_dict = None
if stream is not None:
for param in module.parameters():
param.data = param.data.cpu().pin_memory()
cpu_param_dict = {param: param.data for param in module.parameters()}
# Create module groups for leaf modules and apply group offloading hooks
modules_with_group_offloading = set()
for name, submodule in module.named_modules():
@@ -503,7 +560,7 @@ def _apply_group_offloading_leaf_level(
onload_leader=submodule,
non_blocking=non_blocking,
stream=stream,
cpu_param_dict=cpu_param_dict,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
)
_apply_group_offloading_hook(submodule, group, None)
@@ -548,7 +605,7 @@ def _apply_group_offloading_leaf_level(
buffers=buffers,
non_blocking=non_blocking,
stream=stream,
cpu_param_dict=cpu_param_dict,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
)
_apply_group_offloading_hook(parent_module, group, None)
@@ -567,7 +624,7 @@ def _apply_group_offloading_leaf_level(
buffers=None,
non_blocking=False,
stream=None,
cpu_param_dict=None,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
)
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
@@ -26,8 +26,8 @@ from .hooks import HookRegistry, ModelHook
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
_PYRAMID_ATTENTION_BROADCAST_HOOK = "pyramid_attention_broadcast"
_ATTENTION_CLASSES = (Attention, MochiAttention)
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
@@ -87,7 +87,7 @@ class PyramidAttentionBroadcastConfig:
def __repr__(self) -> str:
return (
f"PyramidAttentionBroadcastConfig("
f"PyramidAttentionBroadcastConfig(\n"
f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n"
f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n"
f" cross_attention_block_skip_range={self.cross_attention_block_skip_range},\n"
@@ -175,10 +175,7 @@ class PyramidAttentionBroadcastHook(ModelHook):
return module
def apply_pyramid_attention_broadcast(
module: torch.nn.Module,
config: PyramidAttentionBroadcastConfig,
):
def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAttentionBroadcastConfig):
r"""
Apply [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) to a given pipeline.
@@ -311,4 +308,4 @@ def _apply_pyramid_attention_broadcast_hook(
"""
registry = HookRegistry.check_if_exists_or_initialize(module)
hook = PyramidAttentionBroadcastHook(timestep_skip_range, block_skip_range, current_timestep_callback)
registry.register_hook(hook, "pyramid_attention_broadcast")
registry.register_hook(hook, _PYRAMID_ATTENTION_BROADCAST_HOOK)
+25
View File
@@ -316,6 +316,7 @@ def _load_lora_into_text_encoder(
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -341,6 +342,10 @@ def _load_lora_into_text_encoder(
# their prefixes.
prefix = text_encoder_name if prefix is None else prefix
# Safe prefix to check with.
if hotswap and any(text_encoder_name in key for key in state_dict.keys()):
raise ValueError("At the moment, hotswapping is not supported for text encoders, please pass `hotswap=False`.")
# Load the layers corresponding to text encoder and make necessary adjustments.
if prefix is not None:
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
@@ -908,3 +913,23 @@ class LoraBaseMixin:
# property function that returns the lora scale which can be set at run time by the pipeline.
# if _lora_scale has not been set, return 1
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
def enable_lora_hotswap(self, **kwargs) -> None:
"""Enables the possibility to hotswap LoRA adapters.
Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
the loaded adapters differ.
Args:
target_rank (`int`):
The highest rank among all the adapters that will be loaded.
check_compiled (`str`, *optional*, defaults to `"error"`):
How to handle the case when the model is already compiled, which should generally be avoided. The
options are:
- "error" (default): raise an error
- "warn": issue a warning
- "ignore": do nothing
"""
for key, component in self.components.items():
if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules):
component.enable_lora_hotswap(**kwargs)
File diff suppressed because it is too large Load Diff
+130 -5
View File
@@ -16,7 +16,7 @@ import inspect
import os
from functools import partial
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Dict, List, Literal, Optional, Union
import safetensors
import torch
@@ -128,6 +128,8 @@ class PeftAdapterMixin:
"""
_hf_peft_config_loaded = False
# kwargs for prepare_model_for_compiled_hotswap, if required
_prepare_lora_hotswap_kwargs: Optional[dict] = None
@classmethod
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
@@ -145,7 +147,9 @@ class PeftAdapterMixin:
"""
return _func_optionally_disable_offloading(_pipeline=_pipeline)
def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs):
def load_lora_adapter(
self, pretrained_model_name_or_path_or_dict, prefix="transformer", hotswap: bool = False, **kwargs
):
r"""
Loads a LoRA adapter into the underlying model.
@@ -189,6 +193,29 @@ class PeftAdapterMixin:
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap : (`bool`, *optional*)
Defaults to `False`. Whether to substitute an existing (LoRA) adapter with the newly loaded adapter
in-place. This means that, instead of loading an additional adapter, this will take the existing
adapter weights and replace them with the weights of the new adapter. This can be faster and more
memory efficient. However, the main advantage of hotswapping is that when the model is compiled with
torch.compile, loading the new adapter does not require recompilation of the model. When using
hotswapping, the passed `adapter_name` should be the name of an already loaded adapter.
If the new adapter and the old adapter have different ranks and/or LoRA alphas (i.e. scaling), you need
to call an additional method before loading the adapter:
```py
pipeline = ... # load diffusers pipeline
max_rank = ... # the highest rank among all LoRAs that you want to load
# call *before* compiling and loading the LoRA adapter
pipeline.enable_lora_hotswap(target_rank=max_rank)
pipeline.load_lora_weights(file_name)
# optionally compile the model now
```
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
"""
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
from peft.tuners.tuners_utils import BaseTunerLayer
@@ -239,10 +266,15 @@ class PeftAdapterMixin:
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
if len(state_dict) > 0:
if adapter_name in getattr(self, "peft_config", {}):
if adapter_name in getattr(self, "peft_config", {}) and not hotswap:
raise ValueError(
f"Adapter name {adapter_name} already in use in the model - please select a new adapter name."
)
elif adapter_name not in getattr(self, "peft_config", {}) and hotswap:
raise ValueError(
f"Trying to hotswap LoRA adapter '{adapter_name}' but there is no existing adapter by that name. "
"Please choose an existing adapter name or set `hotswap=False` to prevent hotswapping."
)
# check with first key if is not in peft format
first_key = next(iter(state_dict.keys()))
@@ -302,11 +334,71 @@ class PeftAdapterMixin:
if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
if hotswap or (self._prepare_lora_hotswap_kwargs is not None):
if is_peft_version(">", "0.14.0"):
from peft.utils.hotswap import (
check_hotswap_configs_compatible,
hotswap_adapter_from_state_dict,
prepare_model_for_compiled_hotswap,
)
else:
msg = (
"Hotswapping requires PEFT > v0.14. Please upgrade PEFT to a higher version or install it "
"from source."
)
raise ImportError(msg)
if hotswap:
def map_state_dict_for_hotswap(sd):
# For hotswapping, we need the adapter name to be present in the state dict keys
new_sd = {}
for k, v in sd.items():
if k.endswith("lora_A.weight") or key.endswith("lora_B.weight"):
k = k[: -len(".weight")] + f".{adapter_name}.weight"
elif k.endswith("lora_B.bias"): # lora_bias=True option
k = k[: -len(".bias")] + f".{adapter_name}.bias"
new_sd[k] = v
return new_sd
# To handle scenarios where we cannot successfully set state dict. If it's unsucessful,
# we should also delete the `peft_config` associated to the `adapter_name`.
try:
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
if hotswap:
state_dict = map_state_dict_for_hotswap(state_dict)
check_hotswap_configs_compatible(self.peft_config[adapter_name], lora_config)
try:
hotswap_adapter_from_state_dict(
model=self,
state_dict=state_dict,
adapter_name=adapter_name,
config=lora_config,
)
except Exception as e:
logger.error(f"Hotswapping {adapter_name} was unsucessful with the following error: \n{e}")
raise
# the hotswap function raises if there are incompatible keys, so if we reach this point we can set
# it to None
incompatible_keys = None
else:
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
if self._prepare_lora_hotswap_kwargs is not None:
# For hotswapping of compiled models or adapters with different ranks.
# If the user called enable_lora_hotswap, we need to ensure it is called:
# - after the first adapter was loaded
# - before the model is compiled and the 2nd adapter is being hotswapped in
# Therefore, it needs to be called here
prepare_model_for_compiled_hotswap(
self, config=lora_config, **self._prepare_lora_hotswap_kwargs
)
# We only want to call prepare_model_for_compiled_hotswap once
self._prepare_lora_hotswap_kwargs = None
# Set peft config loaded flag to True if module has been successfully injected and incompatible keys retrieved
if not self._hf_peft_config_loaded:
self._hf_peft_config_loaded = True
except Exception as e:
# In case `inject_adapter_in_model()` was unsuccessful even before injecting the `peft_config`.
if hasattr(self, "peft_config"):
@@ -766,3 +858,36 @@ class PeftAdapterMixin:
# Pop also the corresponding adapter from the config
if hasattr(self, "peft_config"):
self.peft_config.pop(adapter_name, None)
def enable_lora_hotswap(
self, target_rank: int = 128, check_compiled: Literal["error", "warn", "ignore"] = "error"
) -> None:
"""Enables the possibility to hotswap LoRA adapters.
Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
the loaded adapters differ.
Args:
target_rank (`int`, *optional*, defaults to `128`):
The highest rank among all the adapters that will be loaded.
check_compiled (`str`, *optional*, defaults to `"error"`):
How to handle the case when the model is already compiled, which should generally be avoided. The
options are:
- "error" (default): raise an error
- "warn": issue a warning
- "ignore": do nothing
"""
if getattr(self, "peft_config", {}):
if check_compiled == "error":
raise RuntimeError("Call `enable_lora_hotswap` before loading the first adapter.")
elif check_compiled == "warn":
logger.warning(
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
)
elif check_compiled != "ignore":
raise ValueError(
f"check_compiles should be one of 'error', 'warn', or 'ignore', got '{check_compiled}' instead."
)
self._prepare_lora_hotswap_kwargs = {"target_rank": target_rank, "check_compiled": check_compiled}
+2 -2
View File
@@ -360,12 +360,12 @@ class FromSingleFileMixin:
cache_dir = kwargs.pop("cache_dir", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
torch_dtype = kwargs.pop("torch_dtype", None)
disable_mmap = kwargs.pop("disable_mmap", False)
is_legacy_loading = False
if not isinstance(torch_dtype, torch.dtype):
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
+3 -2
View File
@@ -255,12 +255,12 @@ class FromOriginalModelMixin:
subfolder = kwargs.pop("subfolder", None)
revision = kwargs.pop("revision", None)
config_revision = kwargs.pop("config_revision", None)
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
torch_dtype = kwargs.pop("torch_dtype", None)
quantization_config = kwargs.pop("quantization_config", None)
device = kwargs.pop("device", None)
disable_mmap = kwargs.pop("disable_mmap", False)
if not isinstance(torch_dtype, torch.dtype):
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
@@ -282,6 +282,7 @@ class FromOriginalModelMixin:
if quantization_config is not None:
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
hf_quantizer.validate_environment()
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
else:
hf_quantizer = None
+2 -2
View File
@@ -44,6 +44,7 @@ from ..utils import (
is_transformers_available,
logging,
)
from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT
from ..utils.hub_utils import _get_model_file
@@ -443,7 +444,7 @@ def fetch_original_config(original_config_file, local_files_only=False):
"Please provide a valid local file path."
)
original_config_file = BytesIO(requests.get(original_config_file).content)
original_config_file = BytesIO(requests.get(original_config_file, timeout=DIFFUSERS_REQUEST_TIMEOUT).content)
else:
raise ValueError("Invalid `original_config_file` provided. Please set it to a valid file path or URL.")
@@ -2406,7 +2407,6 @@ def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs):
"per_channel_statistics.channel": remove_keys_,
"per_channel_statistics.mean-of-means": remove_keys_,
"per_channel_statistics.mean-of-stds": remove_keys_,
"timestep_scale_multiplier": remove_keys_,
}
if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
+2 -2
View File
@@ -449,9 +449,9 @@ class TextualInversionLoaderMixin:
# 7.5 Offload the model again
if is_model_cpu_offload:
self.enable_model_cpu_offload()
self.enable_model_cpu_offload(device=device)
elif is_sequential_cpu_offload:
self.enable_sequential_cpu_offload()
self.enable_sequential_cpu_offload(device=device)
# / Unsafe Code >
@@ -6020,6 +6020,11 @@ class SanaLinearAttnProcessor2_0:
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
query = query.transpose(1, 2).unflatten(1, (attn.heads, -1))
key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3)
value = value.transpose(1, 2).unflatten(1, (attn.heads, -1))
@@ -105,6 +105,7 @@ class CogVideoXCausalConv3d(nn.Module):
self.width_pad = width_pad
self.time_pad = time_pad
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
self.const_padding_conv3d = (0, self.width_pad, self.height_pad)
self.temporal_dim = 2
self.time_kernel_size = time_kernel_size
@@ -117,6 +118,8 @@ class CogVideoXCausalConv3d(nn.Module):
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=0 if self.pad_mode == "replicate" else self.const_padding_conv3d,
padding_mode="zeros",
)
def fake_context_parallel_forward(
@@ -137,9 +140,7 @@ class CogVideoXCausalConv3d(nn.Module):
if self.pad_mode == "replicate":
conv_cache = None
else:
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
output = self.conv(inputs)
return output, conv_cache
@@ -196,6 +196,55 @@ class LTXVideoResnetBlock3d(nn.Module):
return hidden_states
class LTXVideoDownsampler3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
stride: Union[int, Tuple[int, int, int]] = 1,
is_causal: bool = True,
padding_mode: str = "zeros",
) -> None:
super().__init__()
self.stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
self.group_size = (in_channels * stride[0] * stride[1] * stride[2]) // out_channels
out_channels = out_channels // (self.stride[0] * self.stride[1] * self.stride[2])
self.conv = LTXVideoCausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
is_causal=is_causal,
padding_mode=padding_mode,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = torch.cat([hidden_states[:, :, : self.stride[0] - 1], hidden_states], dim=2)
residual = (
hidden_states.unflatten(4, (-1, self.stride[2]))
.unflatten(3, (-1, self.stride[1]))
.unflatten(2, (-1, self.stride[0]))
)
residual = residual.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4)
residual = residual.unflatten(1, (-1, self.group_size))
residual = residual.mean(dim=2)
hidden_states = self.conv(hidden_states)
hidden_states = (
hidden_states.unflatten(4, (-1, self.stride[2]))
.unflatten(3, (-1, self.stride[1]))
.unflatten(2, (-1, self.stride[0]))
)
hidden_states = hidden_states.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(1, 4)
hidden_states = hidden_states + residual
return hidden_states
class LTXVideoUpsampler3d(nn.Module):
def __init__(
self,
@@ -204,6 +253,7 @@ class LTXVideoUpsampler3d(nn.Module):
is_causal: bool = True,
residual: bool = False,
upscale_factor: int = 1,
padding_mode: str = "zeros",
) -> None:
super().__init__()
@@ -219,6 +269,7 @@ class LTXVideoUpsampler3d(nn.Module):
kernel_size=3,
stride=1,
is_causal=is_causal,
padding_mode=padding_mode,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -352,6 +403,118 @@ class LTXVideoDownBlock3D(nn.Module):
return hidden_states
class LTXVideo095DownBlock3D(nn.Module):
r"""
Down block used in the LTXVideo model.
Args:
in_channels (`int`):
Number of input channels.
out_channels (`int`, *optional*):
Number of output channels. If None, defaults to `in_channels`.
num_layers (`int`, defaults to `1`):
Number of resnet layers.
dropout (`float`, defaults to `0.0`):
Dropout rate.
resnet_eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
resnet_act_fn (`str`, defaults to `"swish"`):
Activation function to use.
spatio_temporal_scale (`bool`, defaults to `True`):
Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
Whether or not to downsample across temporal dimension.
is_causal (`bool`, defaults to `True`):
Whether this layer behaves causally (future frames depend only on past frames) or not.
"""
_supports_gradient_checkpointing = True
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
resnet_eps: float = 1e-6,
resnet_act_fn: str = "swish",
spatio_temporal_scale: bool = True,
is_causal: bool = True,
downsample_type: str = "conv",
):
super().__init__()
out_channels = out_channels or in_channels
resnets = []
for _ in range(num_layers):
resnets.append(
LTXVideoResnetBlock3d(
in_channels=in_channels,
out_channels=in_channels,
dropout=dropout,
eps=resnet_eps,
non_linearity=resnet_act_fn,
is_causal=is_causal,
)
)
self.resnets = nn.ModuleList(resnets)
self.downsamplers = None
if spatio_temporal_scale:
self.downsamplers = nn.ModuleList()
if downsample_type == "conv":
self.downsamplers.append(
LTXVideoCausalConv3d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
stride=(2, 2, 2),
is_causal=is_causal,
)
)
elif downsample_type == "spatial":
self.downsamplers.append(
LTXVideoDownsampler3d(
in_channels=in_channels, out_channels=out_channels, stride=(1, 2, 2), is_causal=is_causal
)
)
elif downsample_type == "temporal":
self.downsamplers.append(
LTXVideoDownsampler3d(
in_channels=in_channels, out_channels=out_channels, stride=(2, 1, 1), is_causal=is_causal
)
)
elif downsample_type == "spatiotemporal":
self.downsamplers.append(
LTXVideoDownsampler3d(
in_channels=in_channels, out_channels=out_channels, stride=(2, 2, 2), is_causal=is_causal
)
)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
) -> torch.Tensor:
r"""Forward method of the `LTXDownBlock3D` class."""
for i, resnet in enumerate(self.resnets):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb, generator)
else:
hidden_states = resnet(hidden_states, temb, generator)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
return hidden_states
# Adapted from diffusers.models.autoencoders.autoencoder_kl_cogvideox.CogVideoMidBlock3d
class LTXVideoMidBlock3d(nn.Module):
r"""
@@ -593,8 +756,15 @@ class LTXVideoEncoder3d(nn.Module):
in_channels: int = 3,
out_channels: int = 128,
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
down_block_types: Tuple[str, ...] = (
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
),
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"),
patch_size: int = 4,
patch_size_t: int = 1,
resnet_norm_eps: float = 1e-6,
@@ -617,20 +787,37 @@ class LTXVideoEncoder3d(nn.Module):
)
# down blocks
num_block_out_channels = len(block_out_channels)
is_ltx_095 = down_block_types[-1] == "LTXVideo095DownBlock3D"
num_block_out_channels = len(block_out_channels) - (1 if is_ltx_095 else 0)
self.down_blocks = nn.ModuleList([])
for i in range(num_block_out_channels):
input_channel = output_channel
output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i]
if not is_ltx_095:
output_channel = block_out_channels[i + 1] if i + 1 < num_block_out_channels else block_out_channels[i]
else:
output_channel = block_out_channels[i + 1]
down_block = LTXVideoDownBlock3D(
in_channels=input_channel,
out_channels=output_channel,
num_layers=layers_per_block[i],
resnet_eps=resnet_norm_eps,
spatio_temporal_scale=spatio_temporal_scaling[i],
is_causal=is_causal,
)
if down_block_types[i] == "LTXVideoDownBlock3D":
down_block = LTXVideoDownBlock3D(
in_channels=input_channel,
out_channels=output_channel,
num_layers=layers_per_block[i],
resnet_eps=resnet_norm_eps,
spatio_temporal_scale=spatio_temporal_scaling[i],
is_causal=is_causal,
)
elif down_block_types[i] == "LTXVideo095DownBlock3D":
down_block = LTXVideo095DownBlock3D(
in_channels=input_channel,
out_channels=output_channel,
num_layers=layers_per_block[i],
resnet_eps=resnet_norm_eps,
spatio_temporal_scale=spatio_temporal_scaling[i],
is_causal=is_causal,
downsample_type=downsample_type[i],
)
else:
raise ValueError(f"Unknown down block type: {down_block_types[i]}")
self.down_blocks.append(down_block)
@@ -794,7 +981,9 @@ class LTXVideoDecoder3d(nn.Module):
# timestep embedding
self.time_embedder = None
self.scale_shift_table = None
self.timestep_scale_multiplier = None
if timestep_conditioning:
self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0, dtype=torch.float32))
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(output_channel * 2, 0)
self.scale_shift_table = nn.Parameter(torch.randn(2, output_channel) / output_channel**0.5)
@@ -803,6 +992,9 @@ class LTXVideoDecoder3d(nn.Module):
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
hidden_states = self.conv_in(hidden_states)
if self.timestep_scale_multiplier is not None:
temb = temb * self.timestep_scale_multiplier
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states, temb)
@@ -891,12 +1083,19 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
out_channels: int = 3,
latent_channels: int = 128,
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
down_block_types: Tuple[str, ...] = (
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
"LTXVideoDownBlock3D",
),
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
decoder_layers_per_block: Tuple[int, ...] = (4, 3, 3, 3, 4),
spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
decoder_spatio_temporal_scaling: Tuple[bool, ...] = (True, True, True, False),
decoder_inject_noise: Tuple[bool, ...] = (False, False, False, False, False),
downsample_type: Tuple[str, ...] = ("conv", "conv", "conv", "conv"),
upsample_residual: Tuple[bool, ...] = (False, False, False, False),
upsample_factor: Tuple[int, ...] = (1, 1, 1, 1),
timestep_conditioning: bool = False,
@@ -906,6 +1105,8 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
scaling_factor: float = 1.0,
encoder_causal: bool = True,
decoder_causal: bool = False,
spatial_compression_ratio: int = None,
temporal_compression_ratio: int = None,
) -> None:
super().__init__()
@@ -913,8 +1114,10 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
in_channels=in_channels,
out_channels=latent_channels,
block_out_channels=block_out_channels,
down_block_types=down_block_types,
spatio_temporal_scaling=spatio_temporal_scaling,
layers_per_block=layers_per_block,
downsample_type=downsample_type,
patch_size=patch_size,
patch_size_t=patch_size_t,
resnet_norm_eps=resnet_norm_eps,
@@ -941,8 +1144,16 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.register_buffer("latents_mean", latents_mean, persistent=True)
self.register_buffer("latents_std", latents_std, persistent=True)
self.spatial_compression_ratio = patch_size * 2 ** sum(spatio_temporal_scaling)
self.temporal_compression_ratio = patch_size_t * 2 ** sum(spatio_temporal_scaling)
self.spatial_compression_ratio = (
patch_size * 2 ** sum(spatio_temporal_scaling)
if spatial_compression_ratio is None
else spatial_compression_ratio
)
self.temporal_compression_ratio = (
patch_size_t * 2 ** sum(spatio_temporal_scaling)
if temporal_compression_ratio is None
else temporal_compression_ratio
)
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
# to perform decoding of a single video latent at a time.
@@ -210,7 +210,7 @@ class MochiDownBlock3D(nn.Module):
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
resnet,
hidden_states,
conv_cache=conv_cache.get(conv_cache_key),
conv_cache.get(conv_cache_key),
)
else:
hidden_states, new_conv_cache[conv_cache_key] = resnet(
@@ -306,7 +306,7 @@ class MochiMidBlock3D(nn.Module):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
resnet, hidden_states, conv_cache=conv_cache.get(conv_cache_key)
resnet, hidden_states, conv_cache.get(conv_cache_key)
)
else:
hidden_states, new_conv_cache[conv_cache_key] = resnet(
@@ -382,7 +382,7 @@ class MochiUpBlock3D(nn.Module):
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
resnet,
hidden_states,
conv_cache=conv_cache.get(conv_cache_key),
conv_cache.get(conv_cache_key),
)
else:
hidden_states, new_conv_cache[conv_cache_key] = resnet(
@@ -497,6 +497,8 @@ class MochiEncoder3D(nn.Module):
self.norm_out = MochiChunkedGroupNorm3D(block_out_channels[-1])
self.proj_out = nn.Linear(block_out_channels[-1], 2 * out_channels, bias=False)
self.gradient_checkpointing = False
def forward(
self, hidden_states: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
) -> torch.Tensor:
@@ -513,13 +515,13 @@ class MochiEncoder3D(nn.Module):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, new_conv_cache["block_in"] = self._gradient_checkpointing_func(
self.block_in, hidden_states, conv_cache=conv_cache.get("block_in")
self.block_in, hidden_states, conv_cache.get("block_in")
)
for i, down_block in enumerate(self.down_blocks):
conv_cache_key = f"down_block_{i}"
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
down_block, hidden_states, conv_cache=conv_cache.get(conv_cache_key)
down_block, hidden_states, conv_cache.get(conv_cache_key)
)
else:
hidden_states, new_conv_cache["block_in"] = self.block_in(
@@ -623,13 +625,13 @@ class MochiDecoder3D(nn.Module):
# 1. Mid
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, new_conv_cache["block_in"] = self._gradient_checkpointing_func(
self.block_in, hidden_states, conv_cache=conv_cache.get("block_in")
self.block_in, hidden_states, conv_cache.get("block_in")
)
for i, up_block in enumerate(self.up_blocks):
conv_cache_key = f"up_block_{i}"
hidden_states, new_conv_cache[conv_cache_key] = self._gradient_checkpointing_func(
up_block, hidden_states, conv_cache=conv_cache.get(conv_cache_key)
up_block, hidden_states, conv_cache.get(conv_cache_key)
)
else:
hidden_states, new_conv_cache["block_in"] = self.block_in(
+22 -3
View File
@@ -24,6 +24,7 @@ class CacheMixin:
Supported caching techniques:
- [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588)
- [FasterCache](https://huggingface.co/papers/2410.19355)
"""
_cache_config = None
@@ -59,17 +60,31 @@ class CacheMixin:
```
"""
from ..hooks import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
from ..hooks import (
FasterCacheConfig,
PyramidAttentionBroadcastConfig,
apply_faster_cache,
apply_pyramid_attention_broadcast,
)
if self.is_cache_enabled:
raise ValueError(
f"Caching has already been enabled with {type(self._cache_config)}. To apply a new caching technique, please disable the existing one first."
)
if isinstance(config, PyramidAttentionBroadcastConfig):
apply_pyramid_attention_broadcast(self, config)
elif isinstance(config, FasterCacheConfig):
apply_faster_cache(self, config)
else:
raise ValueError(f"Cache config {type(config)} is not supported.")
self._cache_config = config
def disable_cache(self) -> None:
from ..hooks import HookRegistry, PyramidAttentionBroadcastConfig
from ..hooks import FasterCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
if self._cache_config is None:
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
@@ -77,7 +92,11 @@ class CacheMixin:
if isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
registry = HookRegistry.check_if_exists_or_initialize(self)
registry.remove_hook("pyramid_attention_broadcast", recurse=True)
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
elif isinstance(self._cache_config, FasterCacheConfig):
registry = HookRegistry.check_if_exists_or_initialize(self)
registry.remove_hook(_FASTER_CACHE_DENOISER_HOOK, recurse=True)
registry.remove_hook(_FASTER_CACHE_BLOCK_HOOK, recurse=True)
else:
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")
+1 -1
View File
@@ -336,7 +336,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
" `from_numpy` is no longer required."
" Pass `output_type='pt' to use the new version now."
)
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
deprecate("output_type=='np'", "0.34.0", deprecation_message, standard_warn=False)
return get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos)
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
+14 -4
View File
@@ -37,7 +37,6 @@ from torch import Tensor, nn
from typing_extensions import Self
from .. import __version__
from ..hooks import apply_group_offloading, apply_layerwise_casting
from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
from ..quantizers.quantization_config import QuantizationMethod
from ..utils import (
@@ -504,6 +503,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
non_blocking (`bool`, *optional*, defaults to `False`):
If `True`, the weight casting operations are non-blocking.
"""
from ..hooks import apply_layerwise_casting
user_provided_patterns = True
if skip_modules_pattern is None:
@@ -546,6 +546,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
use_stream: bool = False,
low_cpu_mem_usage=False,
) -> None:
r"""
Activates group offloading for the current model.
@@ -569,6 +570,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
... )
```
"""
from ..hooks import apply_group_offloading
if getattr(self, "enable_tiling", None) is not None and getattr(self, "use_tiling", False) and use_stream:
msg = (
"Applying group offloading on autoencoders, with CUDA streams, may not work as expected if the first "
@@ -584,7 +587,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
f"open an issue at https://github.com/huggingface/diffusers/issues."
)
apply_group_offloading(
self, onload_device, offload_device, offload_type, num_blocks_per_group, non_blocking, use_stream
self,
onload_device,
offload_device,
offload_type,
num_blocks_per_group,
non_blocking,
use_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
)
def save_pretrained(
@@ -870,7 +880,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
torch_dtype = kwargs.pop("torch_dtype", None)
subfolder = kwargs.pop("subfolder", None)
device_map = kwargs.pop("device_map", None)
max_memory = kwargs.pop("max_memory", None)
@@ -883,7 +893,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
disable_mmap = kwargs.pop("disable_mmap", False)
if not isinstance(torch_dtype, torch.dtype):
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
-10
View File
@@ -550,16 +550,6 @@ class RMSNorm(nn.Module):
hidden_states = torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.eps)[0]
if self.bias is not None:
hidden_states = hidden_states + self.bias
elif is_torch_version(">=", "2.4"):
if self.weight is not None:
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
hidden_states = nn.functional.rms_norm(
hidden_states, normalized_shape=(hidden_states.shape[-1],), weight=self.weight, eps=self.eps
)
if self.bias is not None:
hidden_states = hidden_states + self.bias
else:
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
+1 -1
View File
@@ -366,7 +366,7 @@ class ResnetBlock2D(nn.Module):
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
input_tensor = self.conv_shortcut(input_tensor.contiguous())
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
@@ -273,7 +273,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])
if i == 0 and num_frame > 1:
hidden_states = hidden_states + self.temp_pos_embed
hidden_states = hidden_states + self.temp_pos_embed.to(hidden_states.dtype)
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
@@ -15,6 +15,7 @@
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
@@ -23,10 +24,9 @@ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_
from ..attention_processor import (
Attention,
AttentionProcessor,
AttnProcessor2_0,
SanaLinearAttnProcessor2_0,
)
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, TimestepEmbedding, Timesteps
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle, RMSNorm
@@ -96,6 +96,95 @@ class SanaModulatedNorm(nn.Module):
return hidden_states
class SanaCombinedTimestepGuidanceEmbeddings(nn.Module):
def __init__(self, embedding_dim):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.guidance_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
def forward(self, timestep: torch.Tensor, guidance: torch.Tensor = None, hidden_dtype: torch.dtype = None):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
guidance_proj = self.guidance_condition_proj(guidance)
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=hidden_dtype))
conditioning = timesteps_emb + guidance_emb
return self.linear(self.silu(conditioning)), conditioning
class SanaAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("SanaAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class SanaTransformerBlock(nn.Module):
r"""
Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629).
@@ -115,6 +204,7 @@ class SanaTransformerBlock(nn.Module):
norm_eps: float = 1e-6,
attention_out_bias: bool = True,
mlp_ratio: float = 2.5,
qk_norm: Optional[str] = None,
) -> None:
super().__init__()
@@ -124,6 +214,8 @@ class SanaTransformerBlock(nn.Module):
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
kv_heads=num_attention_heads if qk_norm is not None else None,
qk_norm=qk_norm,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=None,
@@ -135,13 +227,15 @@ class SanaTransformerBlock(nn.Module):
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn2 = Attention(
query_dim=dim,
qk_norm=qk_norm,
kv_heads=num_cross_attention_heads if qk_norm is not None else None,
cross_attention_dim=cross_attention_dim,
heads=num_cross_attention_heads,
dim_head=cross_attention_head_dim,
dropout=dropout,
bias=True,
out_bias=attention_out_bias,
processor=AttnProcessor2_0(),
processor=SanaAttnProcessor2_0(),
)
# 3. Feed-forward
@@ -232,6 +326,10 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
Whether to use elementwise affinity in the normalization layer.
norm_eps (`float`, defaults to `1e-6`):
The epsilon value for the normalization layer.
qk_norm (`str`, *optional*, defaults to `None`):
The normalization to use for the query and key.
timestep_scale (`float`, defaults to `1.0`):
The scale to use for the timesteps.
"""
_supports_gradient_checkpointing = True
@@ -258,6 +356,10 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
interpolation_scale: Optional[int] = None,
guidance_embeds: bool = False,
guidance_embeds_scale: float = 0.1,
qk_norm: Optional[str] = None,
timestep_scale: float = 1.0,
) -> None:
super().__init__()
@@ -276,7 +378,10 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
)
# 2. Additional condition embeddings
self.time_embed = AdaLayerNormSingle(inner_dim)
if guidance_embeds:
self.time_embed = SanaCombinedTimestepGuidanceEmbeddings(inner_dim)
else:
self.time_embed = AdaLayerNormSingle(inner_dim)
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True)
@@ -296,6 +401,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
mlp_ratio=mlp_ratio,
qk_norm=qk_norm,
)
for _ in range(num_layers)
]
@@ -372,7 +478,8 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
timestep: torch.Tensor,
guidance: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -423,9 +530,14 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
hidden_states = self.patch_embed(hidden_states)
timestep, embedded_timestep = self.time_embed(
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
if guidance is not None:
timestep, embedded_timestep = self.time_embed(
timestep, guidance=guidance, hidden_dtype=hidden_states.dtype
)
else:
timestep, embedded_timestep = self.time_embed(
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
@@ -27,13 +27,15 @@ from ..attention import FeedForward
from ..attention_processor import Attention, AttentionProcessor
from ..cache_utils import CacheMixin
from ..embeddings import (
CombinedTimestepGuidanceTextProjEmbeddings,
CombinedTimestepTextProjEmbeddings,
PixArtAlphaTextProjection,
TimestepEmbedding,
Timesteps,
get_1d_rotary_pos_embed,
)
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle, FP32LayerNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -173,6 +175,141 @@ class HunyuanVideoAdaNorm(nn.Module):
return gate_msa, gate_mlp
class HunyuanVideoTokenReplaceAdaLayerNormZero(nn.Module):
def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
if norm_type == "layer_norm":
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
elif norm_type == "fp32_layer_norm":
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
else:
raise ValueError(
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
)
def forward(
self,
hidden_states: torch.Tensor,
emb: torch.Tensor,
token_replace_emb: torch.Tensor,
first_frame_num_tokens: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb = self.linear(self.silu(emb))
token_replace_emb = self.linear(self.silu(token_replace_emb))
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
tr_shift_msa, tr_scale_msa, tr_gate_msa, tr_shift_mlp, tr_scale_mlp, tr_gate_mlp = token_replace_emb.chunk(
6, dim=1
)
norm_hidden_states = self.norm(hidden_states)
hidden_states_zero = (
norm_hidden_states[:, :first_frame_num_tokens] * (1 + tr_scale_msa[:, None]) + tr_shift_msa[:, None]
)
hidden_states_orig = (
norm_hidden_states[:, first_frame_num_tokens:] * (1 + scale_msa[:, None]) + shift_msa[:, None]
)
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
return (
hidden_states,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
tr_gate_msa,
tr_shift_mlp,
tr_scale_mlp,
tr_gate_mlp,
)
class HunyuanVideoTokenReplaceAdaLayerNormZeroSingle(nn.Module):
def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
if norm_type == "layer_norm":
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
else:
raise ValueError(
f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
)
def forward(
self,
hidden_states: torch.Tensor,
emb: torch.Tensor,
token_replace_emb: torch.Tensor,
first_frame_num_tokens: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb = self.linear(self.silu(emb))
token_replace_emb = self.linear(self.silu(token_replace_emb))
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
tr_shift_msa, tr_scale_msa, tr_gate_msa = token_replace_emb.chunk(3, dim=1)
norm_hidden_states = self.norm(hidden_states)
hidden_states_zero = (
norm_hidden_states[:, :first_frame_num_tokens] * (1 + tr_scale_msa[:, None]) + tr_shift_msa[:, None]
)
hidden_states_orig = (
norm_hidden_states[:, first_frame_num_tokens:] * (1 + scale_msa[:, None]) + shift_msa[:, None]
)
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
return hidden_states, gate_msa, tr_gate_msa
class HunyuanVideoConditionEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
pooled_projection_dim: int,
guidance_embeds: bool,
image_condition_type: Optional[str] = None,
):
super().__init__()
self.image_condition_type = image_condition_type
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
self.guidance_embedder = None
if guidance_embeds:
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
def forward(
self, timestep: torch.Tensor, pooled_projection: torch.Tensor, guidance: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
pooled_projections = self.text_embedder(pooled_projection)
conditioning = timesteps_emb + pooled_projections
token_replace_emb = None
if self.image_condition_type == "token_replace":
token_replace_timestep = torch.zeros_like(timestep)
token_replace_proj = self.time_proj(token_replace_timestep)
token_replace_emb = self.timestep_embedder(token_replace_proj.to(dtype=pooled_projection.dtype))
token_replace_emb = token_replace_emb + pooled_projections
if self.guidance_embedder is not None:
guidance_proj = self.time_proj(guidance)
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype))
conditioning = conditioning + guidance_emb
return conditioning, token_replace_emb
class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
def __init__(
self,
@@ -390,6 +527,8 @@ class HunyuanVideoSingleTransformerBlock(nn.Module):
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
*args,
**kwargs,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.shape[1]
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
@@ -468,6 +607,8 @@ class HunyuanVideoTransformerBlock(nn.Module):
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
*args,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Input normalization
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
@@ -503,6 +644,181 @@ class HunyuanVideoTransformerBlock(nn.Module):
return hidden_states, encoder_hidden_states
class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module):
def __init__(
self,
num_attention_heads: int,
attention_head_dim: int,
mlp_ratio: float = 4.0,
qk_norm: str = "rms_norm",
) -> None:
super().__init__()
hidden_size = num_attention_heads * attention_head_dim
mlp_dim = int(hidden_size * mlp_ratio)
self.attn = Attention(
query_dim=hidden_size,
cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=hidden_size,
bias=True,
processor=HunyuanVideoAttnProcessor2_0(),
qk_norm=qk_norm,
eps=1e-6,
pre_only=True,
)
self.norm = HunyuanVideoTokenReplaceAdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
token_replace_emb: torch.Tensor = None,
num_tokens: int = None,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.shape[1]
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
residual = hidden_states
# 1. Input normalization
norm_hidden_states, gate, tr_gate = self.norm(hidden_states, temb, token_replace_emb, num_tokens)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
norm_hidden_states, norm_encoder_hidden_states = (
norm_hidden_states[:, :-text_seq_length, :],
norm_hidden_states[:, -text_seq_length:, :],
)
# 2. Attention
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
attn_output = torch.cat([attn_output, context_attn_output], dim=1)
# 3. Modulation and residual connection
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
proj_output = self.proj_out(hidden_states)
hidden_states_zero = proj_output[:, :num_tokens] * tr_gate.unsqueeze(1)
hidden_states_orig = proj_output[:, num_tokens:] * gate.unsqueeze(1)
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
hidden_states = hidden_states + residual
hidden_states, encoder_hidden_states = (
hidden_states[:, :-text_seq_length, :],
hidden_states[:, -text_seq_length:, :],
)
return hidden_states, encoder_hidden_states
class HunyuanVideoTokenReplaceTransformerBlock(nn.Module):
def __init__(
self,
num_attention_heads: int,
attention_head_dim: int,
mlp_ratio: float,
qk_norm: str = "rms_norm",
) -> None:
super().__init__()
hidden_size = num_attention_heads * attention_head_dim
self.norm1 = HunyuanVideoTokenReplaceAdaLayerNormZero(hidden_size, norm_type="layer_norm")
self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
self.attn = Attention(
query_dim=hidden_size,
cross_attention_dim=None,
added_kv_proj_dim=hidden_size,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=hidden_size,
context_pre_only=False,
bias=True,
processor=HunyuanVideoAttnProcessor2_0(),
qk_norm=qk_norm,
eps=1e-6,
)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
token_replace_emb: torch.Tensor = None,
num_tokens: int = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Input normalization
(
norm_hidden_states,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
tr_gate_msa,
tr_shift_mlp,
tr_scale_mlp,
tr_gate_mlp,
) = self.norm1(hidden_states, temb, token_replace_emb, num_tokens)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
)
# 2. Joint attention
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=freqs_cis,
)
# 3. Modulation and residual connection
hidden_states_zero = hidden_states[:, :num_tokens] + attn_output[:, :num_tokens] * tr_gate_msa.unsqueeze(1)
hidden_states_orig = hidden_states[:, num_tokens:] + attn_output[:, num_tokens:] * gate_msa.unsqueeze(1)
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
norm_hidden_states = self.norm2(hidden_states)
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
hidden_states_zero = norm_hidden_states[:, :num_tokens] * (1 + tr_scale_mlp[:, None]) + tr_shift_mlp[:, None]
hidden_states_orig = norm_hidden_states[:, num_tokens:] * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
norm_hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
# 4. Feed-forward
ff_output = self.ff(norm_hidden_states)
context_ff_output = self.ff_context(norm_encoder_hidden_states)
hidden_states_zero = hidden_states[:, :num_tokens] + ff_output[:, :num_tokens] * tr_gate_mlp.unsqueeze(1)
hidden_states_orig = hidden_states[:, num_tokens:] + ff_output[:, num_tokens:] * gate_mlp.unsqueeze(1)
hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
return hidden_states, encoder_hidden_states
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
r"""
A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
@@ -540,6 +856,10 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
The value of theta to use in the RoPE layer.
rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
The dimensions of the axes to use in the RoPE layer.
image_condition_type (`str`, *optional*, defaults to `None`):
The type of image conditioning to use. If `None`, no image conditioning is used. If `latent_concat`, the
image is concatenated to the latent stream. If `token_replace`, the image is used to replace first-frame
tokens in the latent stream and apply conditioning.
"""
_supports_gradient_checkpointing = True
@@ -570,9 +890,16 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
pooled_projection_dim: int = 768,
rope_theta: float = 256.0,
rope_axes_dim: Tuple[int] = (16, 56, 56),
image_condition_type: Optional[str] = None,
) -> None:
super().__init__()
supported_image_condition_types = ["latent_concat", "token_replace"]
if image_condition_type is not None and image_condition_type not in supported_image_condition_types:
raise ValueError(
f"Invalid `image_condition_type` ({image_condition_type}). Supported ones are: {supported_image_condition_types}"
)
inner_dim = num_attention_heads * attention_head_dim
out_channels = out_channels or in_channels
@@ -582,33 +909,52 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
)
if guidance_embeds:
self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
else:
self.time_text_embed = CombinedTimestepTextProjEmbeddings(inner_dim, pooled_projection_dim)
self.time_text_embed = HunyuanVideoConditionEmbedding(
inner_dim, pooled_projection_dim, guidance_embeds, image_condition_type
)
# 2. RoPE
self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
# 3. Dual stream transformer blocks
self.transformer_blocks = nn.ModuleList(
[
HunyuanVideoTransformerBlock(
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
)
for _ in range(num_layers)
]
)
if image_condition_type == "token_replace":
self.transformer_blocks = nn.ModuleList(
[
HunyuanVideoTokenReplaceTransformerBlock(
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
)
for _ in range(num_layers)
]
)
else:
self.transformer_blocks = nn.ModuleList(
[
HunyuanVideoTransformerBlock(
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
)
for _ in range(num_layers)
]
)
# 4. Single stream transformer blocks
self.single_transformer_blocks = nn.ModuleList(
[
HunyuanVideoSingleTransformerBlock(
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
)
for _ in range(num_single_layers)
]
)
if image_condition_type == "token_replace":
self.single_transformer_blocks = nn.ModuleList(
[
HunyuanVideoTokenReplaceSingleTransformerBlock(
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
)
for _ in range(num_single_layers)
]
)
else:
self.single_transformer_blocks = nn.ModuleList(
[
HunyuanVideoSingleTransformerBlock(
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
)
for _ in range(num_single_layers)
]
)
# 5. Output projection
self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
@@ -707,15 +1053,13 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p
post_patch_width = width // p
first_frame_num_tokens = 1 * post_patch_height * post_patch_width
# 1. RoPE
image_rotary_emb = self.rope(hidden_states)
# 2. Conditional embeddings
if self.config.guidance_embeds:
temb = self.time_text_embed(timestep, guidance, pooled_projections)
else:
temb = self.time_text_embed(timestep, pooled_projections)
temb, token_replace_emb = self.time_text_embed(timestep, pooled_projections, guidance)
hidden_states = self.x_embedder(hidden_states)
encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
@@ -746,6 +1090,8 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
temb,
attention_mask,
image_rotary_emb,
token_replace_emb,
first_frame_num_tokens,
)
for block in self.single_transformer_blocks:
@@ -756,17 +1102,31 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
temb,
attention_mask,
image_rotary_emb,
token_replace_emb,
first_frame_num_tokens,
)
else:
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
hidden_states,
encoder_hidden_states,
temb,
attention_mask,
image_rotary_emb,
token_replace_emb,
first_frame_num_tokens,
)
for block in self.single_transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
hidden_states,
encoder_hidden_states,
temb,
attention_mask,
image_rotary_emb,
token_replace_emb,
first_frame_num_tokens,
)
# 5. Output projection
@@ -14,7 +14,7 @@
# limitations under the License.
import math
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -26,6 +26,7 @@ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention
from ..cache_utils import CacheMixin
from ..embeddings import PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
@@ -113,20 +114,19 @@ class LTXVideoRotaryPosEmbed(nn.Module):
self.patch_size_t = patch_size_t
self.theta = theta
def forward(
def _prepare_video_coords(
self,
hidden_states: torch.Tensor,
batch_size: int,
num_frames: int,
height: int,
width: int,
rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = hidden_states.size(0)
rope_interpolation_scale: Tuple[torch.Tensor, float, float],
device: torch.device,
) -> torch.Tensor:
# Always compute rope in fp32
grid_h = torch.arange(height, dtype=torch.float32, device=hidden_states.device)
grid_w = torch.arange(width, dtype=torch.float32, device=hidden_states.device)
grid_f = torch.arange(num_frames, dtype=torch.float32, device=hidden_states.device)
grid_h = torch.arange(height, dtype=torch.float32, device=device)
grid_w = torch.arange(width, dtype=torch.float32, device=device)
grid_f = torch.arange(num_frames, dtype=torch.float32, device=device)
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij")
grid = torch.stack(grid, dim=0)
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
@@ -138,6 +138,38 @@ class LTXVideoRotaryPosEmbed(nn.Module):
grid = grid.flatten(2, 4).transpose(1, 2)
return grid
def forward(
self,
hidden_states: torch.Tensor,
num_frames: Optional[int] = None,
height: Optional[int] = None,
width: Optional[int] = None,
rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None,
video_coords: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = hidden_states.size(0)
if video_coords is None:
grid = self._prepare_video_coords(
batch_size,
num_frames,
height,
width,
rope_interpolation_scale=rope_interpolation_scale,
device=hidden_states.device,
)
else:
grid = torch.stack(
[
video_coords[:, 0] / self.base_num_frames,
video_coords[:, 1] / self.base_height,
video_coords[:, 2] / self.base_width,
],
dim=-1,
)
start = 1.0
end = self.theta
freqs = self.theta ** torch.linspace(
@@ -267,7 +299,7 @@ class LTXVideoTransformerBlock(nn.Module):
@maybe_allow_in_graph
class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin):
r"""
A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
@@ -367,10 +399,11 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_attention_mask: torch.Tensor,
num_frames: int,
height: int,
width: int,
rope_interpolation_scale: Optional[Tuple[float, float, float]] = None,
num_frames: Optional[int] = None,
height: Optional[int] = None,
width: Optional[int] = None,
rope_interpolation_scale: Optional[Union[Tuple[float, float, float], torch.Tensor]] = None,
video_coords: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> torch.Tensor:
@@ -389,7 +422,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale)
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale, video_coords)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
@@ -24,6 +24,7 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import FeedForward
from ..attention_processor import Attention
from ..cache_utils import CacheMixin
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
@@ -288,7 +289,7 @@ class WanTransformerBlock(nn.Module):
return hidden_states
class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
r"""
A Transformer model for video-like data used in the Wan model.
+6 -6
View File
@@ -264,7 +264,7 @@ else:
]
)
_import_structure["latte"] = ["LattePipeline"]
_import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline"]
_import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline", "LTXConditionPipeline"]
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
_import_structure["marigold"].extend(
@@ -280,7 +280,7 @@ else:
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
_import_structure["pia"] = ["PIAPipeline"]
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"]
_import_structure["sana"] = ["SanaPipeline"]
_import_structure["sana"] = ["SanaPipeline", "SanaSprintPipeline"]
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
_import_structure["stable_audio"] = [
@@ -356,7 +356,7 @@ else:
"WuerstchenDecoderPipeline",
"WuerstchenPriorPipeline",
]
_import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline"]
_import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline", "WanVideoToVideoPipeline"]
try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
@@ -618,7 +618,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
)
from .ltx import LTXImageToVideoPipeline, LTXPipeline
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXPipeline
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
from .marigold import (
@@ -651,7 +651,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .paint_by_example import PaintByExamplePipeline
from .pia import PIAPipeline
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
from .sana import SanaPipeline
from .sana import SanaPipeline, SanaSprintPipeline
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
from .stable_audio import StableAudioPipeline, StableAudioProjectionModel
@@ -709,7 +709,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
UniDiffuserPipeline,
UniDiffuserTextDecoder,
)
from .wan import WanImageToVideoPipeline, WanPipeline
from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline
from .wuerstchen import (
WuerstchenCombinedPipeline,
WuerstchenDecoderPipeline,
@@ -213,9 +213,7 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
device=text_input_ids.device,
)
text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
prompt_embeds = self.text_encoder(
text_input_ids.to(self.text_encoder.device), output_hidden_states=True
).hidden_states[-2]
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=True).hidden_states[-2]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
return prompt_embeds
@@ -68,7 +68,7 @@ def calculate_shift(
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
# Copied from diffusers.pipelines.cogview4.pipeline_cogview4.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
@@ -100,10 +100,19 @@ def retrieve_timesteps(
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps and not accepts_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep or sigma schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif timesteps is not None and sigmas is None:
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
@@ -112,9 +121,8 @@ def retrieve_timesteps(
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
elif timesteps is None and sigmas is not None:
if not accepts_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
@@ -208,9 +216,7 @@ class CogView4ControlPipeline(DiffusionPipeline):
device=text_input_ids.device,
)
text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
prompt_embeds = self.text_encoder(
text_input_ids.to(self.text_encoder.device), output_hidden_states=True
).hidden_states[-2]
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=True).hidden_states[-2]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
return prompt_embeds
@@ -515,8 +521,8 @@ class CogView4ControlPipeline(DiffusionPipeline):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
Whether or not to return a [`~pipelines.pipeline_CogView4.CogView4PipelineOutput`] instead of a plain
tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
@@ -532,7 +538,6 @@ class CogView4ControlPipeline(DiffusionPipeline):
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int`, defaults to `224`):
Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
Examples:
Returns:
@@ -184,7 +184,14 @@ class StableDiffusionControlNetInpaintPipeline(
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
_exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "control_image"]
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
"control_image",
"mask",
"masked_image_latents",
]
def __init__(
self,
@@ -533,7 +533,6 @@ class FluxControlImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSin
return latents
# Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.prepare_latents
def prepare_latents(
self,
image,
@@ -63,6 +63,7 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers import FluxControlNetPipeline
>>> from diffusers import FluxControlNetModel
>>> base_model = "black-forest-labs/FLUX.1-dev"
>>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny"
>>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
>>> pipe = FluxControlNetPipeline.from_pretrained(
@@ -533,7 +533,6 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
return latents
# Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.prepare_latents
def prepare_latents(
self,
image,
@@ -561,7 +561,6 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
return latents
# Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_latents
def prepare_latents(
self,
image,
@@ -614,7 +613,6 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
return latents, noise, image_latents, latent_image_ids
# Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_mask_latents
def prepare_mask_latents(
self,
mask,
@@ -224,11 +224,13 @@ class FluxFillPipeline(
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
self.image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels
)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor * 2,
vae_latent_channels=latent_channels,
vae_latent_channels=self.latent_channels,
do_normalize=False,
do_binarize=True,
do_convert_grayscale=True,
@@ -493,10 +495,38 @@ class FluxFillPipeline(
return prompt_embeds, pooled_prompt_embeds, text_ids
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
return image_latents
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(num_inference_steps * strength, num_inference_steps)
t_start = int(max(num_inference_steps - init_timestep, 0))
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start
def check_inputs(
self,
prompt,
prompt_2,
strength,
height,
width,
prompt_embeds=None,
@@ -507,6 +537,9 @@ class FluxFillPipeline(
mask_image=None,
masked_image_latents=None,
):
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
@@ -624,9 +657,11 @@ class FluxFillPipeline(
"""
self.vae.disable_tiling()
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
# Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.prepare_latents
def prepare_latents(
self,
image,
timestep,
batch_size,
num_channels_latents,
height,
@@ -636,28 +671,41 @@ class FluxFillPipeline(
generator,
latents=None,
):
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
if latents is not None:
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
return latents.to(device=device, dtype=dtype), latent_image_ids
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
if latents is not None:
return latents.to(device=device, dtype=dtype), latent_image_ids
image = image.to(device=device, dtype=dtype)
if image.shape[1] != self.latent_channels:
image_latents = self._encode_vae_image(image=image, generator=generator)
else:
image_latents = image
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand init_latents for batch_size
additional_image_per_prompt = batch_size // image_latents.shape[0]
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
)
else:
image_latents = torch.cat([image_latents], dim=0)
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
return latents, latent_image_ids
@property
@@ -687,6 +735,7 @@ class FluxFillPipeline(
masked_image_latents: Optional[torch.FloatTensor] = None,
height: Optional[int] = None,
width: Optional[int] = None,
strength: float = 1.0,
num_inference_steps: int = 50,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 30.0,
@@ -731,6 +780,12 @@ class FluxFillPipeline(
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
strength (`float`, *optional*, defaults to 1.0):
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
essentially ignores `image`.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
@@ -794,6 +849,7 @@ class FluxFillPipeline(
self.check_inputs(
prompt,
prompt_2,
strength,
height,
width,
prompt_embeds=prompt_embeds,
@@ -809,6 +865,9 @@ class FluxFillPipeline(
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
init_image = self.image_processor.preprocess(image, height=height, width=width)
init_image = init_image.to(dtype=torch.float32)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -838,47 +897,9 @@ class FluxFillPipeline(
lora_scale=lora_scale,
)
# 4. Prepare latent variables
num_channels_latents = self.vae.config.latent_channels
latents, latent_image_ids = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 5. Prepare mask and masked image latents
if masked_image_latents is not None:
masked_image_latents = masked_image_latents.to(latents.device)
else:
image = self.image_processor.preprocess(image, height=height, width=width)
mask_image = self.mask_processor.preprocess(mask_image, height=height, width=width)
masked_image = image * (1 - mask_image)
masked_image = masked_image.to(device=device, dtype=prompt_embeds.dtype)
height, width = image.shape[-2:]
mask, masked_image_latents = self.prepare_mask_latents(
mask_image,
masked_image,
batch_size,
num_channels_latents,
num_images_per_prompt,
height,
width,
prompt_embeds.dtype,
device,
generator,
)
masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1)
# 6. Prepare timesteps
# 4. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
image_seq_len = latents.shape[1]
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
@@ -893,6 +914,54 @@ class FluxFillPipeline(
sigmas=sigmas,
mu=mu,
)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
if num_inference_steps < 1:
raise ValueError(
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# 5. Prepare latent variables
num_channels_latents = self.vae.config.latent_channels
latents, latent_image_ids = self.prepare_latents(
init_image,
latent_timestep,
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare mask and masked image latents
if masked_image_latents is not None:
masked_image_latents = masked_image_latents.to(latents.device)
else:
mask_image = self.mask_processor.preprocess(mask_image, height=height, width=width)
masked_image = init_image * (1 - mask_image)
masked_image = masked_image.to(device=device, dtype=prompt_embeds.dtype)
height, width = init_image.shape[-2:]
mask, masked_image_latents = self.prepare_mask_latents(
mask_image,
masked_image,
batch_size,
num_channels_latents,
num_images_per_prompt,
height,
width,
prompt_embeds.dtype,
device,
generator,
)
masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
@@ -225,7 +225,10 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
self.image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels
)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
@@ -634,7 +637,10 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
return latents.to(device=device, dtype=dtype), latent_image_ids
image = image.to(device=device, dtype=dtype)
image_latents = self._encode_vae_image(image=image, generator=generator)
if image.shape[1] != self.latent_channels:
image_latents = self._encode_vae_image(image=image, generator=generator)
else:
image_latents = image
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand init_latents for batch_size
additional_image_per_prompt = batch_size // image_latents.shape[0]
@@ -222,11 +222,13 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
self.image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels
)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor * 2,
vae_latent_channels=latent_channels,
vae_latent_channels=self.latent_channels,
do_normalize=False,
do_binarize=True,
do_convert_grayscale=True,
@@ -653,7 +655,10 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
image = image.to(device=device, dtype=dtype)
image_latents = self._encode_vae_image(image=image, generator=generator)
if image.shape[1] != self.latent_channels:
image_latents = self._encode_vae_image(image=image, generator=generator)
else:
image_latents = image
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand init_latents for batch_size
@@ -710,7 +715,9 @@ class FluxInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FluxIPAdapterM
else:
masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
masked_image_latents = (
masked_image_latents - self.vae.config.shift_factor
) * self.vae.config.scaling_factor
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
@@ -54,6 +54,7 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers import HunyuanVideoImageToVideoPipeline, HunyuanVideoTransformer3DModel
>>> from diffusers.utils import load_image, export_to_video
>>> # Available checkpoints: hunyuanvideo-community/HunyuanVideo-I2V, hunyuanvideo-community/HunyuanVideo-I2V-33ch
>>> model_id = "hunyuanvideo-community/HunyuanVideo-I2V"
>>> transformer = HunyuanVideoTransformer3DModel.from_pretrained(
... model_id, subfolder="transformer", torch_dtype=torch.bfloat16
@@ -69,7 +70,12 @@ EXAMPLE_DOC_STRING = """
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/guitar-man.png"
... )
>>> output = pipe(image=image, prompt=prompt).frames[0]
>>> # If using hunyuanvideo-community/HunyuanVideo-I2V
>>> output = pipe(image=image, prompt=prompt, guidance_scale=6.0).frames[0]
>>> # If using hunyuanvideo-community/HunyuanVideo-I2V-33ch
>>> output = pipe(image=image, prompt=prompt, guidance_scale=1.0, true_cfg_scale=1.0).frames[0]
>>> export_to_video(output, "output.mp4", fps=15)
```
"""
@@ -399,7 +405,8 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 256,
):
image_embed_interleave: int = 2,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if prompt_embeds is None:
prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds(
image,
@@ -409,6 +416,7 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
device=device,
dtype=dtype,
max_sequence_length=max_sequence_length,
image_embed_interleave=image_embed_interleave,
)
if pooled_prompt_embeds is None:
@@ -433,6 +441,8 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
prompt_template=None,
true_cfg_scale=1.0,
guidance_scale=1.0,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
@@ -471,6 +481,13 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}"
)
if true_cfg_scale > 1.0 and guidance_scale > 1.0:
logger.warning(
"Both `true_cfg_scale` and `guidance_scale` are greater than 1.0. This will result in both "
"classifier-free guidance and embedded-guidance to be applied. This is not recommended "
"as it may lead to higher memory usage, slower inference and potentially worse results."
)
def prepare_latents(
self,
image: torch.Tensor,
@@ -483,6 +500,7 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
image_condition_type: str = "latent_concat",
) -> torch.Tensor:
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
@@ -497,10 +515,11 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
image = image.unsqueeze(2) # [B, C, 1, H, W]
if isinstance(generator, list):
image_latents = [
retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i], "argmax")
for i in range(batch_size)
]
else:
image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image]
image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator, "argmax") for img in image]
image_latents = torch.cat(image_latents, dim=0).to(dtype) * self.vae_scaling_factor
image_latents = image_latents.repeat(1, 1, num_latent_frames, 1, 1)
@@ -513,6 +532,9 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
t = torch.tensor([0.999]).to(device=device)
latents = latents * t + image_latents * (1 - t)
if image_condition_type == "token_replace":
image_latents = image_latents[:, :, :1]
return latents, image_latents
def enable_vae_slicing(self):
@@ -598,6 +620,7 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE,
max_sequence_length: int = 256,
image_embed_interleave: Optional[int] = None,
):
r"""
The call function to the pipeline for generation.
@@ -704,12 +727,22 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
prompt_embeds,
callback_on_step_end_tensor_inputs,
prompt_template,
true_cfg_scale,
guidance_scale,
)
image_condition_type = self.transformer.config.image_condition_type
has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
)
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
image_embed_interleave = (
image_embed_interleave
if image_embed_interleave is not None
else (
2 if image_condition_type == "latent_concat" else 4 if image_condition_type == "token_replace" else 1
)
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
@@ -729,7 +762,12 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
# 3. Prepare latent variables
vae_dtype = self.vae.dtype
image_tensor = self.video_processor.preprocess(image, height, width).to(device, vae_dtype)
num_channels_latents = (self.transformer.config.in_channels - 1) // 2
if image_condition_type == "latent_concat":
num_channels_latents = (self.transformer.config.in_channels - 1) // 2
elif image_condition_type == "token_replace":
num_channels_latents = self.transformer.config.in_channels
latents, image_latents = self.prepare_latents(
image_tensor,
batch_size * num_videos_per_prompt,
@@ -741,10 +779,12 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
device,
generator,
latents,
image_condition_type,
)
image_latents[:, :, 1:] = 0
mask = image_latents.new_ones(image_latents.shape[0], 1, *image_latents.shape[2:])
mask[:, :, 1:] = 0
if image_condition_type == "latent_concat":
image_latents[:, :, 1:] = 0
mask = image_latents.new_ones(image_latents.shape[0], 1, *image_latents.shape[2:])
mask[:, :, 1:] = 0
# 4. Encode input prompt
transformer_dtype = self.transformer.dtype
@@ -759,6 +799,7 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
prompt_attention_mask=prompt_attention_mask,
device=device,
max_sequence_length=max_sequence_length,
image_embed_interleave=image_embed_interleave,
)
prompt_embeds = prompt_embeds.to(transformer_dtype)
prompt_attention_mask = prompt_attention_mask.to(transformer_dtype)
@@ -782,10 +823,17 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype)
# 4. Prepare timesteps
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
# 6. Prepare guidance condition
guidance = None
if self.transformer.config.guidance_embeds:
guidance = (
torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0
)
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
@@ -796,16 +844,21 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
continue
self._current_timestep = t
latent_model_input = torch.cat([latents, image_latents, mask], dim=1).to(transformer_dtype)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
if image_condition_type == "latent_concat":
latent_model_input = torch.cat([latents, image_latents, mask], dim=1).to(transformer_dtype)
elif image_condition_type == "token_replace":
latent_model_input = torch.cat([image_latents, latents[:, :, 1:]], dim=2).to(transformer_dtype)
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
pooled_projections=pooled_prompt_embeds,
guidance=guidance,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
@@ -817,13 +870,20 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
encoder_hidden_states=negative_prompt_embeds,
encoder_attention_mask=negative_prompt_attention_mask,
pooled_projections=negative_pooled_prompt_embeds,
guidance=guidance,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if image_condition_type == "latent_concat":
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
elif image_condition_type == "token_replace":
latents = latents = self.scheduler.step(
noise_pred[:, :, 1:], t, latents[:, :, 1:], return_dict=False
)[0]
latents = torch.cat([image_latents, latents], dim=2)
if callback_on_step_end is not None:
callback_kwargs = {}
@@ -844,12 +904,16 @@ class HunyuanVideoImageToVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoader
self._current_timestep = None
if not output_type == "latent":
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
latents = latents.to(self.vae.dtype) / self.vae_scaling_factor
video = self.vae.decode(latents, return_dict=False)[0]
video = video[:, :, 4:, :, :]
if image_condition_type == "latent_concat":
video = video[:, :, 4:, :, :]
video = self.video_processor.postprocess_video(video, output_type=output_type)
else:
video = latents[:, :, 1:, :, :]
if image_condition_type == "latent_concat":
video = latents[:, :, 1:, :, :]
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
@@ -19,7 +19,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import IPAdapterMixin, StableDiffusionXLLoraLoaderMixin
from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
from ...schedulers import KarrasDiffusionSchedulers
@@ -121,7 +121,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLLoraLoaderMixin, IPAdapterMixin):
class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionLoraLoaderMixin, IPAdapterMixin):
r"""
Pipeline for text-to-image generation using Kolors.
@@ -129,8 +129,8 @@ class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLL
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
The pipeline also inherits the following loading methods:
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args:
+6 -29
View File
@@ -104,13 +104,6 @@ class RMSNorm(torch.nn.Module):
return (self.weight * hidden_states).to(input_dtype)
def _config_to_kwargs(args):
common_kwargs = {
"dtype": args.torch_dtype,
}
return common_kwargs
class CoreAttention(torch.nn.Module):
def __init__(self, config: ChatGLMConfig, layer_number):
super(CoreAttention, self).__init__()
@@ -314,7 +307,6 @@ class SelfAttention(torch.nn.Module):
self.qkv_hidden_size,
bias=config.add_bias_linear or config.add_qkv_bias,
device=device,
**_config_to_kwargs(config),
)
self.core_attention = CoreAttention(config, self.layer_number)
@@ -325,7 +317,6 @@ class SelfAttention(torch.nn.Module):
config.hidden_size,
bias=config.add_bias_linear,
device=device,
**_config_to_kwargs(config),
)
def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
@@ -449,7 +440,6 @@ class MLP(torch.nn.Module):
config.ffn_hidden_size * 2,
bias=self.add_bias,
device=device,
**_config_to_kwargs(config),
)
def swiglu(x):
@@ -459,9 +449,7 @@ class MLP(torch.nn.Module):
self.activation_func = swiglu
# Project back to h.
self.dense_4h_to_h = nn.Linear(
config.ffn_hidden_size, config.hidden_size, bias=self.add_bias, device=device, **_config_to_kwargs(config)
)
self.dense_4h_to_h = nn.Linear(config.ffn_hidden_size, config.hidden_size, bias=self.add_bias, device=device)
def forward(self, hidden_states):
# [s, b, 4hp]
@@ -488,18 +476,14 @@ class GLMBlock(torch.nn.Module):
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
# Layernorm on the input data.
self.input_layernorm = LayerNormFunc(
config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
)
self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device)
# Self attention.
self.self_attention = SelfAttention(config, layer_number, device=device)
self.hidden_dropout = config.hidden_dropout
# Layernorm on the attention output
self.post_attention_layernorm = LayerNormFunc(
config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
)
self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device)
# MLP
self.mlp = MLP(config, device=device)
@@ -569,9 +553,7 @@ class GLMTransformer(torch.nn.Module):
if self.post_layer_norm:
LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
# Final layer norm before output.
self.final_layernorm = LayerNormFunc(
config.hidden_size, eps=config.layernorm_epsilon, device=device, dtype=config.torch_dtype
)
self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device)
self.gradient_checkpointing = False
@@ -679,9 +661,7 @@ class Embedding(torch.nn.Module):
self.hidden_size = config.hidden_size
# Word embeddings (parallel).
self.word_embeddings = nn.Embedding(
config.padded_vocab_size, self.hidden_size, dtype=config.torch_dtype, device=device
)
self.word_embeddings = nn.Embedding(config.padded_vocab_size, self.hidden_size, device=device)
self.fp32_residual_connection = config.fp32_residual_connection
def forward(self, input_ids):
@@ -784,16 +764,13 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
)
self.rotary_pos_emb = RotaryEmbedding(
rotary_dim // 2, original_impl=config.original_rope, device=device, dtype=config.torch_dtype
)
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device)
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
self.output_layer = init_method(
nn.Linear,
config.hidden_size,
config.padded_vocab_size,
bias=False,
dtype=config.torch_dtype,
**init_kwargs,
)
self.pre_seq_len = config.pre_seq_len
@@ -817,7 +817,7 @@ class LattePipeline(DiffusionPipeline):
# predict noise model_output
noise_pred = self.transformer(
latent_model_input,
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=current_timestep,
enable_temporal_attentions=enable_temporal_attentions,
+2
View File
@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_ltx"] = ["LTXPipeline"]
_import_structure["pipeline_ltx_condition"] = ["LTXConditionPipeline"]
_import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -34,6 +35,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_ltx import LTXPipeline
from .pipeline_ltx_condition import LTXConditionPipeline
from .pipeline_ltx_image2video import LTXImageToVideoPipeline
else:
+8 -2
View File
@@ -489,6 +489,10 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def attention_kwargs(self):
return self._attention_kwargs
@@ -622,6 +626,7 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._interrupt = False
self._current_timestep = None
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
@@ -694,9 +699,8 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
self._num_timesteps = len(timesteps)
# 6. Prepare micro-conditions
latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
rope_interpolation_scale = (
1 / latent_frame_rate,
self.vae_temporal_compression_ratio / frame_rate,
self.vae_spatial_compression_ratio,
self.vae_spatial_compression_ratio,
)
@@ -707,6 +711,8 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
File diff suppressed because it is too large Load Diff
@@ -487,19 +487,21 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
) -> torch.Tensor:
height = height // self.vae_spatial_compression_ratio
width = width // self.vae_spatial_compression_ratio
num_frames = (
(num_frames - 1) // self.vae_temporal_compression_ratio + 1 if latents is None else latents.size(2)
)
num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
shape = (batch_size, num_channels_latents, num_frames, height, width)
mask_shape = (batch_size, 1, num_frames, height, width)
if latents is not None:
conditioning_mask = latents.new_zeros(shape)
conditioning_mask = latents.new_zeros(mask_shape)
conditioning_mask[:, :, 0] = 1.0
conditioning_mask = self._pack_latents(
conditioning_mask, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
)
).squeeze(-1)
if latents.ndim != 3 or latents.shape[:2] != conditioning_mask.shape:
raise ValueError(
f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is {conditioning_mask.shape + (num_channels_latents,)}."
)
return latents.to(device=device, dtype=dtype), conditioning_mask
if isinstance(generator, list):
@@ -548,6 +550,10 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def attention_kwargs(self):
return self._attention_kwargs
@@ -684,6 +690,7 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._interrupt = False
self._current_timestep = None
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
@@ -764,9 +771,8 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
self._num_timesteps = len(timesteps)
# 6. Prepare micro-conditions
latent_frame_rate = frame_rate / self.vae_temporal_compression_ratio
rope_interpolation_scale = (
1 / latent_frame_rate,
self.vae_temporal_compression_ratio / frame_rate,
self.vae_spatial_compression_ratio,
self.vae_spatial_compression_ratio,
)
@@ -777,6 +783,8 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
@@ -592,6 +592,11 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
loaded_sub_model = passed_class_obj[name]
else:
sub_model_dtype = (
torch_dtype.get(name, torch_dtype.get("default", torch.float32))
if isinstance(torch_dtype, dict)
else torch_dtype
)
loaded_sub_model = _load_empty_model(
library_name=library_name,
class_name=class_name,
@@ -600,7 +605,7 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
is_pipeline_module=is_pipeline_module,
pipeline_class=pipeline_class,
name=name,
torch_dtype=torch_dtype,
torch_dtype=sub_model_dtype,
cached_folder=kwargs.get("cached_folder", None),
force_download=kwargs.get("force_download", None),
proxies=kwargs.get("proxies", None),
@@ -616,7 +621,12 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
# Obtain a sorted dictionary for mapping the model-level components
# to their sizes.
module_sizes = {
module_name: compute_module_sizes(module, dtype=torch_dtype)[""]
module_name: compute_module_sizes(
module,
dtype=torch_dtype.get(module_name, torch_dtype.get("default", torch.float32))
if isinstance(torch_dtype, dict)
else torch_dtype,
)[""]
for module_name, module in init_empty_modules.items()
if isinstance(module, torch.nn.Module)
}
+18 -10
View File
@@ -427,7 +427,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline."
)
if device_type == "cuda":
if device_type in ["cuda", "xpu"]:
if pipeline_is_sequentially_offloaded and not pipeline_has_bnb:
raise ValueError(
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
@@ -440,7 +440,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
# Display a warning in this case (the operation succeeds but the benefits are lost)
pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
if pipeline_is_offloaded and device_type == "cuda":
if pipeline_is_offloaded and device_type in ["cuda", "xpu"]:
logger.warning(
f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
)
@@ -552,9 +552,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
saved using
[`~DiffusionPipeline.save_pretrained`].
- A path to a *directory* (for example `./my_pipeline_directory/`) containing a dduf file
torch_dtype (`str` or `torch.dtype`, *optional*):
torch_dtype (`str` or `torch.dtype` or `dict[str, Union[str, torch.dtype]]`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
dtype is automatically derived from the model's weights.
dtype is automatically derived from the model's weights. To load submodels with different dtype pass a
`dict` (for example `{'transformer': torch.bfloat16, 'vae': torch.float16}`). Set the default dtype for
unspecified components with `default` (for example `{'transformer': torch.bfloat16, 'default':
torch.float16}`). If a component is not specified and no default is set, `torch.float32` is used.
custom_pipeline (`str`, *optional*):
<Tip warning={true}>
@@ -686,7 +689,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
from_flax = kwargs.pop("from_flax", False)
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
torch_dtype = kwargs.pop("torch_dtype", None)
custom_pipeline = kwargs.pop("custom_pipeline", None)
custom_revision = kwargs.pop("custom_revision", None)
provider = kwargs.pop("provider", None)
@@ -703,7 +706,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
if not isinstance(torch_dtype, torch.dtype):
if torch_dtype is not None and not isinstance(torch_dtype, dict) and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
@@ -950,6 +953,11 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
loaded_sub_model = passed_class_obj[name]
else:
# load sub model
sub_model_dtype = (
torch_dtype.get(name, torch_dtype.get("default", torch.float32))
if isinstance(torch_dtype, dict)
else torch_dtype
)
loaded_sub_model = load_sub_model(
library_name=library_name,
class_name=class_name,
@@ -957,7 +965,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
pipelines=pipelines,
is_pipeline_module=is_pipeline_module,
pipeline_class=pipeline_class,
torch_dtype=torch_dtype,
torch_dtype=sub_model_dtype,
provider=provider,
sess_options=sess_options,
device_map=current_device_map,
@@ -998,7 +1006,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
for module in missing_modules:
init_kwargs[module] = passed_class_obj.get(module, None)
elif len(missing_modules) > 0:
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - set(optional_kwargs)
raise ValueError(
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
)
@@ -1456,8 +1464,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
if load_components_from_hub and not trust_remote_code:
raise ValueError(
f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k,v in custom_components.items()])} which must be executed to correctly "
f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k,v in custom_components.items()])}.\n"
f"The repository for {pretrained_model_name} contains custom code in {'.py, '.join([os.path.join(k, v) for k, v in custom_components.items()])} which must be executed to correctly "
f"load the model. You can inspect the repository content at {', '.join([f'https://hf.co/{pretrained_model_name}/{k}/{v}.py' for k, v in custom_components.items()])}.\n"
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
)
@@ -941,8 +941,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
# compute previous image: x_t -> x_t-1
if num_inference_steps == 1:
# For DMD one step sampling: https://arxiv.org/abs/2311.18828
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[1]
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
@@ -868,7 +868,7 @@ class PixArtSigmaPipeline(DiffusionPipeline):
xm.mark_step()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.vae.decode(latents.to(self.vae.dtype) / self.vae.config.scaling_factor, return_dict=False)[0]
if use_resolution_binning:
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
else:
+2
View File
@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_sana"] = ["SanaPipeline"]
_import_structure["pipeline_sana_sprint"] = ["SanaSprintPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -33,6 +34,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_sana import SanaPipeline
from .pipeline_sana_sprint import SanaSprintPipeline
else:
import sys
+82 -51
View File
@@ -248,6 +248,64 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
"""
self.vae.disable_tiling()
def _get_gemma_prompt_embeds(
self,
prompt: Union[str, List[str]],
device: torch.device,
dtype: torch.dtype,
clean_caption: bool = False,
max_sequence_length: int = 300,
complex_human_instruction: Optional[List[str]] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`, *optional*):
torch device to place the resulting embeddings on
clean_caption (`bool`, defaults to `False`):
If `True`, the function will preprocess and clean the provided caption before encoding.
max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
the prompt.
"""
prompt = [prompt] if isinstance(prompt, str) else prompt
if getattr(self, "tokenizer", None) is not None:
self.tokenizer.padding_side = "right"
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
# prepare complex human instruction
if not complex_human_instruction:
max_length_all = max_sequence_length
else:
chi_prompt = "\n".join(complex_human_instruction)
prompt = [chi_prompt + p for p in prompt]
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
max_length_all = num_chi_prompt_tokens + max_sequence_length - 2
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_length_all,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.to(device)
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device)
return prompt_embeds, prompt_attention_mask
def encode_prompt(
self,
prompt: Union[str, List[str]],
@@ -296,6 +354,13 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
if device is None:
device = self._execution_device
if self.transformer is not None:
dtype = self.transformer.dtype
elif self.text_encoder is not None:
dtype = self.text_encoder.dtype
else:
dtype = None
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin):
@@ -320,43 +385,18 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
select_index = [0] + list(range(-max_length + 1, 0))
if prompt_embeds is None:
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
# prepare complex human instruction
if not complex_human_instruction:
max_length_all = max_length
else:
chi_prompt = "\n".join(complex_human_instruction)
prompt = [chi_prompt + p for p in prompt]
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
max_length_all = num_chi_prompt_tokens + max_length - 2
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_length_all,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
prompt=prompt,
device=device,
dtype=dtype,
clean_caption=clean_caption,
max_sequence_length=max_sequence_length,
complex_human_instruction=complex_human_instruction,
)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.to(device)
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
prompt_embeds = prompt_embeds[0][:, select_index]
prompt_embeds = prompt_embeds[:, select_index]
prompt_attention_mask = prompt_attention_mask[:, select_index]
if self.transformer is not None:
dtype = self.transformer.dtype
elif self.text_encoder is not None:
dtype = self.text_encoder.dtype
else:
dtype = None
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -366,25 +406,15 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
prompt=negative_prompt,
device=device,
dtype=dtype,
clean_caption=clean_caption,
max_sequence_length=max_sequence_length,
complex_human_instruction=False,
)
negative_prompt_attention_mask = uncond_input.attention_mask
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
)
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
@@ -908,6 +938,7 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
timestep = timestep * self.transformer.config.timestep_scale
# predict noise model_output
noise_pred = self.transformer(
@@ -0,0 +1,889 @@
# Copyright 2024 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import html
import inspect
import re
import urllib.parse as ul
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PixArtImageProcessor
from ...loaders import SanaLoraLoaderMixin
from ...models import AutoencoderDC, SanaTransformer2DModel
from ...schedulers import DPMSolverMultistepScheduler
from ...utils import (
BACKENDS_MAPPING,
USE_PEFT_BACKEND,
is_bs4_available,
is_ftfy_available,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..pixart_alpha.pipeline_pixart_alpha import ASPECT_RATIO_1024_BIN
from .pipeline_output import SanaPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_bs4_available():
from bs4 import BeautifulSoup
if is_ftfy_available():
import ftfy
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import SanaSprintPipeline
>>> pipe = SanaSprintPipeline.from_pretrained(
... "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers", torch_dtype=torch.bfloat16
... )
>>> pipe.to("cuda")
>>> image = pipe(prompt="a tiny astronaut hatching from an egg on the moon")[0]
>>> image[0].save("output.png")
```
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
r"""
Pipeline for text-to-image generation using [SANA-Sprint](https://huggingface.co/papers/2503.09641).
"""
# fmt: off
bad_punct_regex = re.compile(r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}")
# fmt: on
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
text_encoder: Gemma2PreTrainedModel,
vae: AutoencoderDC,
transformer: SanaTransformer2DModel,
scheduler: DPMSolverMultistepScheduler,
):
super().__init__()
self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.encoder_block_out_channels) - 1)
if hasattr(self, "vae") and self.vae is not None
else 32
)
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
# Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds
def _get_gemma_prompt_embeds(
self,
prompt: Union[str, List[str]],
device: torch.device,
dtype: torch.dtype,
clean_caption: bool = False,
max_sequence_length: int = 300,
complex_human_instruction: Optional[List[str]] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`, *optional*):
torch device to place the resulting embeddings on
clean_caption (`bool`, defaults to `False`):
If `True`, the function will preprocess and clean the provided caption before encoding.
max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
the prompt.
"""
prompt = [prompt] if isinstance(prompt, str) else prompt
if getattr(self, "tokenizer", None) is not None:
self.tokenizer.padding_side = "right"
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
# prepare complex human instruction
if not complex_human_instruction:
max_length_all = max_sequence_length
else:
chi_prompt = "\n".join(complex_human_instruction)
prompt = [chi_prompt + p for p in prompt]
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
max_length_all = num_chi_prompt_tokens + max_sequence_length - 2
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_length_all,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.to(device)
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device)
return prompt_embeds, prompt_attention_mask
def encode_prompt(
self,
prompt: Union[str, List[str]],
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
clean_caption: bool = False,
max_sequence_length: int = 300,
complex_human_instruction: Optional[List[str]] = None,
lora_scale: Optional[float] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
num_images_per_prompt (`int`, *optional*, defaults to 1):
number of images that should be generated per prompt
device: (`torch.device`, *optional*):
torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
clean_caption (`bool`, defaults to `False`):
If `True`, the function will preprocess and clean the provided caption before encoding.
max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
the prompt.
"""
if device is None:
device = self._execution_device
if self.transformer is not None:
dtype = self.transformer.dtype
elif self.text_encoder is not None:
dtype = self.text_encoder.dtype
else:
dtype = None
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if self.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder, lora_scale)
if getattr(self, "tokenizer", None) is not None:
self.tokenizer.padding_side = "right"
# See Section 3.1. of the paper.
max_length = max_sequence_length
select_index = [0] + list(range(-max_length + 1, 0))
if prompt_embeds is None:
prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
prompt=prompt,
device=device,
dtype=dtype,
clean_caption=clean_caption,
max_sequence_length=max_sequence_length,
complex_human_instruction=complex_human_instruction,
)
prompt_embeds = prompt_embeds[:, select_index]
prompt_attention_mask = prompt_attention_mask[:, select_index]
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
if self.text_encoder is not None:
if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, prompt_attention_mask
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(
self,
prompt,
height,
width,
num_inference_steps,
timesteps,
max_timesteps,
intermediate_timesteps,
callback_on_step_end_tensor_inputs=None,
prompt_embeds=None,
prompt_attention_mask=None,
):
if height % 32 != 0 or width % 32 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt_embeds is not None and prompt_attention_mask is None:
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
if timesteps is not None and len(timesteps) != num_inference_steps + 1:
raise ValueError("If providing custom timesteps, `timesteps` must be of length `num_inference_steps + 1`.")
if timesteps is not None and max_timesteps is not None:
raise ValueError("If providing custom timesteps, `max_timesteps` should not be provided.")
if timesteps is None and max_timesteps is None:
raise ValueError("Should provide either `timesteps` or `max_timesteps`.")
if intermediate_timesteps is not None and num_inference_steps != 2:
raise ValueError("Intermediate timesteps for SCM is not supported when num_inference_steps != 2.")
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
def _text_preprocessing(self, text, clean_caption=False):
if clean_caption and not is_bs4_available():
logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
clean_caption = False
if clean_caption and not is_ftfy_available():
logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
clean_caption = False
if not isinstance(text, (tuple, list)):
text = [text]
def process(text: str):
if clean_caption:
text = self._clean_caption(text)
text = self._clean_caption(text)
else:
text = text.lower().strip()
return text
return [process(t) for t in text]
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
def _clean_caption(self, caption):
caption = str(caption)
caption = ul.unquote_plus(caption)
caption = caption.strip().lower()
caption = re.sub("<person>", "person", caption)
# urls:
caption = re.sub(
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
"",
caption,
) # regex for urls
caption = re.sub(
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
"",
caption,
) # regex for urls
# html:
caption = BeautifulSoup(caption, features="html.parser").text
# @<nickname>
caption = re.sub(r"@[\w\d]+\b", "", caption)
# 31C0—31EF CJK Strokes
# 31F0—31FF Katakana Phonetic Extensions
# 3200—32FF Enclosed CJK Letters and Months
# 3300—33FF CJK Compatibility
# 3400—4DBF CJK Unified Ideographs Extension A
# 4DC0—4DFF Yijing Hexagram Symbols
# 4E00—9FFF CJK Unified Ideographs
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
#######################################################
# все виды тире / all types of dash --> "-"
caption = re.sub(
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
"-",
caption,
)
# кавычки к одному стандарту
caption = re.sub(r"[`´«»“”¨]", '"', caption)
caption = re.sub(r"[‘’]", "'", caption)
# &quot;
caption = re.sub(r"&quot;?", "", caption)
# &amp
caption = re.sub(r"&amp", "", caption)
# ip adresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
caption = re.sub(r"\d:\d\d\s+$", "", caption)
# \n
caption = re.sub(r"\\n", " ", caption)
# "#123"
caption = re.sub(r"#\d{1,3}\b", "", caption)
# "#12345.."
caption = re.sub(r"#\d{5,}\b", "", caption)
# "123456.."
caption = re.sub(r"\b\d{6,}\b", "", caption)
# filenames:
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
#
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
# this-is-my-cute-cat / this_is_my_cute_cat
regex2 = re.compile(r"(?:\-|\_)")
if len(re.findall(regex2, caption)) > 3:
caption = re.sub(regex2, " ", caption)
caption = ftfy.fix_text(caption)
caption = html.unescape(html.unescape(caption))
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
caption = re.sub(r"\s+", " ", caption)
caption.strip()
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
caption = re.sub(r"^\.\S+$", "", caption)
return caption.strip()
# Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
if latents is not None:
return latents.to(device=device, dtype=dtype)
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def attention_kwargs(self):
return self._attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
num_inference_steps: int = 2,
timesteps: List[int] = None,
max_timesteps: float = 1.57080,
intermediate_timesteps: float = 1.3,
guidance_scale: float = 4.5,
num_images_per_prompt: Optional[int] = 1,
height: int = 1024,
width: int = 1024,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
clean_caption: bool = False,
use_resolution_binning: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 300,
complex_human_instruction: List[str] = [
"Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:",
"- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.",
"- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.",
"Here are examples of how to transform or refine prompts:",
"- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.",
"- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.",
"Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:",
"User Prompt: ",
],
) -> Union[SanaPipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
num_inference_steps (`int`, *optional*, defaults to 20):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
max_timesteps (`float`, *optional*, defaults to 1.57080):
The maximum timestep value used in the SCM scheduler.
intermediate_timesteps (`float`, *optional*, defaults to 1.3):
The intermediate timestep value used in SCM scheduler (only used when num_inference_steps=2).
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 4.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
height (`int`, *optional*, defaults to self.unet.config.sample_size):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size):
The width in pixels of the generated image.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
attention_kwargs:
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
clean_caption (`bool`, *optional*, defaults to `True`):
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
be installed. If the dependencies are not installed, the embeddings will be created from the raw
prompt.
use_resolution_binning (`bool` defaults to `True`):
If set to `True`, the requested height and width are first mapped to the closest resolutions using
`ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
the requested resolution. Useful for generating non-square images.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to `300`):
Maximum sequence length to use with the `prompt`.
complex_human_instruction (`List[str]`, *optional*):
Instructions for complex human attention:
https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55.
Examples:
Returns:
[`~pipelines.sana.pipeline_output.SanaPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list with the generated images
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 1. Check inputs. Raise error if not correct
if use_resolution_binning:
if self.transformer.config.sample_size == 32:
aspect_ratio_bin = ASPECT_RATIO_1024_BIN
else:
raise ValueError("Invalid sample size")
orig_height, orig_width = height, width
height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
self.check_inputs(
prompt=prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
timesteps=timesteps,
max_timesteps=max_timesteps,
intermediate_timesteps=intermediate_timesteps,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._interrupt = False
# 2. Default height and width to transformer
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
# 3. Encode input prompt
(
prompt_embeds,
prompt_attention_mask,
) = self.encode_prompt(
prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
clean_caption=clean_caption,
max_sequence_length=max_sequence_length,
complex_human_instruction=complex_human_instruction,
lora_scale=lora_scale,
)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas=None,
max_timesteps=max_timesteps,
intermediate_timesteps=intermediate_timesteps,
)
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(0)
# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
latent_channels,
height,
width,
torch.float32,
device,
generator,
latents,
)
latents = latents * self.scheduler.config.sigma_data
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0]).to(prompt_embeds.dtype)
guidance = guidance * self.transformer.config.guidance_embeds_scale
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Denoising loop
timesteps = timesteps[:-1]
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(prompt_embeds.dtype)
latents_model_input = latents / self.scheduler.config.sigma_data
scm_timestep = torch.sin(timestep) / (torch.cos(timestep) + torch.sin(timestep))
scm_timestep_expanded = scm_timestep.view(-1, 1, 1, 1)
latent_model_input = latents_model_input * torch.sqrt(
scm_timestep_expanded**2 + (1 - scm_timestep_expanded) ** 2
)
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
# predict noise model_output
noise_pred = self.transformer(
latent_model_input,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
guidance=guidance,
timestep=scm_timestep,
return_dict=False,
attention_kwargs=self.attention_kwargs,
)[0]
noise_pred = (
(1 - 2 * scm_timestep_expanded) * latent_model_input
+ (1 - 2 * scm_timestep_expanded + 2 * scm_timestep_expanded**2) * noise_pred
) / torch.sqrt(scm_timestep_expanded**2 + (1 - scm_timestep_expanded) ** 2)
noise_pred = noise_pred.float() * self.scheduler.config.sigma_data
# compute previous image: x_t -> x_t-1
latents, denoised = self.scheduler.step(
noise_pred, timestep, latents, **extra_step_kwargs, return_dict=False
)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
latents = denoised / self.scheduler.config.sigma_data
if output_type == "latent":
image = latents
else:
latents = latents.to(self.vae.dtype)
try:
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
except torch.cuda.OutOfMemoryError as e:
warnings.warn(
f"{e}. \n"
f"Try to use VAE tiling for large images. For example: \n"
f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)"
)
if use_resolution_binning:
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
if not output_type == "latent":
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return SanaPipelineOutput(images=image)
@@ -52,6 +52,7 @@ from ...schedulers import (
UnCLIPScheduler,
)
from ...utils import is_accelerate_available, logging
from ...utils.constants import DIFFUSERS_REQUEST_TIMEOUT
from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
from ..paint_by_example import PaintByExampleImageEncoder
from ..pipeline_utils import DiffusionPipeline
@@ -1324,7 +1325,7 @@ def download_from_original_stable_diffusion_ckpt(
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml"
if config_url is not None:
original_config_file = BytesIO(requests.get(config_url).content)
original_config_file = BytesIO(requests.get(config_url, timeout=DIFFUSERS_REQUEST_TIMEOUT).content)
else:
with open(original_config_file, "r") as f:
original_config_file = f.read()
+2 -1
View File
@@ -24,7 +24,7 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["pipeline_wan"] = ["WanPipeline"]
_import_structure["pipeline_wan_i2v"] = ["WanImageToVideoPipeline"]
_import_structure["pipeline_wan_video2video"] = ["WanVideoToVideoPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
@@ -35,6 +35,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .pipeline_wan import WanPipeline
from .pipeline_wan_i2v import WanImageToVideoPipeline
from .pipeline_wan_video2video import WanVideoToVideoPipeline
else:
import sys
@@ -458,6 +458,13 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
callback_on_step_end_tensor_inputs,
)
if num_frames % self.vae_scale_factor_temporal != 1:
logger.warning(
f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
)
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
num_frames = max(num_frames, 1)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._current_timestep = None
+42 -24
View File
@@ -108,31 +108,16 @@ def prompt_clean(text):
return text
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor,
latents_mean: torch.Tensor,
latents_std: torch.Tensor,
generator: Optional[torch.Generator] = None,
sample_mode: str = "sample",
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std
encoder_output.latent_dist.logvar = torch.clamp(
(encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0
)
encoder_output.latent_dist.std = torch.exp(0.5 * encoder_output.latent_dist.logvar)
encoder_output.latent_dist.var = torch.exp(encoder_output.latent_dist.logvar)
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std
encoder_output.latent_dist.logvar = torch.clamp(
(encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0
)
encoder_output.latent_dist.std = torch.exp(0.5 * encoder_output.latent_dist.logvar)
encoder_output.latent_dist.var = torch.exp(encoder_output.latent_dist.logvar)
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return (encoder_output.latents - latents_mean) * latents_std
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
@@ -235,8 +220,13 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
return prompt_embeds
def encode_image(self, image: PipelineImageInput):
image = self.image_processor(images=image, return_tensors="pt").to(self.device)
def encode_image(
self,
image: PipelineImageInput,
device: Optional[torch.device] = None,
):
device = device or self._execution_device
image = self.image_processor(images=image, return_tensors="pt").to(device)
image_embeds = self.image_encoder(**image, output_hidden_states=True)
return image_embeds.hidden_states[-2]
@@ -331,9 +321,19 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
width,
prompt_embeds=None,
negative_prompt_embeds=None,
image_embeds=None,
callback_on_step_end_tensor_inputs=None,
):
if not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
if image is not None and image_embeds is not None:
raise ValueError(
f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to"
" only forward one of the two."
)
if image is None and image_embeds is None:
raise ValueError(
"Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined."
)
if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
raise ValueError("`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is" f" {type(image)}")
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
@@ -412,13 +412,15 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
if isinstance(generator, list):
latent_condition = [
retrieve_latents(self.vae.encode(video_condition), latents_mean, latents_std, g) for g in generator
retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator
]
latent_condition = torch.cat(latent_condition)
else:
latent_condition = retrieve_latents(self.vae.encode(video_condition), latents_mean, latents_std, generator)
latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax")
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
latent_condition = (latent_condition - latents_mean) * latents_std
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
mask_lat_size[:, :, list(range(1, num_frames))] = 0
first_frame_mask = mask_lat_size[:, :, 0:1]
@@ -471,6 +473,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
image_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "np",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -520,6 +523,12 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `negative_prompt` input argument.
image_embeds (`torch.Tensor`, *optional*):
Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided,
image embeddings are generated from the `image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
@@ -564,9 +573,17 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
width,
prompt_embeds,
negative_prompt_embeds,
image_embeds,
callback_on_step_end_tensor_inputs,
)
if num_frames % self.vae_scale_factor_temporal != 1:
logger.warning(
f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
)
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
num_frames = max(num_frames, 1)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._current_timestep = None
@@ -600,7 +617,8 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
if negative_prompt_embeds is not None:
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
image_embeds = self.encode_image(image)
if image_embeds is None:
image_embeds = self.encode_image(image, device)
image_embeds = image_embeds.repeat(batch_size, 1, 1)
image_embeds = image_embeds.to(transformer_dtype)

Some files were not shown because too many files have changed in this diff Show More