Compare commits

..

34 Commits

Author SHA1 Message Date
Sayak Paul e72648a311 Merge branch 'main' into gpu-pr-test 2025-07-09 09:04:31 +05:30
Aryan 0454fbb30b First Block Cache (#11180)
* update

* modify flux single blocks to make compatible with cache techniques (without too much model-specific intrusion code)

* remove debug logs

* update

* cache context for different batches of data

* fix hs residual bug for single return outputs; support ltx

* fix controlnet flux

* support flux, ltx i2v, ltx condition

* update

* update

* Update docs/source/en/api/cache.md

* Update src/diffusers/hooks/hooks.py

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* address review comments pt. 1

* address review comments pt. 2

* cache context refacotr; address review pt. 3

* address review comments

* metadata registration with decorators instead of centralized

* support cogvideox

* support mochi

* fix

* remove unused function

* remove central registry based on review

* update

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2025-07-09 03:27:15 +05:30
DN6 3e3c0fcc1c update 2025-07-08 22:38:50 +05:30
Dhruv Nair cbc8ced20f [CI] Fix big GPU test marker (#11786)
* update

* update
2025-07-08 22:09:09 +05:30
Sayak Paul 01240fecb0 [training ] add Kontext i2i training (#11858)
* feat: enable i2i fine-tuning in Kontext script.

* readme

* more checks.

* Apply suggestions from code review

Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>

* fixes

* fix

* add proj_mlp to the mix

* Update README_flux.md

add note on installing from commit `05e7a854d0a5661f5b433f6dd5954c224b104f0b`

* fix

* fix

---------

Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
2025-07-08 21:04:16 +05:30
Steven Liu ce338d4e4a [docs] LoRA metadata (#11848)
* draft

* hub image

* update

* fix
2025-07-08 08:29:38 -07:00
Sayak Paul bc55b631fd [tests] remove tests for deprecated pipelines. (#11879)
* remove tests for deprecated pipelines.

* remove folders

* test_pipelines_common
2025-07-08 07:13:16 +05:30
Sayak Paul 15d50f16f2 [docs] fix references in flux pipelines. (#11857)
* fix references in flux.

* Update src/diffusers/pipelines/flux/pipeline_flux_kontext.py

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2025-07-07 22:20:34 +05:30
Sayak Paul 2c30287958 [chore] deprecate blip controlnet pipeline. (#11877)
* deprecate blip controlnet pipeline.

* last_supported_version
2025-07-07 13:25:40 +05:30
Aryan 425a715e35 Fix Wan AccVideo/CausVid fuse_lora (#11856)
* fix

* actually, better fix

* empty commit; trigger tests again

* mark wanvace test as flaky
2025-07-04 21:10:35 +05:30
Benjamin Bossan 2527917528 FIX set_lora_device when target layers differ (#11844)
* FIX set_lora_device when target layers differ

Resolves #11833

Fixes a bug that occurs after calling set_lora_device when multiple LoRA
adapters are loaded that target different layers.

Note: Technically, the accompanying test does not require a GPU because
the bug is triggered even if the parameters are already on the
corresponding device, i.e. loading on CPU and then changing the device
to CPU is sufficient to cause the bug. However, this may be optimized
away in the future, so I decided to test with GPU.

* Update docstring to warn about device mismatch

* Extend docstring with an example

* Fix docstring

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-07-04 19:26:17 +05:30
Sayak Paul e6639fef70 [benchmarks] overhaul benchmarks (#11565)
* start overhauling the benchmarking suite.

* fixes

* fixes

* checking.

* checking

* fixes.

* error handling and logging.

* add flops and params.

* add more models.

* utility to fire execution of all benchmarking scripts.

* utility to push to the hub.

* push utility improvement

* seems to be working.

* okay

* add torchprofile dep.

* remove total gpu memory

* fixes

* fix

* need a big gpu

* better

* what's happening.

* okay

* separate requirements and make it nightly.

* add db population script.

* update secret name

* update secret.

* population db update

* disable db population for now.

* change to every monday

* Update .github/workflows/benchmark.yml

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* quality improvements.

* reparate hub upload step.

* repository

* remove csv

* check

* update

* update

* threading.

* update

* update

* updaye

* update

* update

* update

* remove peft dep

* upgrade runner.

* fix

* fixes

* fix merging csvs.

* push dataset to the Space repo for analysis.

* warm up.

* add a readme

* Apply suggestions from code review

Co-authored-by: Luc Georges <McPatate@users.noreply.github.com>

* address feedback

* Apply suggestions from code review

* disable db workflow.

* update to bi weekly.

* enable population

* enable

* updaye

* update

* metadata

* fix

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: Luc Georges <McPatate@users.noreply.github.com>
2025-07-04 11:04:17 +05:30
Aryan 8c938fb410 [docs] Add a note of _keep_in_fp32_modules (#11851)
* update

* Update docs/source/en/using-diffusers/schedulers.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update schedulers.md

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2025-07-02 15:51:57 -07:00
Linoy Tsaban f864a9a352 [Flux Kontext] Support Fal Kontext LoRA (#11823)
* initial commit

* initial commit

* initial commit

* fix import

* fix prefix

* remove print

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-07-02 16:57:08 +03:00
Vương Đình Minh d6fa3298fa update: FluxKontextInpaintPipeline support (#11820)
* update: FluxKontextInpaintPipeline support

* fix: Refactor code, remove mask_image_latents and ruff check

* feat: Add test case and fix with pytest

* Apply style fixes

* copies

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-07-01 23:34:27 -10:00
Sayak Paul 6f1d6694df [lora] tests for exclude_modules with Wan VACE (#11843)
* wan vace.

* update

* update

* import problem
2025-07-02 14:23:26 +05:30
Ju Hoon Park 0e95aa853e [From Single File] support from_single_file method for WanVACE3DTransformer (#11807)
* add `WandVACETransformer3DModel` in`SINGLE_FILE_LOADABLE_CLASSES`

* add rename keys for `VACE`

add rename keys for `VACE`

* fix typo

Sincere thanks to @nitinmukesh 🙇‍♂️

* support for `1.3B VACE` model

Sincere thanks to @nitinmukesh again🙇‍♂️

* update

* update

* Apply style fixes

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-07-02 05:55:36 +02:00
Luo Yihang 5ef74fd5f6 fix norm not training in train_control_lora_flux.py (#11832) 2025-07-01 17:37:54 -10:00
Steven Liu 64a9210315 [docs] Deprecated pipelines (#11838)
add warning

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-07-01 14:02:54 -10:00
Steven Liu d31b8cea3e [docs] Batch generation (#11841)
* draft

* fix

* fix

* feedback

* feedback
2025-07-01 17:00:20 -07:00
Mikko Tukiainen 62e847db5f Use real-valued instead of complex tensors in Wan2.1 RoPE (#11649)
* use real instead of complex tensors in Wan2.1 RoPE

* remove the redundant type conversion

* unpack rotary_emb

* register rotary embedding frequencies as non-persistent buffers

* Apply style fixes

---------

Co-authored-by: Aryan <aryan@huggingface.co>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-07-01 13:57:19 -10:00
Sayak Paul 470458623e [docs] fix single_file example. (#11847)
fix single_file example.
2025-07-01 21:23:27 +05:30
Aryan a79c3af6bb [single file] Cosmos (#11801)
* update

* update

* update docs
2025-07-01 18:02:58 +05:30
Aryan 3f3f0c16a6 [tests] Fix failing float16 cuda tests (#11835)
* update

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-07-01 11:13:58 +05:30
jiqing-feng f3e1310469 reset deterministic in tearDownClass (#11785)
* reset deterministic in tearDownClass

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix deterministic setting

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-07-01 10:06:54 +05:30
Sayak Paul 87f83d3dd9 [tests] add test for hotswapping + compilation on resolution changes (#11825)
* add resolution changes tests to hotswapping test suite.

* fixes

* docs

* explain duck shapes

* fix
2025-07-01 09:40:34 +05:30
Aryan f064b3bf73 Remove print statement in SCM Scheduler (#11836)
remove print
2025-06-30 09:07:34 -10:00
Benjamin Bossan 3b079ec3fa ENH: Improve speed of function expanding LoRA scales (#11834)
* ENH Improve speed of expanding LoRA scales

Resolves #11816

The following call proved to be a bottleneck when setting a lot of LoRA
adapters in diffusers:

https://github.com/huggingface/diffusers/blob/cdaf84a708eadf17d731657f4be3fa39d09a12c0/src/diffusers/loaders/peft.py#L482

This is because we would repeatedly call unet.state_dict(), even though
in the standard case, it is not necessary:

https://github.com/huggingface/diffusers/blob/cdaf84a708eadf17d731657f4be3fa39d09a12c0/src/diffusers/loaders/unet_loader_utils.py#L55

This PR fixes this by deferring this call, so that it is only run when
it's necessary, not earlier.

* Small fix

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-06-30 20:25:56 +05:30
Sayak Paul bc34fa8386 [lora]feat: use exclude modules to loraconfig. (#11806)
* feat: use exclude modules to loraconfig.

* version-guard.

* tests and version guard.

* remove print.

* describe the test

* more detailed warning message + shift to debug

* update

* update

* update

* remove test
2025-06-30 20:08:53 +05:30
Sayak Paul 05e7a854d0 [lora] fix: lora unloading behvaiour (#11822)
* fix: lora unloading behvaiour

* fix

* update
2025-06-28 12:00:42 +05:30
Aryan 76ec3d1fee Support dynamically loading/unloading loras with group offloading (#11804)
* update

* add test

* address review comments

* update

* fixes

* change decorator order to fix tests

* try fix

* fight tests
2025-06-27 23:20:53 +05:30
Aryan cdaf84a708 TorchAO compile + offloading tests (#11697)
* update

* update

* update

* update

* update

* user property instead
2025-06-27 18:31:57 +05:30
Sayak Paul e8e44a510c [CI] disable onnx, mps, flax from the CI (#11803)
* disable onnx, mps, flax

* remove
2025-06-27 16:33:43 +05:30
Sayak Paul 21543de571 remove syncs before denoising in Kontext (#11818) 2025-06-27 15:57:55 +05:30
192 changed files with 5636 additions and 12352 deletions
+31 -10
View File
@@ -11,17 +11,18 @@ env:
HF_HOME: /mnt/cache
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
BASE_PATH: benchmark_outputs
jobs:
torch_pipelines_cuda_benchmark_tests:
torch_models_cuda_benchmark_tests:
env:
SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL_BENCHMARK }}
name: Torch Core Pipelines CUDA Benchmarking Tests
name: Torch Core Models CUDA Benchmarking Tests
strategy:
fail-fast: false
max-parallel: 1
runs-on:
group: aws-g6-4xlarge-plus
group: aws-g6e-4xlarge
container:
image: diffusers/diffusers-pytorch-cuda
options: --shm-size "16gb" --ipc host --gpus 0
@@ -35,27 +36,47 @@ jobs:
nvidia-smi
- name: Install dependencies
run: |
apt update
apt install -y libpq-dev postgresql-client
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
python -m uv pip install pandas peft
python -m uv pip uninstall transformers && python -m uv pip install transformers==4.48.0
python -m uv pip install -r benchmarks/requirements.txt
- name: Environment
run: |
python utils/print_env.py
- name: Diffusers Benchmarking
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_BOT_TOKEN }}
BASE_PATH: benchmark_outputs
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
export TOTAL_GPU_MEMORY=$(python -c "import torch; print(torch.cuda.get_device_properties(0).total_memory / (1024**3))")
cd benchmarks && mkdir ${BASE_PATH} && python run_all.py && python push_results.py
cd benchmarks && python run_all.py
- name: Push results to the Hub
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_BOT_TOKEN }}
run: |
cd benchmarks && python push_results.py
mkdir $BASE_PATH && cp *.csv $BASE_PATH
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: benchmark_test_reports
path: benchmarks/benchmark_outputs
path: benchmarks/${{ env.BASE_PATH }}
# TODO: enable this once the connection problem has been resolved.
- name: Update benchmarking results to DB
env:
PGDATABASE: metrics
PGHOST: ${{ secrets.DIFFUSERS_BENCHMARKS_PGHOST }}
PGUSER: transformers_benchmarks
PGPASSWORD: ${{ secrets.DIFFUSERS_BENCHMARKS_PGPASSWORD }}
BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
run: |
git config --global --add safe.directory /__w/diffusers/diffusers
commit_id=$GITHUB_SHA
commit_msg=$(git show -s --format=%s "$commit_id" | cut -c1-70)
cd benchmarks && python populate_into_db.py "$BRANCH_NAME" "$commit_id" "$commit_msg"
- name: Report success status
if: ${{ success() }}
@@ -75,10 +75,6 @@ jobs:
- diffusers-pytorch-cuda
- diffusers-pytorch-xformers-cuda
- diffusers-pytorch-minimum-cuda
- diffusers-flax-cpu
- diffusers-flax-tpu
- diffusers-onnxruntime-cpu
- diffusers-onnxruntime-cuda
- diffusers-doc-builder
steps:
+3 -103
View File
@@ -248,7 +248,7 @@ jobs:
BIG_GPU_MEMORY: 40
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-m "big_gpu_with_torch_cuda" \
-m "big_accelerator" \
--make-reports=tests_big_gpu_torch_cuda \
--report-log=tests_big_gpu_torch_cuda.log \
tests/
@@ -321,55 +321,6 @@ jobs:
name: torch_minimum_version_cuda_test_reports
path: reports
run_nightly_onnx_tests:
name: Nightly ONNXRuntime CUDA tests on Ubuntu
runs-on:
group: aws-g4dn-2xlarge
container:
image: diffusers/diffusers-onnxruntime-cuda
options: --gpus 0 --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: NVIDIA-SMI
run: nvidia-smi
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
python -m uv pip install pytest-reportlog
- name: Environment
run: python utils/print_env.py
- name: Run Nightly ONNXRuntime CUDA tests
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "Onnx" \
--make-reports=tests_onnx_cuda \
--report-log=tests_onnx_cuda.log \
tests/
- name: Failure short reports
if: ${{ failure() }}
run: |
cat reports/tests_onnx_cuda_stats.txt
cat reports/tests_onnx_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: tests_onnx_cuda_reports
path: reports
run_nightly_quantization_tests:
name: Torch quantization nightly tests
strategy:
@@ -485,57 +436,6 @@ jobs:
name: torch_cuda_pipeline_level_quant_reports
path: reports
run_flax_tpu_tests:
name: Nightly Flax TPU Tests
runs-on:
group: gcp-ct5lp-hightpu-8t
if: github.event_name == 'schedule'
container:
image: diffusers/diffusers-flax-tpu
options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache
defaults:
run:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
python -m uv pip install pytest-reportlog
- name: Environment
run: python utils/print_env.py
- name: Run nightly Flax TPU tests
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
python -m pytest -n 0 \
-s -v -k "Flax" \
--make-reports=tests_flax_tpu \
--report-log=tests_flax_tpu.log \
tests/
- name: Failure short reports
if: ${{ failure() }}
run: |
cat reports/tests_flax_tpu_stats.txt
cat reports/tests_flax_tpu_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: flax_tpu_test_reports
path: reports
generate_consolidated_report:
name: Generate Consolidated Test Report
needs: [
@@ -545,9 +445,9 @@ jobs:
run_big_gpu_torch_tests,
run_nightly_quantization_tests,
run_nightly_pipeline_level_quantization_tests,
run_nightly_onnx_tests,
# run_nightly_onnx_tests,
torch_minimum_version_cuda_tests,
run_flax_tpu_tests
# run_flax_tpu_tests
]
if: always()
runs-on:
-14
View File
@@ -87,11 +87,6 @@ jobs:
runner: aws-general-8-plus
image: diffusers/diffusers-pytorch-cpu
report: torch_cpu_models_schedulers
- name: Fast Flax CPU tests
framework: flax
runner: aws-general-8-plus
image: diffusers/diffusers-flax-cpu
report: flax_cpu
- name: PyTorch Example CPU tests
framework: pytorch_examples
runner: aws-general-8-plus
@@ -147,15 +142,6 @@ jobs:
--make-reports=tests_${{ matrix.config.report }} \
tests/models tests/schedulers tests/others
- name: Run fast Flax TPU tests
if: ${{ matrix.config.framework == 'flax' }}
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
-s -v -k "Flax" \
--make-reports=tests_${{ matrix.config.report }} \
tests
- name: Run example PyTorch CPU tests
if: ${{ matrix.config.framework == 'pytorch_examples' }}
run: |
+1 -1
View File
@@ -188,7 +188,7 @@ jobs:
shell: bash
strategy:
fail-fast: false
max-parallel: 2
max-parallel: 4
matrix:
module: [models, schedulers, lora, others]
steps:
-96
View File
@@ -159,102 +159,6 @@ jobs:
name: torch_cuda_test_reports_${{ matrix.module }}
path: reports
flax_tpu_tests:
name: Flax TPU Tests
runs-on:
group: gcp-ct5lp-hightpu-8t
container:
image: diffusers/diffusers-flax-tpu
options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache
defaults:
run:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- name: Environment
run: |
python utils/print_env.py
- name: Run Flax TPU tests
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
python -m pytest -n 0 \
-s -v -k "Flax" \
--make-reports=tests_flax_tpu \
tests/
- name: Failure short reports
if: ${{ failure() }}
run: |
cat reports/tests_flax_tpu_stats.txt
cat reports/tests_flax_tpu_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: flax_tpu_test_reports
path: reports
onnx_cuda_tests:
name: ONNX CUDA Tests
runs-on:
group: aws-g4dn-2xlarge
container:
image: diffusers/diffusers-onnxruntime-cuda
options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ --gpus 0
defaults:
run:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- name: Environment
run: |
python utils/print_env.py
- name: Run ONNXRuntime CUDA tests
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "Onnx" \
--make-reports=tests_onnx_cuda \
tests/
- name: Failure short reports
if: ${{ failure() }}
run: |
cat reports/tests_onnx_cuda_stats.txt
cat reports/tests_onnx_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: onnx_cuda_test_reports
path: reports
run_torch_compile_tests:
name: PyTorch Compile CUDA tests
-28
View File
@@ -33,16 +33,6 @@ jobs:
runner: aws-general-8-plus
image: diffusers/diffusers-pytorch-cpu
report: torch_cpu
- name: Fast Flax CPU tests on Ubuntu
framework: flax
runner: aws-general-8-plus
image: diffusers/diffusers-flax-cpu
report: flax_cpu
- name: Fast ONNXRuntime CPU tests on Ubuntu
framework: onnxruntime
runner: aws-general-8-plus
image: diffusers/diffusers-onnxruntime-cpu
report: onnx_cpu
- name: PyTorch Example CPU tests on Ubuntu
framework: pytorch_examples
runner: aws-general-8-plus
@@ -87,24 +77,6 @@ jobs:
--make-reports=tests_${{ matrix.config.report }} \
tests/
- name: Run fast Flax TPU tests
if: ${{ matrix.config.framework == 'flax' }}
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
-s -v -k "Flax" \
--make-reports=tests_${{ matrix.config.report }} \
tests/
- name: Run fast ONNXRuntime CPU tests
if: ${{ matrix.config.framework == 'onnxruntime' }}
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
-s -v -k "Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/
- name: Run example PyTorch CPU tests
if: ${{ matrix.config.framework == 'pytorch_examples' }}
run: |
+1 -6
View File
@@ -1,12 +1,7 @@
name: Fast mps tests on main
on:
push:
branches:
- main
paths:
- "src/diffusers/**.py"
- "tests/**.py"
workflow_dispatch:
env:
DIFFUSERS_IS_CI: yes
-95
View File
@@ -213,101 +213,6 @@ jobs:
with:
name: torch_minimum_version_cuda_test_reports
path: reports
flax_tpu_tests:
name: Flax TPU Tests
runs-on: docker-tpu
container:
image: diffusers/diffusers-flax-tpu
options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ --privileged
defaults:
run:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- name: Environment
run: |
python utils/print_env.py
- name: Run slow Flax TPU tests
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
python -m pytest -n 0 \
-s -v -k "Flax" \
--make-reports=tests_flax_tpu \
tests/
- name: Failure short reports
if: ${{ failure() }}
run: |
cat reports/tests_flax_tpu_stats.txt
cat reports/tests_flax_tpu_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: flax_tpu_test_reports
path: reports
onnx_cuda_tests:
name: ONNX CUDA Tests
runs-on:
group: aws-g4dn-2xlarge
container:
image: diffusers/diffusers-onnxruntime-cuda
options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ --gpus 0
defaults:
run:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- name: Environment
run: |
python utils/print_env.py
- name: Run slow ONNXRuntime CUDA tests
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "Onnx" \
--make-reports=tests_onnx_cuda \
tests/
- name: Failure short reports
if: ${{ failure() }}
run: |
cat reports/tests_onnx_cuda_stats.txt
cat reports/tests_onnx_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: onnx_cuda_test_reports
path: reports
run_torch_compile_tests:
name: PyTorch Compile CUDA tests
+69
View File
@@ -0,0 +1,69 @@
# Diffusers Benchmarks
Welcome to Diffusers Benchmarks. These benchmarks are use to obtain latency and memory information of the most popular models across different scenarios such as:
* Base case i.e., when using `torch.bfloat16` and `torch.nn.functional.scaled_dot_product_attention`.
* Base + `torch.compile()`
* NF4 quantization
* Layerwise upcasting
Instead of full diffusion pipelines, only the forward pass of the respective model classes (such as `FluxTransformer2DModel`) is tested with the real checkpoints (such as `"black-forest-labs/FLUX.1-dev"`).
The entrypoint to running all the currently available benchmarks is in `run_all.py`. However, one can run the individual benchmarks, too, e.g., `python benchmarking_flux.py`. It should produce a CSV file containing various information about the benchmarks run.
The benchmarks are run on a weekly basis and the CI is defined in [benchmark.yml](../.github/workflows/benchmark.yml).
## Running the benchmarks manually
First set up `torch` and install `diffusers` from the root of the directory:
```py
pip install -e ".[quality,test]"
```
Then make sure the other dependencies are installed:
```sh
cd benchmarks/
pip install -r requirements.txt
```
We need to be authenticated to access some of the checkpoints used during benchmarking:
```sh
huggingface-cli login
```
We use an L40 GPU with 128GB RAM to run the benchmark CI. As such, the benchmarks are configured to run on NVIDIA GPUs. So, make sure you have access to a similar machine (or modify the benchmarking scripts accordingly).
Then you can either launch the entire benchmarking suite by running:
```sh
python run_all.py
```
Or, you can run the individual benchmarks.
## Customizing the benchmarks
We define "scenarios" to cover the most common ways in which these models are used. You can
define a new scenario, modifying an existing benchmark file:
```py
BenchmarkScenario(
name=f"{CKPT_ID}-bnb-8bit",
model_cls=FluxTransformer2DModel,
model_init_kwargs={
"pretrained_model_name_or_path": CKPT_ID,
"torch_dtype": torch.bfloat16,
"subfolder": "transformer",
"quantization_config": BitsAndBytesConfig(load_in_8bit=True),
},
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
model_init_fn=model_init_fn,
)
```
You can also configure a new model-level benchmark and add it to the existing suite. To do so, just defining a valid benchmarking file like `benchmarking_flux.py` should be enough.
Happy benchmarking 🧨
-346
View File
@@ -1,346 +0,0 @@
import os
import sys
import torch
from diffusers import (
AutoPipelineForImage2Image,
AutoPipelineForInpainting,
AutoPipelineForText2Image,
ControlNetModel,
LCMScheduler,
StableDiffusionAdapterPipeline,
StableDiffusionControlNetPipeline,
StableDiffusionXLAdapterPipeline,
StableDiffusionXLControlNetPipeline,
T2IAdapter,
WuerstchenCombinedPipeline,
)
from diffusers.utils import load_image
sys.path.append(".")
from utils import ( # noqa: E402
BASE_PATH,
PROMPT,
BenchmarkInfo,
benchmark_fn,
bytes_to_giga_bytes,
flush,
generate_csv_dict,
write_to_csv,
)
RESOLUTION_MAPPING = {
"Lykon/DreamShaper": (512, 512),
"lllyasviel/sd-controlnet-canny": (512, 512),
"diffusers/controlnet-canny-sdxl-1.0": (1024, 1024),
"TencentARC/t2iadapter_canny_sd14v1": (512, 512),
"TencentARC/t2i-adapter-canny-sdxl-1.0": (1024, 1024),
"stabilityai/stable-diffusion-2-1": (768, 768),
"stabilityai/stable-diffusion-xl-base-1.0": (1024, 1024),
"stabilityai/stable-diffusion-xl-refiner-1.0": (1024, 1024),
"stabilityai/sdxl-turbo": (512, 512),
}
class BaseBenchmak:
pipeline_class = None
def __init__(self, args):
super().__init__()
def run_inference(self, args):
raise NotImplementedError
def benchmark(self, args):
raise NotImplementedError
def get_result_filepath(self, args):
pipeline_class_name = str(self.pipe.__class__.__name__)
name = (
args.ckpt.replace("/", "_")
+ "_"
+ pipeline_class_name
+ f"-bs@{args.batch_size}-steps@{args.num_inference_steps}-mco@{args.model_cpu_offload}-compile@{args.run_compile}.csv"
)
filepath = os.path.join(BASE_PATH, name)
return filepath
class TextToImageBenchmark(BaseBenchmak):
pipeline_class = AutoPipelineForText2Image
def __init__(self, args):
pipe = self.pipeline_class.from_pretrained(args.ckpt, torch_dtype=torch.float16)
pipe = pipe.to("cuda")
if args.run_compile:
if not isinstance(pipe, WuerstchenCombinedPipeline):
pipe.unet.to(memory_format=torch.channels_last)
print("Run torch compile")
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
if hasattr(pipe, "movq") and getattr(pipe, "movq", None) is not None:
pipe.movq.to(memory_format=torch.channels_last)
pipe.movq = torch.compile(pipe.movq, mode="reduce-overhead", fullgraph=True)
else:
print("Run torch compile")
pipe.decoder = torch.compile(pipe.decoder, mode="reduce-overhead", fullgraph=True)
pipe.vqgan = torch.compile(pipe.vqgan, mode="reduce-overhead", fullgraph=True)
pipe.set_progress_bar_config(disable=True)
self.pipe = pipe
def run_inference(self, pipe, args):
_ = pipe(
prompt=PROMPT,
num_inference_steps=args.num_inference_steps,
num_images_per_prompt=args.batch_size,
)
def benchmark(self, args):
flush()
print(f"[INFO] {self.pipe.__class__.__name__}: Running benchmark with: {vars(args)}\n")
time = benchmark_fn(self.run_inference, self.pipe, args) # in seconds.
memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) # in GBs.
benchmark_info = BenchmarkInfo(time=time, memory=memory)
pipeline_class_name = str(self.pipe.__class__.__name__)
flush()
csv_dict = generate_csv_dict(
pipeline_cls=pipeline_class_name, ckpt=args.ckpt, args=args, benchmark_info=benchmark_info
)
filepath = self.get_result_filepath(args)
write_to_csv(filepath, csv_dict)
print(f"Logs written to: {filepath}")
flush()
class TurboTextToImageBenchmark(TextToImageBenchmark):
def __init__(self, args):
super().__init__(args)
def run_inference(self, pipe, args):
_ = pipe(
prompt=PROMPT,
num_inference_steps=args.num_inference_steps,
num_images_per_prompt=args.batch_size,
guidance_scale=0.0,
)
class LCMLoRATextToImageBenchmark(TextToImageBenchmark):
lora_id = "latent-consistency/lcm-lora-sdxl"
def __init__(self, args):
super().__init__(args)
self.pipe.load_lora_weights(self.lora_id)
self.pipe.fuse_lora()
self.pipe.unload_lora_weights()
self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config)
def get_result_filepath(self, args):
pipeline_class_name = str(self.pipe.__class__.__name__)
name = (
self.lora_id.replace("/", "_")
+ "_"
+ pipeline_class_name
+ f"-bs@{args.batch_size}-steps@{args.num_inference_steps}-mco@{args.model_cpu_offload}-compile@{args.run_compile}.csv"
)
filepath = os.path.join(BASE_PATH, name)
return filepath
def run_inference(self, pipe, args):
_ = pipe(
prompt=PROMPT,
num_inference_steps=args.num_inference_steps,
num_images_per_prompt=args.batch_size,
guidance_scale=1.0,
)
def benchmark(self, args):
flush()
print(f"[INFO] {self.pipe.__class__.__name__}: Running benchmark with: {vars(args)}\n")
time = benchmark_fn(self.run_inference, self.pipe, args) # in seconds.
memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) # in GBs.
benchmark_info = BenchmarkInfo(time=time, memory=memory)
pipeline_class_name = str(self.pipe.__class__.__name__)
flush()
csv_dict = generate_csv_dict(
pipeline_cls=pipeline_class_name, ckpt=self.lora_id, args=args, benchmark_info=benchmark_info
)
filepath = self.get_result_filepath(args)
write_to_csv(filepath, csv_dict)
print(f"Logs written to: {filepath}")
flush()
class ImageToImageBenchmark(TextToImageBenchmark):
pipeline_class = AutoPipelineForImage2Image
url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/1665_Girl_with_a_Pearl_Earring.jpg"
image = load_image(url).convert("RGB")
def __init__(self, args):
super().__init__(args)
self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt])
def run_inference(self, pipe, args):
_ = pipe(
prompt=PROMPT,
image=self.image,
num_inference_steps=args.num_inference_steps,
num_images_per_prompt=args.batch_size,
)
class TurboImageToImageBenchmark(ImageToImageBenchmark):
def __init__(self, args):
super().__init__(args)
def run_inference(self, pipe, args):
_ = pipe(
prompt=PROMPT,
image=self.image,
num_inference_steps=args.num_inference_steps,
num_images_per_prompt=args.batch_size,
guidance_scale=0.0,
strength=0.5,
)
class InpaintingBenchmark(ImageToImageBenchmark):
pipeline_class = AutoPipelineForInpainting
mask_url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/overture-creations-5sI6fQgYIuo_mask.png"
mask = load_image(mask_url).convert("RGB")
def __init__(self, args):
super().__init__(args)
self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt])
self.mask = self.mask.resize(RESOLUTION_MAPPING[args.ckpt])
def run_inference(self, pipe, args):
_ = pipe(
prompt=PROMPT,
image=self.image,
mask_image=self.mask,
num_inference_steps=args.num_inference_steps,
num_images_per_prompt=args.batch_size,
)
class IPAdapterTextToImageBenchmark(TextToImageBenchmark):
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png"
image = load_image(url)
def __init__(self, args):
pipe = self.pipeline_class.from_pretrained(args.ckpt, torch_dtype=torch.float16).to("cuda")
pipe.load_ip_adapter(
args.ip_adapter_id[0],
subfolder="models" if "sdxl" not in args.ip_adapter_id[1] else "sdxl_models",
weight_name=args.ip_adapter_id[1],
)
if args.run_compile:
pipe.unet.to(memory_format=torch.channels_last)
print("Run torch compile")
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
pipe.set_progress_bar_config(disable=True)
self.pipe = pipe
def run_inference(self, pipe, args):
_ = pipe(
prompt=PROMPT,
ip_adapter_image=self.image,
num_inference_steps=args.num_inference_steps,
num_images_per_prompt=args.batch_size,
)
class ControlNetBenchmark(TextToImageBenchmark):
pipeline_class = StableDiffusionControlNetPipeline
aux_network_class = ControlNetModel
root_ckpt = "Lykon/DreamShaper"
url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/canny_image_condition.png"
image = load_image(url).convert("RGB")
def __init__(self, args):
aux_network = self.aux_network_class.from_pretrained(args.ckpt, torch_dtype=torch.float16)
pipe = self.pipeline_class.from_pretrained(self.root_ckpt, controlnet=aux_network, torch_dtype=torch.float16)
pipe = pipe.to("cuda")
pipe.set_progress_bar_config(disable=True)
self.pipe = pipe
if args.run_compile:
pipe.unet.to(memory_format=torch.channels_last)
pipe.controlnet.to(memory_format=torch.channels_last)
print("Run torch compile")
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
pipe.controlnet = torch.compile(pipe.controlnet, mode="reduce-overhead", fullgraph=True)
self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt])
def run_inference(self, pipe, args):
_ = pipe(
prompt=PROMPT,
image=self.image,
num_inference_steps=args.num_inference_steps,
num_images_per_prompt=args.batch_size,
)
class ControlNetSDXLBenchmark(ControlNetBenchmark):
pipeline_class = StableDiffusionXLControlNetPipeline
root_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"
def __init__(self, args):
super().__init__(args)
class T2IAdapterBenchmark(ControlNetBenchmark):
pipeline_class = StableDiffusionAdapterPipeline
aux_network_class = T2IAdapter
root_ckpt = "Lykon/DreamShaper"
url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/canny_for_adapter.png"
image = load_image(url).convert("L")
def __init__(self, args):
aux_network = self.aux_network_class.from_pretrained(args.ckpt, torch_dtype=torch.float16)
pipe = self.pipeline_class.from_pretrained(self.root_ckpt, adapter=aux_network, torch_dtype=torch.float16)
pipe = pipe.to("cuda")
pipe.set_progress_bar_config(disable=True)
self.pipe = pipe
if args.run_compile:
pipe.unet.to(memory_format=torch.channels_last)
pipe.adapter.to(memory_format=torch.channels_last)
print("Run torch compile")
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
pipe.adapter = torch.compile(pipe.adapter, mode="reduce-overhead", fullgraph=True)
self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt])
class T2IAdapterSDXLBenchmark(T2IAdapterBenchmark):
pipeline_class = StableDiffusionXLAdapterPipeline
root_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"
url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/canny_for_adapter_sdxl.png"
image = load_image(url)
def __init__(self, args):
super().__init__(args)
-26
View File
@@ -1,26 +0,0 @@
import argparse
import sys
sys.path.append(".")
from base_classes import ControlNetBenchmark, ControlNetSDXLBenchmark # noqa: E402
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--ckpt",
type=str,
default="lllyasviel/sd-controlnet-canny",
choices=["lllyasviel/sd-controlnet-canny", "diffusers/controlnet-canny-sdxl-1.0"],
)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--num_inference_steps", type=int, default=50)
parser.add_argument("--model_cpu_offload", action="store_true")
parser.add_argument("--run_compile", action="store_true")
args = parser.parse_args()
benchmark_pipe = (
ControlNetBenchmark(args) if args.ckpt == "lllyasviel/sd-controlnet-canny" else ControlNetSDXLBenchmark(args)
)
benchmark_pipe.benchmark(args)
-33
View File
@@ -1,33 +0,0 @@
import argparse
import sys
sys.path.append(".")
from base_classes import IPAdapterTextToImageBenchmark # noqa: E402
IP_ADAPTER_CKPTS = {
# because original SD v1.5 has been taken down.
"Lykon/DreamShaper": ("h94/IP-Adapter", "ip-adapter_sd15.bin"),
"stabilityai/stable-diffusion-xl-base-1.0": ("h94/IP-Adapter", "ip-adapter_sdxl.bin"),
}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--ckpt",
type=str,
default="rstabilityai/stable-diffusion-xl-base-1.0",
choices=list(IP_ADAPTER_CKPTS.keys()),
)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--num_inference_steps", type=int, default=50)
parser.add_argument("--model_cpu_offload", action="store_true")
parser.add_argument("--run_compile", action="store_true")
args = parser.parse_args()
args.ip_adapter_id = IP_ADAPTER_CKPTS[args.ckpt]
benchmark_pipe = IPAdapterTextToImageBenchmark(args)
args.ckpt = f"{args.ckpt} (IP-Adapter)"
benchmark_pipe.benchmark(args)
-29
View File
@@ -1,29 +0,0 @@
import argparse
import sys
sys.path.append(".")
from base_classes import ImageToImageBenchmark, TurboImageToImageBenchmark # noqa: E402
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--ckpt",
type=str,
default="Lykon/DreamShaper",
choices=[
"Lykon/DreamShaper",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-xl-refiner-1.0",
"stabilityai/sdxl-turbo",
],
)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--num_inference_steps", type=int, default=50)
parser.add_argument("--model_cpu_offload", action="store_true")
parser.add_argument("--run_compile", action="store_true")
args = parser.parse_args()
benchmark_pipe = ImageToImageBenchmark(args) if "turbo" not in args.ckpt else TurboImageToImageBenchmark(args)
benchmark_pipe.benchmark(args)
-28
View File
@@ -1,28 +0,0 @@
import argparse
import sys
sys.path.append(".")
from base_classes import InpaintingBenchmark # noqa: E402
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--ckpt",
type=str,
default="Lykon/DreamShaper",
choices=[
"Lykon/DreamShaper",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-xl-base-1.0",
],
)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--num_inference_steps", type=int, default=50)
parser.add_argument("--model_cpu_offload", action="store_true")
parser.add_argument("--run_compile", action="store_true")
args = parser.parse_args()
benchmark_pipe = InpaintingBenchmark(args)
benchmark_pipe.benchmark(args)
-28
View File
@@ -1,28 +0,0 @@
import argparse
import sys
sys.path.append(".")
from base_classes import T2IAdapterBenchmark, T2IAdapterSDXLBenchmark # noqa: E402
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--ckpt",
type=str,
default="TencentARC/t2iadapter_canny_sd14v1",
choices=["TencentARC/t2iadapter_canny_sd14v1", "TencentARC/t2i-adapter-canny-sdxl-1.0"],
)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--num_inference_steps", type=int, default=50)
parser.add_argument("--model_cpu_offload", action="store_true")
parser.add_argument("--run_compile", action="store_true")
args = parser.parse_args()
benchmark_pipe = (
T2IAdapterBenchmark(args)
if args.ckpt == "TencentARC/t2iadapter_canny_sd14v1"
else T2IAdapterSDXLBenchmark(args)
)
benchmark_pipe.benchmark(args)
-23
View File
@@ -1,23 +0,0 @@
import argparse
import sys
sys.path.append(".")
from base_classes import LCMLoRATextToImageBenchmark # noqa: E402
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--ckpt",
type=str,
default="stabilityai/stable-diffusion-xl-base-1.0",
)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--num_inference_steps", type=int, default=4)
parser.add_argument("--model_cpu_offload", action="store_true")
parser.add_argument("--run_compile", action="store_true")
args = parser.parse_args()
benchmark_pipe = LCMLoRATextToImageBenchmark(args)
benchmark_pipe.benchmark(args)
-40
View File
@@ -1,40 +0,0 @@
import argparse
import sys
sys.path.append(".")
from base_classes import TextToImageBenchmark, TurboTextToImageBenchmark # noqa: E402
ALL_T2I_CKPTS = [
"Lykon/DreamShaper",
"segmind/SSD-1B",
"stabilityai/stable-diffusion-xl-base-1.0",
"kandinsky-community/kandinsky-2-2-decoder",
"warp-ai/wuerstchen",
"stabilityai/sdxl-turbo",
]
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--ckpt",
type=str,
default="Lykon/DreamShaper",
choices=ALL_T2I_CKPTS,
)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--num_inference_steps", type=int, default=50)
parser.add_argument("--model_cpu_offload", action="store_true")
parser.add_argument("--run_compile", action="store_true")
args = parser.parse_args()
benchmark_cls = None
if "turbo" in args.ckpt:
benchmark_cls = TurboTextToImageBenchmark
else:
benchmark_cls = TextToImageBenchmark
benchmark_pipe = benchmark_cls(args)
benchmark_pipe.benchmark(args)
+98
View File
@@ -0,0 +1,98 @@
from functools import partial
import torch
from benchmarking_utils import BenchmarkMixin, BenchmarkScenario, model_init_fn
from diffusers import BitsAndBytesConfig, FluxTransformer2DModel
from diffusers.utils.testing_utils import torch_device
CKPT_ID = "black-forest-labs/FLUX.1-dev"
RESULT_FILENAME = "flux.csv"
def get_input_dict(**device_dtype_kwargs):
# resolution: 1024x1024
# maximum sequence length 512
hidden_states = torch.randn(1, 4096, 64, **device_dtype_kwargs)
encoder_hidden_states = torch.randn(1, 512, 4096, **device_dtype_kwargs)
pooled_prompt_embeds = torch.randn(1, 768, **device_dtype_kwargs)
image_ids = torch.ones(512, 3, **device_dtype_kwargs)
text_ids = torch.ones(4096, 3, **device_dtype_kwargs)
timestep = torch.tensor([1.0], **device_dtype_kwargs)
guidance = torch.tensor([1.0], **device_dtype_kwargs)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"img_ids": image_ids,
"txt_ids": text_ids,
"pooled_projections": pooled_prompt_embeds,
"timestep": timestep,
"guidance": guidance,
}
if __name__ == "__main__":
scenarios = [
BenchmarkScenario(
name=f"{CKPT_ID}-bf16",
model_cls=FluxTransformer2DModel,
model_init_kwargs={
"pretrained_model_name_or_path": CKPT_ID,
"torch_dtype": torch.bfloat16,
"subfolder": "transformer",
},
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
model_init_fn=model_init_fn,
compile_kwargs={"fullgraph": True},
),
BenchmarkScenario(
name=f"{CKPT_ID}-bnb-nf4",
model_cls=FluxTransformer2DModel,
model_init_kwargs={
"pretrained_model_name_or_path": CKPT_ID,
"torch_dtype": torch.bfloat16,
"subfolder": "transformer",
"quantization_config": BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4"
),
},
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
model_init_fn=model_init_fn,
),
BenchmarkScenario(
name=f"{CKPT_ID}-layerwise-upcasting",
model_cls=FluxTransformer2DModel,
model_init_kwargs={
"pretrained_model_name_or_path": CKPT_ID,
"torch_dtype": torch.bfloat16,
"subfolder": "transformer",
},
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
model_init_fn=partial(model_init_fn, layerwise_upcasting=True),
),
BenchmarkScenario(
name=f"{CKPT_ID}-group-offload-leaf",
model_cls=FluxTransformer2DModel,
model_init_kwargs={
"pretrained_model_name_or_path": CKPT_ID,
"torch_dtype": torch.bfloat16,
"subfolder": "transformer",
},
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
model_init_fn=partial(
model_init_fn,
group_offload_kwargs={
"onload_device": torch_device,
"offload_device": torch.device("cpu"),
"offload_type": "leaf_level",
"use_stream": True,
"non_blocking": True,
},
),
),
]
runner = BenchmarkMixin()
runner.run_bencmarks_and_collate(scenarios, filename=RESULT_FILENAME)
+80
View File
@@ -0,0 +1,80 @@
from functools import partial
import torch
from benchmarking_utils import BenchmarkMixin, BenchmarkScenario, model_init_fn
from diffusers import LTXVideoTransformer3DModel
from diffusers.utils.testing_utils import torch_device
CKPT_ID = "Lightricks/LTX-Video-0.9.7-dev"
RESULT_FILENAME = "ltx.csv"
def get_input_dict(**device_dtype_kwargs):
# 512x704 (161 frames)
# `max_sequence_length`: 256
hidden_states = torch.randn(1, 7392, 128, **device_dtype_kwargs)
encoder_hidden_states = torch.randn(1, 256, 4096, **device_dtype_kwargs)
encoder_attention_mask = torch.ones(1, 256, **device_dtype_kwargs)
timestep = torch.tensor([1.0], **device_dtype_kwargs)
video_coords = torch.randn(1, 3, 7392, **device_dtype_kwargs)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"encoder_attention_mask": encoder_attention_mask,
"timestep": timestep,
"video_coords": video_coords,
}
if __name__ == "__main__":
scenarios = [
BenchmarkScenario(
name=f"{CKPT_ID}-bf16",
model_cls=LTXVideoTransformer3DModel,
model_init_kwargs={
"pretrained_model_name_or_path": CKPT_ID,
"torch_dtype": torch.bfloat16,
"subfolder": "transformer",
},
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
model_init_fn=model_init_fn,
compile_kwargs={"fullgraph": True},
),
BenchmarkScenario(
name=f"{CKPT_ID}-layerwise-upcasting",
model_cls=LTXVideoTransformer3DModel,
model_init_kwargs={
"pretrained_model_name_or_path": CKPT_ID,
"torch_dtype": torch.bfloat16,
"subfolder": "transformer",
},
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
model_init_fn=partial(model_init_fn, layerwise_upcasting=True),
),
BenchmarkScenario(
name=f"{CKPT_ID}-group-offload-leaf",
model_cls=LTXVideoTransformer3DModel,
model_init_kwargs={
"pretrained_model_name_or_path": CKPT_ID,
"torch_dtype": torch.bfloat16,
"subfolder": "transformer",
},
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
model_init_fn=partial(
model_init_fn,
group_offload_kwargs={
"onload_device": torch_device,
"offload_device": torch.device("cpu"),
"offload_type": "leaf_level",
"use_stream": True,
"non_blocking": True,
},
),
),
]
runner = BenchmarkMixin()
runner.run_bencmarks_and_collate(scenarios, filename=RESULT_FILENAME)
+82
View File
@@ -0,0 +1,82 @@
from functools import partial
import torch
from benchmarking_utils import BenchmarkMixin, BenchmarkScenario, model_init_fn
from diffusers import UNet2DConditionModel
from diffusers.utils.testing_utils import torch_device
CKPT_ID = "stabilityai/stable-diffusion-xl-base-1.0"
RESULT_FILENAME = "sdxl.csv"
def get_input_dict(**device_dtype_kwargs):
# height: 1024
# width: 1024
# max_sequence_length: 77
hidden_states = torch.randn(1, 4, 128, 128, **device_dtype_kwargs)
encoder_hidden_states = torch.randn(1, 77, 2048, **device_dtype_kwargs)
timestep = torch.tensor([1.0], **device_dtype_kwargs)
added_cond_kwargs = {
"text_embeds": torch.randn(1, 1280, **device_dtype_kwargs),
"time_ids": torch.ones(1, 6, **device_dtype_kwargs),
}
return {
"sample": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
"added_cond_kwargs": added_cond_kwargs,
}
if __name__ == "__main__":
scenarios = [
BenchmarkScenario(
name=f"{CKPT_ID}-bf16",
model_cls=UNet2DConditionModel,
model_init_kwargs={
"pretrained_model_name_or_path": CKPT_ID,
"torch_dtype": torch.bfloat16,
"subfolder": "unet",
},
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
model_init_fn=model_init_fn,
compile_kwargs={"fullgraph": True},
),
BenchmarkScenario(
name=f"{CKPT_ID}-layerwise-upcasting",
model_cls=UNet2DConditionModel,
model_init_kwargs={
"pretrained_model_name_or_path": CKPT_ID,
"torch_dtype": torch.bfloat16,
"subfolder": "unet",
},
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
model_init_fn=partial(model_init_fn, layerwise_upcasting=True),
),
BenchmarkScenario(
name=f"{CKPT_ID}-group-offload-leaf",
model_cls=UNet2DConditionModel,
model_init_kwargs={
"pretrained_model_name_or_path": CKPT_ID,
"torch_dtype": torch.bfloat16,
"subfolder": "unet",
},
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
model_init_fn=partial(
model_init_fn,
group_offload_kwargs={
"onload_device": torch_device,
"offload_device": torch.device("cpu"),
"offload_type": "leaf_level",
"use_stream": True,
"non_blocking": True,
},
),
),
]
runner = BenchmarkMixin()
runner.run_bencmarks_and_collate(scenarios, filename=RESULT_FILENAME)
+244
View File
@@ -0,0 +1,244 @@
import gc
import inspect
import logging
import os
import queue
import threading
from contextlib import nullcontext
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional, Union
import pandas as pd
import torch
import torch.utils.benchmark as benchmark
from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils.testing_utils import require_torch_gpu, torch_device
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
logger = logging.getLogger(__name__)
NUM_WARMUP_ROUNDS = 5
def benchmark_fn(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "f": f},
num_threads=1,
)
return float(f"{(t0.blocked_autorange().mean):.3f}")
def flush():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
# Adapted from https://github.com/lucasb-eyer/cnn_vit_benchmarks/blob/15b665ff758e8062131353076153905cae00a71f/main.py
def calculate_flops(model, input_dict):
try:
from torchprofile import profile_macs
except ModuleNotFoundError:
raise
# This is a hacky way to convert the kwargs to args as `profile_macs` cries about kwargs.
sig = inspect.signature(model.forward)
param_names = [
p.name
for p in sig.parameters.values()
if p.kind
in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
)
and p.name != "self"
]
bound = sig.bind_partial(**input_dict)
bound.apply_defaults()
args = tuple(bound.arguments[name] for name in param_names)
model.eval()
with torch.no_grad():
macs = profile_macs(model, args)
flops = 2 * macs # 1 MAC operation = 2 FLOPs (1 multiplication + 1 addition)
return flops
def calculate_params(model):
return sum(p.numel() for p in model.parameters())
# Users can define their own in case this doesn't suffice. For most cases,
# it should be sufficient.
def model_init_fn(model_cls, group_offload_kwargs=None, layerwise_upcasting=False, **init_kwargs):
model = model_cls.from_pretrained(**init_kwargs).eval()
if group_offload_kwargs and isinstance(group_offload_kwargs, dict):
model.enable_group_offload(**group_offload_kwargs)
else:
model.to(torch_device)
if layerwise_upcasting:
model.enable_layerwise_casting(
storage_dtype=torch.float8_e4m3fn, compute_dtype=init_kwargs.get("torch_dtype", torch.bfloat16)
)
return model
@dataclass
class BenchmarkScenario:
name: str
model_cls: ModelMixin
model_init_kwargs: Dict[str, Any]
model_init_fn: Callable
get_model_input_dict: Callable
compile_kwargs: Optional[Dict[str, Any]] = None
@require_torch_gpu
class BenchmarkMixin:
def pre_benchmark(self):
flush()
torch.compiler.reset()
def post_benchmark(self, model):
model.cpu()
flush()
torch.compiler.reset()
@torch.no_grad()
def run_benchmark(self, scenario: BenchmarkScenario):
# 0) Basic stats
logger.info(f"Running scenario: {scenario.name}.")
try:
model = model_init_fn(scenario.model_cls, **scenario.model_init_kwargs)
num_params = round(calculate_params(model) / 1e9, 2)
try:
flops = round(calculate_flops(model, input_dict=scenario.get_model_input_dict()) / 1e9, 2)
except Exception as e:
logger.info(f"Problem in calculating FLOPs:\n{e}")
flops = None
model.cpu()
del model
except Exception as e:
logger.info(f"Error while initializing the model and calculating FLOPs:\n{e}")
return {}
self.pre_benchmark()
# 1) plain stats
results = {}
plain = None
try:
plain = self._run_phase(
model_cls=scenario.model_cls,
init_fn=scenario.model_init_fn,
init_kwargs=scenario.model_init_kwargs,
get_input_fn=scenario.get_model_input_dict,
compile_kwargs=None,
)
except Exception as e:
logger.info(f"Benchmark could not be run with the following error:\n{e}")
return results
# 2) compiled stats (if any)
compiled = {"time": None, "memory": None}
if scenario.compile_kwargs:
try:
compiled = self._run_phase(
model_cls=scenario.model_cls,
init_fn=scenario.model_init_fn,
init_kwargs=scenario.model_init_kwargs,
get_input_fn=scenario.get_model_input_dict,
compile_kwargs=scenario.compile_kwargs,
)
except Exception as e:
logger.info(f"Compilation benchmark could not be run with the following error\n: {e}")
if plain is None:
return results
# 3) merge
result = {
"scenario": scenario.name,
"model_cls": scenario.model_cls.__name__,
"num_params_B": num_params,
"flops_G": flops,
"time_plain_s": plain["time"],
"mem_plain_GB": plain["memory"],
"time_compile_s": compiled["time"],
"mem_compile_GB": compiled["memory"],
}
if scenario.compile_kwargs:
result["fullgraph"] = scenario.compile_kwargs.get("fullgraph", False)
result["mode"] = scenario.compile_kwargs.get("mode", "default")
else:
result["fullgraph"], result["mode"] = None, None
return result
def run_bencmarks_and_collate(self, scenarios: Union[BenchmarkScenario, list[BenchmarkScenario]], filename: str):
if not isinstance(scenarios, list):
scenarios = [scenarios]
record_queue = queue.Queue()
stop_signal = object()
def _writer_thread():
while True:
item = record_queue.get()
if item is stop_signal:
break
df_row = pd.DataFrame([item])
write_header = not os.path.exists(filename)
df_row.to_csv(filename, mode="a", header=write_header, index=False)
record_queue.task_done()
record_queue.task_done()
writer = threading.Thread(target=_writer_thread, daemon=True)
writer.start()
for s in scenarios:
try:
record = self.run_benchmark(s)
if record:
record_queue.put(record)
else:
logger.info(f"Record empty from scenario: {s.name}.")
except Exception as e:
logger.info(f"Running scenario ({s.name}) led to error:\n{e}")
record_queue.put(stop_signal)
logger.info(f"Results serialized to {filename=}.")
def _run_phase(
self,
*,
model_cls: ModelMixin,
init_fn: Callable,
init_kwargs: Dict[str, Any],
get_input_fn: Callable,
compile_kwargs: Optional[Dict[str, Any]],
) -> Dict[str, float]:
# setup
self.pre_benchmark()
# init & (optional) compile
model = init_fn(model_cls, **init_kwargs)
if compile_kwargs:
model.compile(**compile_kwargs)
# build inputs
inp = get_input_fn()
# measure
run_ctx = torch._inductor.utils.fresh_inductor_cache() if compile_kwargs else nullcontext()
with run_ctx:
for _ in range(NUM_WARMUP_ROUNDS):
_ = model(**inp)
time_s = benchmark_fn(lambda m, d: m(**d), model, inp)
mem_gb = torch.cuda.max_memory_allocated() / (1024**3)
mem_gb = round(mem_gb, 2)
# teardown
self.post_benchmark(model)
del model
return {"time": time_s, "memory": mem_gb}
+74
View File
@@ -0,0 +1,74 @@
from functools import partial
import torch
from benchmarking_utils import BenchmarkMixin, BenchmarkScenario, model_init_fn
from diffusers import WanTransformer3DModel
from diffusers.utils.testing_utils import torch_device
CKPT_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
RESULT_FILENAME = "wan.csv"
def get_input_dict(**device_dtype_kwargs):
# height: 480
# width: 832
# num_frames: 81
# max_sequence_length: 512
hidden_states = torch.randn(1, 16, 21, 60, 104, **device_dtype_kwargs)
encoder_hidden_states = torch.randn(1, 512, 4096, **device_dtype_kwargs)
timestep = torch.tensor([1.0], **device_dtype_kwargs)
return {"hidden_states": hidden_states, "encoder_hidden_states": encoder_hidden_states, "timestep": timestep}
if __name__ == "__main__":
scenarios = [
BenchmarkScenario(
name=f"{CKPT_ID}-bf16",
model_cls=WanTransformer3DModel,
model_init_kwargs={
"pretrained_model_name_or_path": CKPT_ID,
"torch_dtype": torch.bfloat16,
"subfolder": "transformer",
},
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
model_init_fn=model_init_fn,
compile_kwargs={"fullgraph": True},
),
BenchmarkScenario(
name=f"{CKPT_ID}-layerwise-upcasting",
model_cls=WanTransformer3DModel,
model_init_kwargs={
"pretrained_model_name_or_path": CKPT_ID,
"torch_dtype": torch.bfloat16,
"subfolder": "transformer",
},
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
model_init_fn=partial(model_init_fn, layerwise_upcasting=True),
),
BenchmarkScenario(
name=f"{CKPT_ID}-group-offload-leaf",
model_cls=WanTransformer3DModel,
model_init_kwargs={
"pretrained_model_name_or_path": CKPT_ID,
"torch_dtype": torch.bfloat16,
"subfolder": "transformer",
},
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
model_init_fn=partial(
model_init_fn,
group_offload_kwargs={
"onload_device": torch_device,
"offload_device": torch.device("cpu"),
"offload_type": "leaf_level",
"use_stream": True,
"non_blocking": True,
},
),
),
]
runner = BenchmarkMixin()
runner.run_bencmarks_and_collate(scenarios, filename=RESULT_FILENAME)
+166
View File
@@ -0,0 +1,166 @@
import argparse
import os
import sys
import gpustat
import pandas as pd
import psycopg2
import psycopg2.extras
from psycopg2.extensions import register_adapter
from psycopg2.extras import Json
register_adapter(dict, Json)
FINAL_CSV_FILENAME = "collated_results.csv"
# https://github.com/huggingface/transformers/blob/593e29c5e2a9b17baec010e8dc7c1431fed6e841/benchmark/init_db.sql#L27
BENCHMARKS_TABLE_NAME = "benchmarks"
MEASUREMENTS_TABLE_NAME = "model_measurements"
def _init_benchmark(conn, branch, commit_id, commit_msg):
gpu_stats = gpustat.GPUStatCollection.new_query()
metadata = {"gpu_name": gpu_stats[0]["name"]}
repository = "huggingface/diffusers"
with conn.cursor() as cur:
cur.execute(
f"INSERT INTO {BENCHMARKS_TABLE_NAME} (repository, branch, commit_id, commit_message, metadata) VALUES (%s, %s, %s, %s, %s) RETURNING benchmark_id",
(repository, branch, commit_id, commit_msg, metadata),
)
benchmark_id = cur.fetchone()[0]
print(f"Initialised benchmark #{benchmark_id}")
return benchmark_id
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"branch",
type=str,
help="The branch name on which the benchmarking is performed.",
)
parser.add_argument(
"commit_id",
type=str,
help="The commit hash on which the benchmarking is performed.",
)
parser.add_argument(
"commit_msg",
type=str,
help="The commit message associated with the commit, truncated to 70 characters.",
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
try:
conn = psycopg2.connect(
host=os.getenv("PGHOST"),
database=os.getenv("PGDATABASE"),
user=os.getenv("PGUSER"),
password=os.getenv("PGPASSWORD"),
)
print("DB connection established successfully.")
except Exception as e:
print(f"Problem during DB init: {e}")
sys.exit(1)
try:
benchmark_id = _init_benchmark(
conn=conn,
branch=args.branch,
commit_id=args.commit_id,
commit_msg=args.commit_msg,
)
except Exception as e:
print(f"Problem during initializing benchmark: {e}")
sys.exit(1)
cur = conn.cursor()
df = pd.read_csv(FINAL_CSV_FILENAME)
# Helper to cast values (or None) given a dtype
def _cast_value(val, dtype: str):
if pd.isna(val):
return None
if dtype == "text":
return str(val).strip()
if dtype == "float":
try:
return float(val)
except ValueError:
return None
if dtype == "bool":
s = str(val).strip().lower()
if s in ("true", "t", "yes", "1"):
return True
if s in ("false", "f", "no", "0"):
return False
if val in (1, 1.0):
return True
if val in (0, 0.0):
return False
return None
return val
try:
rows_to_insert = []
for _, row in df.iterrows():
scenario = _cast_value(row.get("scenario"), "text")
model_cls = _cast_value(row.get("model_cls"), "text")
num_params_B = _cast_value(row.get("num_params_B"), "float")
flops_G = _cast_value(row.get("flops_G"), "float")
time_plain_s = _cast_value(row.get("time_plain_s"), "float")
mem_plain_GB = _cast_value(row.get("mem_plain_GB"), "float")
time_compile_s = _cast_value(row.get("time_compile_s"), "float")
mem_compile_GB = _cast_value(row.get("mem_compile_GB"), "float")
fullgraph = _cast_value(row.get("fullgraph"), "bool")
mode = _cast_value(row.get("mode"), "text")
# If "github_sha" column exists in the CSV, cast it; else default to None
if "github_sha" in df.columns:
github_sha = _cast_value(row.get("github_sha"), "text")
else:
github_sha = None
measurements = {
"scenario": scenario,
"model_cls": model_cls,
"num_params_B": num_params_B,
"flops_G": flops_G,
"time_plain_s": time_plain_s,
"mem_plain_GB": mem_plain_GB,
"time_compile_s": time_compile_s,
"mem_compile_GB": mem_compile_GB,
"fullgraph": fullgraph,
"mode": mode,
"github_sha": github_sha,
}
rows_to_insert.append((benchmark_id, measurements))
# Batch-insert all rows
insert_sql = f"""
INSERT INTO {MEASUREMENTS_TABLE_NAME} (
benchmark_id,
measurements
)
VALUES (%s, %s);
"""
psycopg2.extras.execute_batch(cur, insert_sql, rows_to_insert)
conn.commit()
cur.close()
conn.close()
except Exception as e:
print(f"Exception: {e}")
sys.exit(1)
+30 -26
View File
@@ -1,19 +1,19 @@
import glob
import sys
import os
import pandas as pd
from huggingface_hub import hf_hub_download, upload_file
from huggingface_hub.utils import EntryNotFoundError
sys.path.append(".")
from utils import BASE_PATH, FINAL_CSV_FILE, GITHUB_SHA, REPO_ID, collate_csv # noqa: E402
REPO_ID = "diffusers/benchmarks"
def has_previous_benchmark() -> str:
from run_all import FINAL_CSV_FILENAME
csv_path = None
try:
csv_path = hf_hub_download(repo_id=REPO_ID, repo_type="dataset", filename=FINAL_CSV_FILE)
csv_path = hf_hub_download(repo_id=REPO_ID, repo_type="dataset", filename=FINAL_CSV_FILENAME)
except EntryNotFoundError:
csv_path = None
return csv_path
@@ -26,46 +26,50 @@ def filter_float(value):
def push_to_hf_dataset():
all_csvs = sorted(glob.glob(f"{BASE_PATH}/*.csv"))
collate_csv(all_csvs, FINAL_CSV_FILE)
from run_all import FINAL_CSV_FILENAME, GITHUB_SHA
# If there's an existing benchmark file, we should report the changes.
csv_path = has_previous_benchmark()
if csv_path is not None:
current_results = pd.read_csv(FINAL_CSV_FILE)
current_results = pd.read_csv(FINAL_CSV_FILENAME)
previous_results = pd.read_csv(csv_path)
numeric_columns = current_results.select_dtypes(include=["float64", "int64"]).columns
numeric_columns = [
c for c in numeric_columns if c not in ["batch_size", "num_inference_steps", "actual_gpu_memory (gbs)"]
]
for column in numeric_columns:
previous_results[column] = previous_results[column].map(lambda x: filter_float(x))
# get previous values as floats, aligned to current index
prev_vals = previous_results[column].map(filter_float).reindex(current_results.index)
# Calculate the percentage change
current_results[column] = current_results[column].astype(float)
previous_results[column] = previous_results[column].astype(float)
percent_change = ((current_results[column] - previous_results[column]) / previous_results[column]) * 100
# get current values as floats
curr_vals = current_results[column].astype(float)
# Format the values with '+' or '-' sign and append to original values
current_results[column] = current_results[column].map(str) + percent_change.map(
lambda x: f" ({'+' if x > 0 else ''}{x:.2f}%)"
# stringify the current values
curr_str = curr_vals.map(str)
# build an appendage only when prev exists and differs
append_str = prev_vals.where(prev_vals.notnull() & (prev_vals != curr_vals), other=pd.NA).map(
lambda x: f" ({x})" if pd.notnull(x) else ""
)
# There might be newly added rows. So, filter out the NaNs.
current_results[column] = current_results[column].map(lambda x: x.replace(" (nan%)", ""))
# Overwrite the current result file.
current_results.to_csv(FINAL_CSV_FILE, index=False)
# combine
current_results[column] = curr_str + append_str
os.remove(FINAL_CSV_FILENAME)
current_results.to_csv(FINAL_CSV_FILENAME, index=False)
commit_message = f"upload from sha: {GITHUB_SHA}" if GITHUB_SHA is not None else "upload benchmark results"
upload_file(
repo_id=REPO_ID,
path_in_repo=FINAL_CSV_FILE,
path_or_fileobj=FINAL_CSV_FILE,
path_in_repo=FINAL_CSV_FILENAME,
path_or_fileobj=FINAL_CSV_FILENAME,
repo_type="dataset",
commit_message=commit_message,
)
upload_file(
repo_id="diffusers/benchmark-analyzer",
path_in_repo=FINAL_CSV_FILENAME,
path_or_fileobj=FINAL_CSV_FILENAME,
repo_type="space",
commit_message=commit_message,
)
if __name__ == "__main__":
+6
View File
@@ -0,0 +1,6 @@
pandas
psutil
gpustat
torchprofile
bitsandbytes
psycopg2==2.9.9
+55 -72
View File
@@ -1,101 +1,84 @@
import glob
import logging
import os
import subprocess
import sys
from typing import List
import pandas as pd
sys.path.append(".")
from benchmark_text_to_image import ALL_T2I_CKPTS # noqa: E402
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
logger = logging.getLogger(__name__)
PATTERN = "benchmark_*.py"
PATTERN = "benchmarking_*.py"
FINAL_CSV_FILENAME = "collated_results.csv"
GITHUB_SHA = os.getenv("GITHUB_SHA", None)
class SubprocessCallException(Exception):
pass
# Taken from `test_examples_utils.py`
def run_command(command: List[str], return_stdout=False):
"""
Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture
if an error occurred while running `command`
"""
def run_command(command: list[str], return_stdout=False):
try:
output = subprocess.check_output(command, stderr=subprocess.STDOUT)
if return_stdout:
if hasattr(output, "decode"):
output = output.decode("utf-8")
return output
if return_stdout and hasattr(output, "decode"):
return output.decode("utf-8")
except subprocess.CalledProcessError as e:
raise SubprocessCallException(
f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
) from e
raise SubprocessCallException(f"Command `{' '.join(command)}` failed with:\n{e.output.decode()}") from e
def main():
python_files = glob.glob(PATTERN)
def merge_csvs(final_csv: str = "collated_results.csv"):
all_csvs = glob.glob("*.csv")
all_csvs = [f for f in all_csvs if f != final_csv]
if not all_csvs:
logger.info("No result CSVs found to merge.")
return
for file in python_files:
print(f"****** Running file: {file} ******")
# Run with canonical settings.
if file != "benchmark_text_to_image.py" and file != "benchmark_ip_adapters.py":
command = f"python {file}"
run_command(command.split())
command += " --run_compile"
run_command(command.split())
# Run variants.
for file in python_files:
# See: https://github.com/pytorch/pytorch/issues/129637
if file == "benchmark_ip_adapters.py":
df_list = []
for f in all_csvs:
try:
d = pd.read_csv(f)
except pd.errors.EmptyDataError:
# If a file existed but was zerobytes or corrupted, skip it
continue
df_list.append(d)
if file == "benchmark_text_to_image.py":
for ckpt in ALL_T2I_CKPTS:
command = f"python {file} --ckpt {ckpt}"
if not df_list:
logger.info("All result CSVs were empty or invalid; nothing to merge.")
return
if "turbo" in ckpt:
command += " --num_inference_steps 1"
final_df = pd.concat(df_list, ignore_index=True)
if GITHUB_SHA is not None:
final_df["github_sha"] = GITHUB_SHA
final_df.to_csv(final_csv, index=False)
logger.info(f"Merged {len(all_csvs)} partial CSVs → {final_csv}.")
run_command(command.split())
command += " --run_compile"
run_command(command.split())
def run_scripts():
python_files = sorted(glob.glob(PATTERN))
python_files = [f for f in python_files if f != "benchmarking_utils.py"]
elif file == "benchmark_sd_img.py":
for ckpt in ["stabilityai/stable-diffusion-xl-refiner-1.0", "stabilityai/sdxl-turbo"]:
command = f"python {file} --ckpt {ckpt}"
for file in python_files:
script_name = file.split(".py")[0].split("_")[-1] # example: benchmarking_foo.py -> foo
logger.info(f"\n****** Running file: {file} ******")
if ckpt == "stabilityai/sdxl-turbo":
command += " --num_inference_steps 2"
partial_csv = f"{script_name}.csv"
if os.path.exists(partial_csv):
logger.info(f"Found {partial_csv}. Removing for safer numbers and duplication.")
os.remove(partial_csv)
run_command(command.split())
command += " --run_compile"
run_command(command.split())
command = ["python", file]
try:
run_command(command)
logger.info(f"{file} finished normally.")
except SubprocessCallException as e:
logger.info(f"Error running {file}:\n{e}")
finally:
logger.info(f"→ Merging partial CSVs after {file}")
merge_csvs(final_csv=FINAL_CSV_FILENAME)
elif file in ["benchmark_sd_inpainting.py", "benchmark_ip_adapters.py"]:
sdxl_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"
command = f"python {file} --ckpt {sdxl_ckpt}"
run_command(command.split())
command += " --run_compile"
run_command(command.split())
elif file in ["benchmark_controlnet.py", "benchmark_t2i_adapter.py"]:
sdxl_ckpt = (
"diffusers/controlnet-canny-sdxl-1.0"
if "controlnet" in file
else "TencentARC/t2i-adapter-canny-sdxl-1.0"
)
command = f"python {file} --ckpt {sdxl_ckpt}"
run_command(command.split())
command += " --run_compile"
run_command(command.split())
logger.info(f"\nAll scripts attempted. Final collated CSV: {FINAL_CSV_FILENAME}")
if __name__ == "__main__":
main()
run_scripts()
-98
View File
@@ -1,98 +0,0 @@
import argparse
import csv
import gc
import os
from dataclasses import dataclass
from typing import Dict, List, Union
import torch
import torch.utils.benchmark as benchmark
GITHUB_SHA = os.getenv("GITHUB_SHA", None)
BENCHMARK_FIELDS = [
"pipeline_cls",
"ckpt_id",
"batch_size",
"num_inference_steps",
"model_cpu_offload",
"run_compile",
"time (secs)",
"memory (gbs)",
"actual_gpu_memory (gbs)",
"github_sha",
]
PROMPT = "ghibli style, a fantasy landscape with castles"
BASE_PATH = os.getenv("BASE_PATH", ".")
TOTAL_GPU_MEMORY = float(os.getenv("TOTAL_GPU_MEMORY", torch.cuda.get_device_properties(0).total_memory / (1024**3)))
REPO_ID = "diffusers/benchmarks"
FINAL_CSV_FILE = "collated_results.csv"
@dataclass
class BenchmarkInfo:
time: float
memory: float
def flush():
"""Wipes off memory."""
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
def bytes_to_giga_bytes(bytes):
return f"{(bytes / 1024 / 1024 / 1024):.3f}"
def benchmark_fn(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "f": f},
num_threads=torch.get_num_threads(),
)
return f"{(t0.blocked_autorange().mean):.3f}"
def generate_csv_dict(
pipeline_cls: str, ckpt: str, args: argparse.Namespace, benchmark_info: BenchmarkInfo
) -> Dict[str, Union[str, bool, float]]:
"""Packs benchmarking data into a dictionary for latter serialization."""
data_dict = {
"pipeline_cls": pipeline_cls,
"ckpt_id": ckpt,
"batch_size": args.batch_size,
"num_inference_steps": args.num_inference_steps,
"model_cpu_offload": args.model_cpu_offload,
"run_compile": args.run_compile,
"time (secs)": benchmark_info.time,
"memory (gbs)": benchmark_info.memory,
"actual_gpu_memory (gbs)": f"{(TOTAL_GPU_MEMORY):.3f}",
"github_sha": GITHUB_SHA,
}
return data_dict
def write_to_csv(file_name: str, data_dict: Dict[str, Union[str, bool, float]]):
"""Serializes a dictionary into a CSV file."""
with open(file_name, mode="w", newline="") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=BENCHMARK_FIELDS)
writer.writeheader()
writer.writerow(data_dict)
def collate_csv(input_files: List[str], output_file: str):
"""Collates multiple identically structured CSVs into a single CSV file."""
with open(output_file, mode="w", newline="") as outfile:
writer = csv.DictWriter(outfile, fieldnames=BENCHMARK_FIELDS)
writer.writeheader()
for file in input_files:
with open(file, mode="r") as infile:
reader = csv.DictReader(infile)
for row in reader:
writer.writerow(row)
+2
View File
@@ -64,6 +64,8 @@
title: Overview
- local: using-diffusers/create_a_server
title: Create a server
- local: using-diffusers/batched_inference
title: Batch inference
- local: training/distributed_inference
title: Distributed inference
- local: using-diffusers/scheduler_features
+6
View File
@@ -28,3 +28,9 @@ Cache methods speedup diffusion transformers by storing and reusing intermediate
[[autodoc]] FasterCacheConfig
[[autodoc]] apply_faster_cache
### FirstBlockCacheConfig
[[autodoc]] FirstBlockCacheConfig
[[autodoc]] apply_first_block_cache
+3
View File
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
> [!WARNING]
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
# aMUSEd
aMUSEd was introduced in [aMUSEd: An Open MUSE Reproduction](https://huggingface.co/papers/2401.01808) by Suraj Patil, William Berman, Robin Rombach, and Patrick von Platen.
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
> [!WARNING]
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
# Attend-and-Excite
Attend-and-Excite for Stable Diffusion was proposed in [Attend-and-Excite: Attention-Based Semantic Guidance for Text-to-Image Diffusion Models](https://attendandexcite.github.io/Attend-and-Excite/) and provides textual attention control over image generation.
+3
View File
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
> [!WARNING]
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
# AudioLDM
AudioLDM was proposed in [AudioLDM: Text-to-Audio Generation with Latent Diffusion Models](https://huggingface.co/papers/2301.12503) by Haohe Liu et al. Inspired by [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview), AudioLDM
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
> [!WARNING]
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
# BLIP-Diffusion
BLIP-Diffusion was proposed in [BLIP-Diffusion: Pre-trained Subject Representation for Controllable Text-to-Image Generation and Editing](https://huggingface.co/papers/2305.14720). It enables zero-shot subject-driven generation and control-guided zero-shot generation.
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
> [!WARNING]
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
# ControlNet-XS
<div class="flex flex-wrap space-x-1">
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
> [!WARNING]
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
# ControlNet-XS with Stable Diffusion XL
ControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results.
+25
View File
@@ -24,6 +24,31 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
</Tip>
## Loading original format checkpoints
Original format checkpoints that have not been converted to diffusers-expected format can be loaded using the `from_single_file` method.
```python
import torch
from diffusers import Cosmos2TextToImagePipeline, CosmosTransformer3DModel
model_id = "nvidia/Cosmos-Predict2-2B-Text2Image"
transformer = CosmosTransformer3DModel.from_single_file(
"https://huggingface.co/nvidia/Cosmos-Predict2-2B-Text2Image/blob/main/model.pt",
torch_dtype=torch.bfloat16,
).to("cuda")
pipe = Cosmos2TextToImagePipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16)
pipe.to("cuda")
prompt = "A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface. Suds form and bubble around the scrubber, creating a satisfying visual of cleanliness in progress. The sound of scrubbing fills the air, accompanied by the gentle clinking of the dish against the sink. As the scrubber continues its task, the dish transforms, gleaming under the bright kitchen lights, symbolizing the triumph of cleanliness over mess."
negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality."
output = pipe(
prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(1)
).images[0]
output.save("output.png")
```
## CosmosTextToWorldPipeline
[[autodoc]] CosmosTextToWorldPipeline
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
> [!WARNING]
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
# Dance Diffusion
[Dance Diffusion](https://github.com/Harmonai-org/sample-generator) is by Zach Evans.
+3
View File
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
> [!WARNING]
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
# DiffEdit
[DiffEdit: Diffusion-based semantic image editing with mask guidance](https://huggingface.co/papers/2210.11427) is by Guillaume Couairon, Jakob Verbeek, Holger Schwenk, and Matthieu Cord.
+3
View File
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
> [!WARNING]
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
# I2VGen-XL
[I2VGen-XL: High-Quality Image-to-Video Synthesis via Cascaded Diffusion Models](https://hf.co/papers/2311.04145.pdf) by Shiwei Zhang, Jiayu Wang, Yingya Zhang, Kang Zhao, Hangjie Yuan, Zhiwu Qin, Xiang Wang, Deli Zhao, and Jingren Zhou.
+3
View File
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
> [!WARNING]
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
# MusicLDM
MusicLDM was proposed in [MusicLDM: Enhancing Novelty in Text-to-Music Generation Using Beat-Synchronous Mixup Strategies](https://huggingface.co/papers/2308.01546) by Ke Chen, Yusong Wu, Haohe Liu, Marianna Nezhurina, Taylor Berg-Kirkpatrick, Shlomo Dubnov.
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
> [!WARNING]
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
# Paint by Example
[Paint by Example: Exemplar-based Image Editing with Diffusion Models](https://huggingface.co/papers/2211.13227) is by Binxin Yang, Shuyang Gu, Bo Zhang, Ting Zhang, Xuejin Chen, Xiaoyan Sun, Dong Chen, Fang Wen.
+3
View File
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
> [!WARNING]
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
# MultiDiffusion
<div class="flex flex-wrap space-x-1">
+3
View File
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
> [!WARNING]
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
# Image-to-Video Generation with PIA (Personalized Image Animator)
<div class="flex flex-wrap space-x-1">
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
> [!WARNING]
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
# Self-Attention Guidance
[Improving Sample Quality of Diffusion Models Using Self-Attention Guidance](https://huggingface.co/papers/2210.00939) is by Susung Hong et al.
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
> [!WARNING]
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
# Semantic Guidance
Semantic Guidance for Diffusion Models was proposed in [SEGA: Instructing Text-to-Image Models using Semantic Guidance](https://huggingface.co/papers/2301.12247) and provides strong semantic control over image generation.
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
> [!WARNING]
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
# GLIGEN (Grounded Language-to-Image Generation)
The GLIGEN model was created by researchers and engineers from [University of Wisconsin-Madison, Columbia University, and Microsoft](https://github.com/gligen/GLIGEN). The [`StableDiffusionGLIGENPipeline`] and [`StableDiffusionGLIGENTextImagePipeline`] can generate photorealistic images conditioned on grounding inputs. Along with text and bounding boxes with [`StableDiffusionGLIGENPipeline`], if input images are given, [`StableDiffusionGLIGENTextImagePipeline`] can insert objects described by text at the region defined by bounding boxes. Otherwise, it'll generate an image described by the caption/prompt and insert objects described by text at the region defined by bounding boxes. It's trained on COCO2014D and COCO2014CD datasets, and the model uses a frozen CLIP ViT-L/14 text encoder to condition itself on grounding inputs.
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
> [!WARNING]
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
# K-Diffusion
[k-diffusion](https://github.com/crowsonkb/k-diffusion) is a popular library created by [Katherine Crowson](https://github.com/crowsonkb/). We provide `StableDiffusionKDiffusionPipeline` and `StableDiffusionXLKDiffusionPipeline` that allow you to run Stable DIffusion with samplers from k-diffusion.
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
> [!WARNING]
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
# Text-to-(RGB, depth)
<div class="flex flex-wrap space-x-1">
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
> [!WARNING]
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
# Safe Stable Diffusion
Safe Stable Diffusion was proposed in [Safe Latent Diffusion: Mitigating Inappropriate Degeneration in Diffusion Models](https://huggingface.co/papers/2211.05105) and mitigates inappropriate degeneration from Stable Diffusion models because they're trained on unfiltered web-crawled datasets. For instance Stable Diffusion may unexpectedly generate nudity, violence, images depicting self-harm, and otherwise offensive content. Safe Stable Diffusion is an extension of Stable Diffusion that drastically reduces this type of content.
@@ -10,11 +10,8 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
<Tip warning={true}>
🧪 This pipeline is for research purposes only.
</Tip>
> [!WARNING]
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
# Text-to-video
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
> [!WARNING]
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
# Text2Video-Zero
<div class="flex flex-wrap space-x-1">
+3
View File
@@ -7,6 +7,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
> [!WARNING]
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
# unCLIP
[Hierarchical Text-Conditional Image Generation with CLIP Latents](https://huggingface.co/papers/2204.06125) is by Aditya Ramesh, Prafulla Dhariwal, Alex Nichol, Casey Chu, Mark Chen. The unCLIP model in 🤗 Diffusers comes from kakaobrain's [karlo](https://github.com/kakaobrain/karlo).
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
> [!WARNING]
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
# UniDiffuser
<div class="flex flex-wrap space-x-1">
+3 -3
View File
@@ -302,12 +302,12 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
```py
# pip install ftfy
import torch
from diffusers import WanPipeline, AutoModel
from diffusers import WanPipeline, WanTransformer3DModel, AutoencoderKLWan
vae = AutoModel.from_single_file(
vae = AutoencoderKLWan.from_single_file(
"https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/vae/wan_2.1_vae.safetensors"
)
transformer = AutoModel.from_single_file(
transformer = WanTransformer3DModel.from_single_file(
"https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors",
torch_dtype=torch.bfloat16
)
@@ -12,6 +12,9 @@ specific language governing permissions and limitations under the License.
# Würstchen
> [!WARNING]
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
</div>
@@ -315,6 +315,8 @@ pipeline.load_lora_weights(
> [!TIP]
> Move your code inside the `with torch._dynamo.config.patch(error_on_recompile=True)` context manager to detect if a model was recompiled. If a model is recompiled despite following all the steps above, please open an [issue](https://github.com/huggingface/diffusers/issues) with a reproducible example.
If you expect to varied resolutions during inference with this feature, then make sure set `dynamic=True` during compilation. Refer to [this document](../optimization/fp16#dynamic-shape-compilation) for more details.
There are still scenarios where recompulation is unavoidable, such as when the hotswapped LoRA targets more layers than the initial adapter. Try to load the LoRA that targets the most layers *first*. For more details about this limitation, refer to the PEFT [hotswapping](https://huggingface.co/docs/peft/main/en/package_reference/hotswap#peft.utils.hotswap.hotswap_adapter) docs.
## Merge
@@ -0,0 +1,264 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Batch inference
Batch inference processes multiple prompts at a time to increase throughput. It is more efficient because processing multiple prompts at once maximizes GPU usage versus processing a single prompt and underutilizing the GPU.
The downside is increased latency because you must wait for the entire batch to complete, and more GPU memory is required for large batches.
<hfoptions id="usage">
<hfoption id="text-to-image">
For text-to-image, pass a list of prompts to the pipeline.
```py
import torch
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16
).to("cuda")
prompts = [
"cinematic photo of A beautiful sunset over mountains, 35mm photograph, film, professional, 4k, highly detailed",
"cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain",
"pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics"
]
images = pipeline(
prompt=prompts,
).images
fig, axes = plt.subplots(2, 2, figsize=(12, 12))
axes = axes.flatten()
for i, image in enumerate(images):
axes[i].imshow(image)
axes[i].set_title(f"Image {i+1}")
axes[i].axis('off')
plt.tight_layout()
plt.show()
```
To generate multiple variations of one prompt, use the `num_images_per_prompt` argument.
```py
import torch
import matplotlib.pyplot as plt
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16
).to("cuda")
images = pipeline(
prompt="pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics",
num_images_per_prompt=4
).images
fig, axes = plt.subplots(2, 2, figsize=(12, 12))
axes = axes.flatten()
for i, image in enumerate(images):
axes[i].imshow(image)
axes[i].set_title(f"Image {i+1}")
axes[i].axis('off')
plt.tight_layout()
plt.show()
```
Combine both approaches to generate different variations of different prompts.
```py
images = pipeline(
prompt=prompts,
num_images_per_prompt=2,
).images
fig, axes = plt.subplots(2, 2, figsize=(12, 12))
axes = axes.flatten()
for i, image in enumerate(images):
axes[i].imshow(image)
axes[i].set_title(f"Image {i+1}")
axes[i].axis('off')
plt.tight_layout()
plt.show()
```
</hfoption>
<hfoption id="image-to-image">
For image-to-image, pass a list of input images and prompts to the pipeline.
```py
import torch
from diffusers.utils import load_image
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16
).to("cuda")
input_images = [
load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png"),
load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"),
load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/detail-prompt.png")
]
prompts = [
"cinematic photo of a beautiful sunset over mountains, 35mm photograph, film, professional, 4k, highly detailed",
"cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain",
"pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics"
]
images = pipeline(
prompt=prompts,
image=input_images,
guidance_scale=8.0,
strength=0.5
).images
fig, axes = plt.subplots(2, 2, figsize=(12, 12))
axes = axes.flatten()
for i, image in enumerate(images):
axes[i].imshow(image)
axes[i].set_title(f"Image {i+1}")
axes[i].axis('off')
plt.tight_layout()
plt.show()
```
To generate multiple variations of one prompt, use the `num_images_per_prompt` argument.
```py
import torch
import matplotlib.pyplot as plt
from diffusers.utils import load_image
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16
).to("cuda")
input_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/detail-prompt.png")
images = pipeline(
prompt="pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics",
image=input_image,
num_images_per_prompt=4
).images
fig, axes = plt.subplots(2, 2, figsize=(12, 12))
axes = axes.flatten()
for i, image in enumerate(images):
axes[i].imshow(image)
axes[i].set_title(f"Image {i+1}")
axes[i].axis('off')
plt.tight_layout()
plt.show()
```
Combine both approaches to generate different variations of different prompts.
```py
input_images = [
load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"),
load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/detail-prompt.png")
]
prompts = [
"cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain",
"pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics"
]
images = pipeline(
prompt=prompts,
image=input_images,
num_images_per_prompt=2,
).images
fig, axes = plt.subplots(2, 2, figsize=(12, 12))
axes = axes.flatten()
for i, image in enumerate(images):
axes[i].imshow(image)
axes[i].set_title(f"Image {i+1}")
axes[i].axis('off')
plt.tight_layout()
plt.show()
```
</hfoption>
</hfoptions>
## Deterministic generation
Enable reproducible batch generation by passing a list of [Generators](https://pytorch.org/docs/stable/generated/torch.Generator.html) to the pipeline and tie each `Generator` to a seed to reuse it.
Use a list comprehension to iterate over the batch size specified in `range()` to create a unique `Generator` object for each image in the batch.
Don't multiply the `Generator` by the batch size because that only creates one `Generator` object that is used sequentially for each image in the batch.
```py
generator = [torch.Generator(device="cuda").manual_seed(0)] * 3
```
Pass the `generator` to the pipeline.
```py
import torch
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16
).to("cuda")
generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(3)]
prompts = [
"cinematic photo of A beautiful sunset over mountains, 35mm photograph, film, professional, 4k, highly detailed",
"cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain",
"pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics"
]
images = pipeline(
prompt=prompts,
generator=generator
).images
fig, axes = plt.subplots(2, 2, figsize=(12, 12))
axes = axes.flatten()
for i, image in enumerate(images):
axes[i].imshow(image)
axes[i].set_title(f"Image {i+1}")
axes[i].axis('off')
plt.tight_layout()
plt.show()
```
You can use this to iteratively select an image associated with a seed and then improve on it by crafting a more detailed prompt.
+20 -29
View File
@@ -70,41 +70,32 @@ pipeline = StableDiffusionPipeline.from_single_file(
</hfoption>
</hfoptions>
#### LoRA files
#### LoRAs
[LoRA](https://hf.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a lightweight adapter that is fast and easy to train, making them especially popular for generating images in a certain way or style. These adapters are commonly stored in a safetensors file, and are widely popular on model sharing platforms like [civitai](https://civitai.com/).
[LoRAs](../tutorials/using_peft_for_inference) are lightweight checkpoints fine-tuned to generate images or video in a specific style. If you are using a checkpoint trained with a Diffusers training script, the LoRA configuration is automatically saved as metadata in a safetensors file. When the safetensors file is loaded, the metadata is parsed to correctly configure the LoRA and avoids missing or incorrect LoRA configurations.
LoRAs are loaded into a base model with the [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] method.
```py
from diffusers import StableDiffusionXLPipeline
import torch
# base model
pipeline = StableDiffusionXLPipeline.from_pretrained(
"Lykon/dreamshaper-xl-1-0", torch_dtype=torch.float16, variant="fp16"
).to("cuda")
# download LoRA weights
!wget https://civitai.com/api/download/models/168776 -O blueprintify.safetensors
# load LoRA weights
pipeline.load_lora_weights(".", weight_name="blueprintify.safetensors")
prompt = "bl3uprint, a highly detailed blueprint of the empire state building, explaining how to build all parts, many txt, blueprint grid backdrop"
negative_prompt = "lowres, cropped, worst quality, low quality, normal quality, artifacts, signature, watermark, username, blurry, more than one bridge, bad architecture"
image = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
generator=torch.manual_seed(0),
).images[0]
image
```
The easiest way to inspect the metadata, if available, is by clicking on the Safetensors logo next to the weights.
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/blueprint-lora.png"/>
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/safetensors_lora.png"/>
</div>
For LoRAs that aren't trained with Diffusers, you can still save metadata with the `transformer_lora_adapter_metadata` and `text_encoder_lora_adapter_metadata` arguments in [`~loaders.FluxLoraLoaderMixin.save_lora_weights`] as long as it is a safetensors file.
```py
import torch
from diffusers import FluxPipeline
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
).to("cuda")
pipeline.load_lora_weights("linoyts/yarn_art_Flux_LoRA")
pipeline.save_lora_weights(
transformer_lora_adapter_metadata={"r": 16, "lora_alpha": 16},
text_encoder_lora_adapter_metadata={"r": 8, "lora_alpha": 8}
)
```
### ckpt
> [!WARNING]
@@ -136,53 +136,3 @@ result2 = pipe(prompt=prompt, num_inference_steps=50, generator=g, output_type="
print("L_inf dist =", abs(result1 - result2).max())
"L_inf dist = tensor(0., device='cuda:0')"
```
## Deterministic batch generation
A practical application of creating reproducible pipelines is *deterministic batch generation*. You generate a batch of images and select one image to improve with a more detailed prompt. The main idea is to pass a list of [Generator's](https://pytorch.org/docs/stable/generated/torch.Generator.html) to the pipeline and tie each `Generator` to a seed so you can reuse it.
Let's use the [stable-diffusion-v1-5/stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) checkpoint and generate a batch of images.
```py
import torch
from diffusers import DiffusionPipeline
from diffusers.utils import make_image_grid
pipeline = DiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
)
pipeline = pipeline.to("cuda")
```
Define four different `Generator`s and assign each `Generator` a seed (`0` to `3`). Then generate a batch of images and pick one to iterate on.
> [!WARNING]
> Use a list comprehension that iterates over the batch size specified in `range()` to create a unique `Generator` object for each image in the batch. If you multiply the `Generator` by the batch size integer, it only creates *one* `Generator` object that is used sequentially for each image in the batch.
>
> ```py
> [torch.Generator().manual_seed(seed)] * 4
> ```
```python
generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(4)]
prompt = "Labrador in the style of Vermeer"
images = pipeline(prompt, generator=generator, num_images_per_prompt=4).images[0]
make_image_grid(images, rows=2, cols=2)
```
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/reusabe_seeds.jpg"/>
</div>
Let's improve the first image (you can choose any image you want) which corresponds to the `Generator` with seed `0`. Add some additional text to your prompt and then make sure you reuse the same `Generator` with seed `0`. All the generated images should resemble the first image.
```python
prompt = [prompt + t for t in [", highly realistic", ", artsy", ", trending", ", colorful"]]
generator = [torch.Generator(device="cuda").manual_seed(0) for i in range(4)]
images = pipeline(prompt, generator=generator).images
make_image_grid(images, rows=2, cols=2)
```
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/reusabe_seeds_2.jpg"/>
</div>
@@ -242,3 +242,15 @@ unet = UNet2DConditionModel.from_pretrained(
)
unet.save_pretrained("./local-unet", variant="non_ema")
```
Use the `torch_dtype` argument in [`~ModelMixin.from_pretrained`] to specify the dtype to load a model in.
```py
from diffusers import AutoModel
unet = AutoModel.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", torch_dtype=torch.float16
)
```
You can also use the [torch.Tensor.to](https://docs.pytorch.org/docs/stable/generated/torch.Tensor.to.html) method to convert to the specified dtype on the fly. It converts *all* weights unlike the `torch_dtype` argument that respects the `_keep_in_fp32_modules`. This is important for models whose layers must remain in fp32 for numerical stability and best generation quality (see example [here](https://github.com/huggingface/diffusers/blob/f864a9a352fa4a220d860bfdd1782e3e5af96382/src/diffusers/models/transformers/transformer_wan.py#L374)).
+49 -3
View File
@@ -263,9 +263,19 @@ This reduces memory requirements significantly w/o a significant quality loss. N
## Training Kontext
[Kontext](https://bfl.ai/announcements/flux-1-kontext) lets us perform image editing as well as image generation. Even though it can accept both image and text as inputs, one can use it for text-to-image (T2I) generation, too. We
provide a simple script for LoRA fine-tuning Kontext in [train_dreambooth_lora_flux_kontext.py](./train_dreambooth_lora_flux_kontext.py) for T2I. The optimizations discussed above apply this script, too.
provide a simple script for LoRA fine-tuning Kontext in [train_dreambooth_lora_flux_kontext.py](./train_dreambooth_lora_flux_kontext.py) for both T2I and I2I. The optimizations discussed above apply this script, too.
Make sure to follow the [instructions to set up your environment](#running-locally-with-pytorch) before proceeding to the rest of the section.
**important**
> [!NOTE]
> To make sure you can successfully run the latest version of the kontext example script, we highly recommend installing from source, specifically from the commit mentioned below.
> To do this, execute the following steps in a new virtual environment:
> ```
> git clone https://github.com/huggingface/diffusers
> cd diffusers
> git checkout 05e7a854d0a5661f5b433f6dd5954c224b104f0b
> pip install -e .
> ```
Below is an example training command:
@@ -294,6 +304,42 @@ accelerate launch train_dreambooth_lora_flux_kontext.py \
Fine-tuning Kontext on the T2I task can be useful when working with specific styles/subjects where it may not
perform as expected.
Image-guided fine-tuning (I2I) is also supported. To start, you must have a dataset containing triplets:
* Condition image
* Target image
* Instruction
[kontext-community/relighting](https://huggingface.co/datasets/kontext-community/relighting) is a good example of such a dataset. If you are using such a dataset, you can use the command below to launch training:
```bash
accelerate launch train_dreambooth_lora_flux_kontext.py \
--pretrained_model_name_or_path=black-forest-labs/FLUX.1-Kontext-dev \
--output_dir="kontext-i2i" \
--dataset_name="kontext-community/relighting" \
--image_column="output" --cond_image_column="file_name" --caption_column="instruction" \
--mixed_precision="bf16" \
--resolution=1024 \
--train_batch_size=1 \
--guidance_scale=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--optimizer="adamw" \
--use_8bit_adam \
--cache_latents \
--learning_rate=1e-4 \
--lr_scheduler="constant" \
--lr_warmup_steps=200 \
--max_train_steps=1000 \
--rank=16\
--seed="0"
```
More generally, when performing I2I fine-tuning, we expect you to:
* Have a dataset `kontext-community/relighting`
* Supply `image_column`, `cond_image_column`, and `caption_column` values when launching training
### Misc notes
* By default, we use `mode` as the value of `--vae_encode_mode` argument. This is because Kontext uses `mode()` of the distribution predicted by the VAE instead of sampling from it.
@@ -307,4 +353,4 @@ To enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a
Since Flux Kontext finetuning is still an experimental phase, we encourage you to explore different settings and share your insights! 🤗
## Other notes
Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️
Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️
@@ -40,7 +40,7 @@ from PIL.ImageOps import exif_transpose
from torch.utils.data import Dataset
from torch.utils.data.sampler import BatchSampler
from torchvision import transforms
from torchvision.transforms.functional import crop
from torchvision.transforms import functional as TF
from tqdm.auto import tqdm
from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast
@@ -62,11 +62,7 @@ from diffusers.training_utils import (
free_memory,
parse_buckets_string,
)
from diffusers.utils import (
check_min_version,
convert_unet_state_dict_to_peft,
is_wandb_available,
)
from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, is_wandb_available, load_image
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_torch_npu_available
from diffusers.utils.torch_utils import is_compiled_module
@@ -186,6 +182,7 @@ def log_validation(
)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)
pipeline_args_cp = pipeline_args.copy()
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
@@ -193,14 +190,16 @@ def log_validation(
# pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
pipeline_args["prompt"], prompt_2=pipeline_args["prompt"]
)
prompt = pipeline_args_cp.pop("prompt")
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(prompt, prompt_2=None)
images = []
for _ in range(args.num_validation_images):
with autocast_ctx:
image = pipeline(
prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator
**pipeline_args_cp,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
generator=generator,
).images[0]
images.append(image)
@@ -310,6 +309,12 @@ def parse_args(input_args=None):
"default, the standard Image Dataset maps out 'file_name' "
"to 'image'.",
)
parser.add_argument(
"--cond_image_column",
type=str,
default=None,
help="Column in the dataset containing the condition image. Must be specified when performing I2I fine-tuning",
)
parser.add_argument(
"--caption_column",
type=str,
@@ -330,7 +335,6 @@ def parse_args(input_args=None):
"--instance_prompt",
type=str,
default=None,
required=True,
help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
)
parser.add_argument(
@@ -351,6 +355,12 @@ def parse_args(input_args=None):
default=None,
help="A prompt that is used during validation to verify that the model is learning.",
)
parser.add_argument(
"--validation_image",
type=str,
default=None,
help="Validation image to use (during I2I fine-tuning) to verify that the model is learning.",
)
parser.add_argument(
"--num_validation_images",
type=int,
@@ -399,7 +409,7 @@ def parse_args(input_args=None):
parser.add_argument(
"--output_dir",
type=str,
default="flux-dreambooth-lora",
default="flux-kontext-lora",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
@@ -716,6 +726,8 @@ def parse_args(input_args=None):
raise ValueError("You must specify a data directory for class images.")
if args.class_prompt is None:
raise ValueError("You must specify prompt for class images.")
if args.cond_image_column is not None:
raise ValueError("Prior preservation isn't supported with I2I training.")
else:
# logger is not available yet
if args.class_data_dir is not None:
@@ -723,6 +735,14 @@ def parse_args(input_args=None):
if args.class_prompt is not None:
warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
if args.cond_image_column is not None:
assert args.image_column is not None
assert args.caption_column is not None
assert args.dataset_name is not None
assert not args.train_text_encoder
if args.validation_prompt is not None:
assert args.validation_image is None and os.path.exists(args.validation_image)
return args
@@ -742,6 +762,7 @@ class DreamBoothDataset(Dataset):
repeats=1,
center_crop=False,
buckets=None,
args=None,
):
self.center_crop = center_crop
@@ -774,6 +795,10 @@ class DreamBoothDataset(Dataset):
column_names = dataset["train"].column_names
# 6. Get the column names for input/target.
if args.cond_image_column is not None and args.cond_image_column not in column_names:
raise ValueError(
f"`--cond_image_column` value '{args.cond_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
if args.image_column is None:
image_column = column_names[0]
logger.info(f"image column defaulting to {image_column}")
@@ -783,7 +808,12 @@ class DreamBoothDataset(Dataset):
raise ValueError(
f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
instance_images = dataset["train"][image_column]
instance_images = [dataset["train"][i][image_column] for i in range(len(dataset["train"]))]
cond_images = None
cond_image_column = args.cond_image_column
if cond_image_column is not None:
cond_images = [dataset["train"][i][cond_image_column] for i in range(len(dataset["train"]))]
assert len(instance_images) == len(cond_images)
if args.caption_column is None:
logger.info(
@@ -811,14 +841,23 @@ class DreamBoothDataset(Dataset):
self.custom_instance_prompts = None
self.instance_images = []
for img in instance_images:
self.cond_images = []
for i, img in enumerate(instance_images):
self.instance_images.extend(itertools.repeat(img, repeats))
if args.dataset_name is not None and cond_images is not None:
self.cond_images.extend(itertools.repeat(cond_images[i], repeats))
self.pixel_values = []
for image in self.instance_images:
self.cond_pixel_values = []
for i, image in enumerate(self.instance_images):
image = exif_transpose(image)
if not image.mode == "RGB":
image = image.convert("RGB")
dest_image = None
if self.cond_images:
dest_image = exif_transpose(self.cond_images[i])
if not dest_image.mode == "RGB":
dest_image = dest_image.convert("RGB")
width, height = image.size
@@ -828,25 +867,16 @@ class DreamBoothDataset(Dataset):
self.size = (target_height, target_width)
# based on the bucket assignment, define the transformations
train_resize = transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR)
train_crop = transforms.CenterCrop(self.size) if center_crop else transforms.RandomCrop(self.size)
train_flip = transforms.RandomHorizontalFlip(p=1.0)
train_transforms = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
image, dest_image = self.paired_transform(
image,
dest_image=dest_image,
size=self.size,
center_crop=args.center_crop,
random_flip=args.random_flip,
)
image = train_resize(image)
if args.center_crop:
image = train_crop(image)
else:
y1, x1, h, w = train_crop.get_params(image, self.size)
image = crop(image, y1, x1, h, w)
if args.random_flip and random.random() < 0.5:
image = train_flip(image)
image = train_transforms(image)
self.pixel_values.append((image, bucket_idx))
if dest_image is not None:
self.cond_pixel_values.append((dest_image, bucket_idx))
self.num_instance_images = len(self.instance_images)
self._length = self.num_instance_images
@@ -880,6 +910,9 @@ class DreamBoothDataset(Dataset):
instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images]
example["instance_images"] = instance_image
example["bucket_idx"] = bucket_idx
if self.cond_pixel_values:
dest_image, _ = self.cond_pixel_values[index % self.num_instance_images]
example["cond_images"] = dest_image
if self.custom_instance_prompts:
caption = self.custom_instance_prompts[index % self.num_instance_images]
@@ -902,6 +935,43 @@ class DreamBoothDataset(Dataset):
return example
def paired_transform(self, image, dest_image=None, size=(224, 224), center_crop=False, random_flip=False):
# 1. Resize (deterministic)
resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
image = resize(image)
if dest_image is not None:
dest_image = resize(dest_image)
# 2. Crop: either center or SAME random crop
if center_crop:
crop = transforms.CenterCrop(size)
image = crop(image)
if dest_image is not None:
dest_image = crop(dest_image)
else:
# get_params returns (i, j, h, w)
i, j, h, w = transforms.RandomCrop.get_params(image, output_size=size)
image = TF.crop(image, i, j, h, w)
if dest_image is not None:
dest_image = TF.crop(dest_image, i, j, h, w)
# 3. Random horizontal flip with the SAME coin flip
if random_flip:
do_flip = random.random() < 0.5
if do_flip:
image = TF.hflip(image)
if dest_image is not None:
dest_image = TF.hflip(dest_image)
# 4. ToTensor + Normalize (deterministic)
to_tensor = transforms.ToTensor()
normalize = transforms.Normalize([0.5], [0.5])
image = normalize(to_tensor(image))
if dest_image is not None:
dest_image = normalize(to_tensor(dest_image))
return (image, dest_image) if dest_image is not None else (image, None)
def collate_fn(examples, with_prior_preservation=False):
pixel_values = [example["instance_images"] for example in examples]
@@ -917,6 +987,11 @@ def collate_fn(examples, with_prior_preservation=False):
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
batch = {"pixel_values": pixel_values, "prompts": prompts}
if any("cond_images" in example for example in examples):
cond_pixel_values = [example["cond_images"] for example in examples]
cond_pixel_values = torch.stack(cond_pixel_values)
cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float()
batch.update({"cond_pixel_values": cond_pixel_values})
return batch
@@ -1318,6 +1393,7 @@ def main(args):
"ff.net.2",
"ff_context.net.0.proj",
"ff_context.net.2",
"proj_mlp",
]
# now we will add new LoRA weights the transformer layers
@@ -1534,7 +1610,10 @@ def main(args):
buckets=buckets,
repeats=args.repeats,
center_crop=args.center_crop,
args=args,
)
if args.cond_image_column is not None:
logger.info("I2I fine-tuning enabled.")
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=False)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
@@ -1574,6 +1653,7 @@ def main(args):
# Clear the memory here
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
text_encoder_one.cpu(), text_encoder_two.cpu()
del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
free_memory()
@@ -1605,19 +1685,41 @@ def main(args):
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
elif train_dataset.custom_instance_prompts and not args.train_text_encoder:
cached_text_embeddings = []
for batch in tqdm(train_dataloader, desc="Embedding prompts"):
batch_prompts = batch["prompts"]
prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
batch_prompts, text_encoders, tokenizers
)
cached_text_embeddings.append((prompt_embeds, pooled_prompt_embeds, text_ids))
if args.validation_prompt is None:
text_encoder_one.cpu(), text_encoder_two.cpu()
del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
free_memory()
vae_config_shift_factor = vae.config.shift_factor
vae_config_scaling_factor = vae.config.scaling_factor
vae_config_block_out_channels = vae.config.block_out_channels
has_image_input = args.cond_image_column is not None
if args.cache_latents:
latents_cache = []
cond_latents_cache = []
for batch in tqdm(train_dataloader, desc="Caching latents"):
with torch.no_grad():
batch["pixel_values"] = batch["pixel_values"].to(
accelerator.device, non_blocking=True, dtype=weight_dtype
)
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
if has_image_input:
batch["cond_pixel_values"] = batch["cond_pixel_values"].to(
accelerator.device, non_blocking=True, dtype=weight_dtype
)
cond_latents_cache.append(vae.encode(batch["cond_pixel_values"]).latent_dist)
if args.validation_prompt is None:
vae.cpu()
del vae
free_memory()
@@ -1678,7 +1780,7 @@ def main(args):
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
tracker_name = "dreambooth-flux-dev-lora"
tracker_name = "dreambooth-flux-kontext-lora"
accelerator.init_trackers(tracker_name, config=vars(args))
# Train!
@@ -1742,6 +1844,7 @@ def main(args):
sigma = sigma.unsqueeze(-1)
return sigma
has_guidance = unwrap_model(transformer).config.guidance_embeds
for epoch in range(first_epoch, args.num_train_epochs):
transformer.train()
if args.train_text_encoder:
@@ -1759,9 +1862,7 @@ def main(args):
# encode batch prompts when custom prompts are provided for each image -
if train_dataset.custom_instance_prompts:
if not args.train_text_encoder:
prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
prompts, text_encoders, tokenizers
)
prompt_embeds, pooled_prompt_embeds, text_ids = cached_text_embeddings[step]
else:
tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77)
tokens_two = tokenize_prompt(
@@ -1794,16 +1895,29 @@ def main(args):
if args.cache_latents:
if args.vae_encode_mode == "sample":
model_input = latents_cache[step].sample()
if has_image_input:
cond_model_input = cond_latents_cache[step].sample()
else:
model_input = latents_cache[step].mode()
if has_image_input:
cond_model_input = cond_latents_cache[step].mode()
else:
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
if has_image_input:
cond_pixel_values = batch["cond_pixel_values"].to(dtype=vae.dtype)
if args.vae_encode_mode == "sample":
model_input = vae.encode(pixel_values).latent_dist.sample()
if has_image_input:
cond_model_input = vae.encode(cond_pixel_values).latent_dist.sample()
else:
model_input = vae.encode(pixel_values).latent_dist.mode()
if has_image_input:
cond_model_input = vae.encode(cond_pixel_values).latent_dist.mode()
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
model_input = model_input.to(dtype=weight_dtype)
if has_image_input:
cond_model_input = (cond_model_input - vae_config_shift_factor) * vae_config_scaling_factor
cond_model_input = cond_model_input.to(dtype=weight_dtype)
vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1)
@@ -1814,6 +1928,17 @@ def main(args):
accelerator.device,
weight_dtype,
)
if has_image_input:
cond_latents_ids = FluxKontextPipeline._prepare_latent_image_ids(
cond_model_input.shape[0],
cond_model_input.shape[2] // 2,
cond_model_input.shape[3] // 2,
accelerator.device,
weight_dtype,
)
cond_latents_ids[..., 0] = 1
latent_image_ids = torch.cat([latent_image_ids, cond_latents_ids], dim=0)
# Sample noise that we'll add to the latents
noise = torch.randn_like(model_input)
bsz = model_input.shape[0]
@@ -1834,7 +1959,6 @@ def main(args):
# zt = (1 - texp) * x + texp * z1
sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
packed_noisy_model_input = FluxKontextPipeline._pack_latents(
noisy_model_input,
batch_size=model_input.shape[0],
@@ -1842,13 +1966,22 @@ def main(args):
height=model_input.shape[2],
width=model_input.shape[3],
)
orig_inp_shape = packed_noisy_model_input.shape
if has_image_input:
packed_cond_input = FluxKontextPipeline._pack_latents(
cond_model_input,
batch_size=cond_model_input.shape[0],
num_channels_latents=cond_model_input.shape[1],
height=cond_model_input.shape[2],
width=cond_model_input.shape[3],
)
packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_input], dim=1)
# handle guidance
if unwrap_model(transformer).config.guidance_embeds:
# Kontext always has guidance
guidance = None
if has_guidance:
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
guidance = guidance.expand(model_input.shape[0])
else:
guidance = None
# Predict the noise residual
model_pred = transformer(
@@ -1862,6 +1995,8 @@ def main(args):
img_ids=latent_image_ids,
return_dict=False,
)[0]
if has_image_input:
model_pred = model_pred[:, : orig_inp_shape[1]]
model_pred = FluxKontextPipeline._unpack_latents(
model_pred,
height=model_input.shape[2] * vae_scale_factor,
@@ -1970,6 +2105,8 @@ def main(args):
torch_dtype=weight_dtype,
)
pipeline_args = {"prompt": args.validation_prompt}
if has_image_input and args.validation_image:
pipeline_args.update({"image": load_image(args.validation_image)})
images = log_validation(
pipeline=pipeline,
args=args,
@@ -2030,6 +2167,8 @@ def main(args):
images = []
if args.validation_prompt and args.num_validation_images > 0:
pipeline_args = {"prompt": args.validation_prompt}
if has_image_input and args.validation_image:
pipeline_args.update({"image": load_image(args.validation_image)})
images = log_validation(
pipeline=pipeline,
args=args,
@@ -837,11 +837,6 @@ def main(args):
assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
if args.train_norm_layers:
for name, param in flux_transformer.named_parameters():
if any(k in name for k in NORM_LAYER_PREFIXES):
param.requires_grad = True
if args.lora_layers is not None:
if args.lora_layers != "all-linear":
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
@@ -879,6 +874,11 @@ def main(args):
)
flux_transformer.add_adapter(transformer_lora_config)
if args.train_norm_layers:
for name, param in flux_transformer.named_parameters():
if any(k in name for k in NORM_LAYER_PREFIXES):
param.requires_grad = True
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
-1
View File
@@ -95,7 +95,6 @@ TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 = {
"mlp.layer1": "ff.net.0.proj",
"mlp.layer2": "ff.net.2",
"x_embedder.proj.1": "patch_embed.proj",
# "extra_pos_embedder": "learnable_pos_embed",
"final_layer.adaln_modulation.1": "norm_out.linear_1",
"final_layer.adaln_modulation.2": "norm_out.linear_2",
"final_layer.linear": "proj_out",
+6
View File
@@ -133,9 +133,11 @@ else:
_import_structure["hooks"].extend(
[
"FasterCacheConfig",
"FirstBlockCacheConfig",
"HookRegistry",
"PyramidAttentionBroadcastConfig",
"apply_faster_cache",
"apply_first_block_cache",
"apply_pyramid_attention_broadcast",
]
)
@@ -381,6 +383,7 @@ else:
"FluxFillPipeline",
"FluxImg2ImgPipeline",
"FluxInpaintPipeline",
"FluxKontextInpaintPipeline",
"FluxKontextPipeline",
"FluxPipeline",
"FluxPriorReduxPipeline",
@@ -750,9 +753,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .hooks import (
FasterCacheConfig,
FirstBlockCacheConfig,
HookRegistry,
PyramidAttentionBroadcastConfig,
apply_faster_cache,
apply_first_block_cache,
apply_pyramid_attention_broadcast,
)
from .models import (
@@ -975,6 +980,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxFillPipeline,
FluxImg2ImgPipeline,
FluxInpaintPipeline,
FluxKontextInpaintPipeline,
FluxKontextPipeline,
FluxPipeline,
FluxPriorReduxPipeline,
+15
View File
@@ -1,8 +1,23 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils import is_torch_available
if is_torch_available():
from .faster_cache import FasterCacheConfig, apply_faster_cache
from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
from .group_offloading import apply_group_offloading
from .hooks import HookRegistry, ModelHook
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
+30
View File
@@ -0,0 +1,30 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..models.attention_processor import Attention, MochiAttention
_ATTENTION_CLASSES = (Attention, MochiAttention)
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
{
*_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
*_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
*_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
}
)
+264
View File
@@ -0,0 +1,264 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from dataclasses import dataclass
from typing import Any, Callable, Dict, Type
@dataclass
class AttentionProcessorMetadata:
skip_processor_output_fn: Callable[[Any], Any]
@dataclass
class TransformerBlockMetadata:
return_hidden_states_index: int = None
return_encoder_hidden_states_index: int = None
_cls: Type = None
_cached_parameter_indices: Dict[str, int] = None
def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
kwargs = kwargs or {}
if identifier in kwargs:
return kwargs[identifier]
if self._cached_parameter_indices is not None:
return args[self._cached_parameter_indices[identifier]]
if self._cls is None:
raise ValueError("Model class is not set for metadata.")
parameters = list(inspect.signature(self._cls.forward).parameters.keys())
parameters = parameters[1:] # skip `self`
self._cached_parameter_indices = {param: i for i, param in enumerate(parameters)}
if identifier not in self._cached_parameter_indices:
raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")
index = self._cached_parameter_indices[identifier]
if index >= len(args):
raise ValueError(f"Expected {index} arguments but got {len(args)}.")
return args[index]
class AttentionProcessorRegistry:
_registry = {}
# TODO(aryan): this is only required for the time being because we need to do the registrations
# for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
# import errors because of the models imported in this file.
_is_registered = False
@classmethod
def register(cls, model_class: Type, metadata: AttentionProcessorMetadata):
cls._register()
cls._registry[model_class] = metadata
@classmethod
def get(cls, model_class: Type) -> AttentionProcessorMetadata:
cls._register()
if model_class not in cls._registry:
raise ValueError(f"Model class {model_class} not registered.")
return cls._registry[model_class]
@classmethod
def _register(cls):
if cls._is_registered:
return
cls._is_registered = True
_register_attention_processors_metadata()
class TransformerBlockRegistry:
_registry = {}
# TODO(aryan): this is only required for the time being because we need to do the registrations
# for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
# import errors because of the models imported in this file.
_is_registered = False
@classmethod
def register(cls, model_class: Type, metadata: TransformerBlockMetadata):
cls._register()
metadata._cls = model_class
cls._registry[model_class] = metadata
@classmethod
def get(cls, model_class: Type) -> TransformerBlockMetadata:
cls._register()
if model_class not in cls._registry:
raise ValueError(f"Model class {model_class} not registered.")
return cls._registry[model_class]
@classmethod
def _register(cls):
if cls._is_registered:
return
cls._is_registered = True
_register_transformer_blocks_metadata()
def _register_attention_processors_metadata():
from ..models.attention_processor import AttnProcessor2_0
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
# AttnProcessor2_0
AttentionProcessorRegistry.register(
model_class=AttnProcessor2_0,
metadata=AttentionProcessorMetadata(
skip_processor_output_fn=_skip_proc_output_fn_Attention_AttnProcessor2_0,
),
)
# CogView4AttnProcessor
AttentionProcessorRegistry.register(
model_class=CogView4AttnProcessor,
metadata=AttentionProcessorMetadata(
skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor,
),
)
def _register_transformer_blocks_metadata():
from ..models.attention import BasicTransformerBlock
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock
from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
from ..models.transformers.transformer_hunyuan_video import (
HunyuanVideoSingleTransformerBlock,
HunyuanVideoTokenReplaceSingleTransformerBlock,
HunyuanVideoTokenReplaceTransformerBlock,
HunyuanVideoTransformerBlock,
)
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
from ..models.transformers.transformer_mochi import MochiTransformerBlock
from ..models.transformers.transformer_wan import WanTransformerBlock
# BasicTransformerBlock
TransformerBlockRegistry.register(
model_class=BasicTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
),
)
# CogVideoX
TransformerBlockRegistry.register(
model_class=CogVideoXBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# CogView4
TransformerBlockRegistry.register(
model_class=CogView4TransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# Flux
TransformerBlockRegistry.register(
model_class=FluxTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=1,
return_encoder_hidden_states_index=0,
),
)
TransformerBlockRegistry.register(
model_class=FluxSingleTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=1,
return_encoder_hidden_states_index=0,
),
)
# HunyuanVideo
TransformerBlockRegistry.register(
model_class=HunyuanVideoTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
TransformerBlockRegistry.register(
model_class=HunyuanVideoSingleTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
TransformerBlockRegistry.register(
model_class=HunyuanVideoTokenReplaceTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
TransformerBlockRegistry.register(
model_class=HunyuanVideoTokenReplaceSingleTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# LTXVideo
TransformerBlockRegistry.register(
model_class=LTXVideoTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
),
)
# Mochi
TransformerBlockRegistry.register(
model_class=MochiTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# Wan
TransformerBlockRegistry.register(
model_class=WanTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
),
)
# fmt: off
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
return hidden_states
def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
if encoder_hidden_states is None and len(args) > 1:
encoder_hidden_states = args[1]
return hidden_states, encoder_hidden_states
_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
# fmt: on
+227
View File
@@ -0,0 +1,227 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Tuple, Union
import torch
from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
from ._helpers import TransformerBlockRegistry
from .hooks import BaseState, HookRegistry, ModelHook, StateManager
logger = get_logger(__name__) # pylint: disable=invalid-name
_FBC_LEADER_BLOCK_HOOK = "fbc_leader_block_hook"
_FBC_BLOCK_HOOK = "fbc_block_hook"
@dataclass
class FirstBlockCacheConfig:
r"""
Configuration for [First Block
Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching).
Args:
threshold (`float`, defaults to `0.05`):
The threshold to determine whether or not a forward pass through all layers of the model is required. A
higher threshold usually results in a forward pass through a lower number of layers and faster inference,
but might lead to poorer generation quality. A lower threshold may not result in significant generation
speedup. The threshold is compared against the absmean difference of the residuals between the current and
cached outputs from the first transformer block. If the difference is below the threshold, the forward pass
is skipped.
"""
threshold: float = 0.05
class FBCSharedBlockState(BaseState):
def __init__(self) -> None:
super().__init__()
self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
self.head_block_residual: torch.Tensor = None
self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
self.should_compute: bool = True
def reset(self):
self.tail_block_residuals = None
self.should_compute = True
class FBCHeadBlockHook(ModelHook):
_is_stateful = True
def __init__(self, state_manager: StateManager, threshold: float):
self.state_manager = state_manager
self.threshold = threshold
self._metadata = None
def initialize_hook(self, module):
unwrapped_module = unwrap_module(module)
self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
output = self.fn_ref.original_forward(*args, **kwargs)
is_output_tuple = isinstance(output, tuple)
if is_output_tuple:
hidden_states_residual = output[self._metadata.return_hidden_states_index] - original_hidden_states
else:
hidden_states_residual = output - original_hidden_states
shared_state: FBCSharedBlockState = self.state_manager.get_state()
hidden_states = encoder_hidden_states = None
should_compute = self._should_compute_remaining_blocks(hidden_states_residual)
shared_state.should_compute = should_compute
if not should_compute:
# Apply caching
if is_output_tuple:
hidden_states = (
shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index]
)
else:
hidden_states = shared_state.tail_block_residuals[0] + output
if self._metadata.return_encoder_hidden_states_index is not None:
assert is_output_tuple
encoder_hidden_states = (
shared_state.tail_block_residuals[1] + output[self._metadata.return_encoder_hidden_states_index]
)
if is_output_tuple:
return_output = [None] * len(output)
return_output[self._metadata.return_hidden_states_index] = hidden_states
return_output[self._metadata.return_encoder_hidden_states_index] = encoder_hidden_states
return_output = tuple(return_output)
else:
return_output = hidden_states
output = return_output
else:
if is_output_tuple:
head_block_output = [None] * len(output)
head_block_output[0] = output[self._metadata.return_hidden_states_index]
head_block_output[1] = output[self._metadata.return_encoder_hidden_states_index]
else:
head_block_output = output
shared_state.head_block_output = head_block_output
shared_state.head_block_residual = hidden_states_residual
return output
def reset_state(self, module):
self.state_manager.reset()
return module
@torch.compiler.disable
def _should_compute_remaining_blocks(self, hidden_states_residual: torch.Tensor) -> bool:
shared_state = self.state_manager.get_state()
if shared_state.head_block_residual is None:
return True
prev_hidden_states_residual = shared_state.head_block_residual
absmean = (hidden_states_residual - prev_hidden_states_residual).abs().mean()
prev_hidden_states_absmean = prev_hidden_states_residual.abs().mean()
diff = (absmean / prev_hidden_states_absmean).item()
return diff > self.threshold
class FBCBlockHook(ModelHook):
def __init__(self, state_manager: StateManager, is_tail: bool = False):
super().__init__()
self.state_manager = state_manager
self.is_tail = is_tail
self._metadata = None
def initialize_hook(self, module):
unwrapped_module = unwrap_module(module)
self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
original_encoder_hidden_states = None
if self._metadata.return_encoder_hidden_states_index is not None:
original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
"encoder_hidden_states", args, kwargs
)
shared_state = self.state_manager.get_state()
if shared_state.should_compute:
output = self.fn_ref.original_forward(*args, **kwargs)
if self.is_tail:
hidden_states_residual = encoder_hidden_states_residual = None
if isinstance(output, tuple):
hidden_states_residual = (
output[self._metadata.return_hidden_states_index] - shared_state.head_block_output[0]
)
encoder_hidden_states_residual = (
output[self._metadata.return_encoder_hidden_states_index] - shared_state.head_block_output[1]
)
else:
hidden_states_residual = output - shared_state.head_block_output
shared_state.tail_block_residuals = (hidden_states_residual, encoder_hidden_states_residual)
return output
if original_encoder_hidden_states is None:
return_output = original_hidden_states
else:
return_output = [None, None]
return_output[self._metadata.return_hidden_states_index] = original_hidden_states
return_output[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states
return_output = tuple(return_output)
return return_output
def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None:
state_manager = StateManager(FBCSharedBlockState, (), {})
remaining_blocks = []
for name, submodule in module.named_children():
if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList):
continue
for index, block in enumerate(submodule):
remaining_blocks.append((f"{name}.{index}", block))
head_block_name, head_block = remaining_blocks.pop(0)
tail_block_name, tail_block = remaining_blocks.pop(-1)
logger.debug(f"Applying FBCHeadBlockHook to '{head_block_name}'")
_apply_fbc_head_block_hook(head_block, state_manager, config.threshold)
for name, block in remaining_blocks:
logger.debug(f"Applying FBCBlockHook to '{name}'")
_apply_fbc_block_hook(block, state_manager)
logger.debug(f"Applying FBCBlockHook to tail block '{tail_block_name}'")
_apply_fbc_block_hook(tail_block, state_manager, is_tail=True)
def _apply_fbc_head_block_hook(block: torch.nn.Module, state_manager: StateManager, threshold: float) -> None:
registry = HookRegistry.check_if_exists_or_initialize(block)
hook = FBCHeadBlockHook(state_manager, threshold)
registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK)
def _apply_fbc_block_hook(block: torch.nn.Module, state_manager: StateManager, is_tail: bool = False) -> None:
registry = HookRegistry.check_if_exists_or_initialize(block)
hook = FBCBlockHook(state_manager, is_tail)
registry.register_hook(hook, _FBC_BLOCK_HOOK)
+139 -152
View File
@@ -14,6 +14,8 @@
import os
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Optional, Set, Tuple, Union
import safetensors.torch
@@ -46,6 +48,24 @@ _SUPPORTED_PYTORCH_LAYERS = (
# fmt: on
class GroupOffloadingType(str, Enum):
BLOCK_LEVEL = "block_level"
LEAF_LEVEL = "leaf_level"
@dataclass
class GroupOffloadingConfig:
onload_device: torch.device
offload_device: torch.device
offload_type: GroupOffloadingType
non_blocking: bool
record_stream: bool
low_cpu_mem_usage: bool
num_blocks_per_group: Optional[int] = None
offload_to_disk_path: Optional[str] = None
stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None
class ModuleGroup:
def __init__(
self,
@@ -288,9 +308,12 @@ class GroupOffloadingHook(ModelHook):
_is_stateful = False
def __init__(self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None) -> None:
def __init__(
self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None, *, config: GroupOffloadingConfig
) -> None:
self.group = group
self.next_group = next_group
self.config = config
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
if self.group.offload_leader == module:
@@ -436,7 +459,7 @@ def apply_group_offloading(
module: torch.nn.Module,
onload_device: torch.device,
offload_device: torch.device = torch.device("cpu"),
offload_type: str = "block_level",
offload_type: Union[str, GroupOffloadingType] = "block_level",
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
use_stream: bool = False,
@@ -478,7 +501,7 @@ def apply_group_offloading(
The device to which the group of modules are onloaded.
offload_device (`torch.device`, defaults to `torch.device("cpu")`):
The device to which the group of modules are offloaded. This should typically be the CPU. Default is CPU.
offload_type (`str`, defaults to "block_level"):
offload_type (`str` or `GroupOffloadingType`, defaults to "block_level"):
The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
"block_level".
offload_to_disk_path (`str`, *optional*, defaults to `None`):
@@ -521,6 +544,8 @@ def apply_group_offloading(
```
"""
offload_type = GroupOffloadingType(offload_type)
stream = None
if use_stream:
if torch.cuda.is_available():
@@ -532,84 +557,45 @@ def apply_group_offloading(
if not use_stream and record_stream:
raise ValueError("`record_stream` cannot be True when `use_stream=False`.")
if offload_type == GroupOffloadingType.BLOCK_LEVEL and num_blocks_per_group is None:
raise ValueError("`num_blocks_per_group` must be provided when using `offload_type='block_level'.")
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
if offload_type == "block_level":
if num_blocks_per_group is None:
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
config = GroupOffloadingConfig(
onload_device=onload_device,
offload_device=offload_device,
offload_type=offload_type,
num_blocks_per_group=num_blocks_per_group,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
offload_to_disk_path=offload_to_disk_path,
)
_apply_group_offloading(module, config)
_apply_group_offloading_block_level(
module=module,
num_blocks_per_group=num_blocks_per_group,
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
)
elif offload_type == "leaf_level":
_apply_group_offloading_leaf_level(
module=module,
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
)
def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
if config.offload_type == GroupOffloadingType.BLOCK_LEVEL:
_apply_group_offloading_block_level(module, config)
elif config.offload_type == GroupOffloadingType.LEAF_LEVEL:
_apply_group_offloading_leaf_level(module, config)
else:
raise ValueError(f"Unsupported offload_type: {offload_type}")
assert False
def _apply_group_offloading_block_level(
module: torch.nn.Module,
num_blocks_per_group: int,
offload_device: torch.device,
onload_device: torch.device,
non_blocking: bool,
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
offload_to_disk_path: Optional[str] = None,
) -> None:
def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
r"""
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks.
Args:
module (`torch.nn.Module`):
The module to which group offloading is applied.
offload_device (`torch.device`):
The device to which the group of modules are offloaded. This should typically be the CPU.
offload_to_disk_path (`str`, *optional*, defaults to `None`):
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
RAM environment settings where a reasonable speed-memory trade-off is desired.
onload_device (`torch.device`):
The device to which the group of modules are onloaded.
non_blocking (`bool`):
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
and data transfer.
stream (`torch.cuda.Stream`or `torch.Stream`, *optional*):
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
for overlapping computation and data transfer.
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
details.
low_cpu_mem_usage (`bool`, defaults to `False`):
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
the CPU memory is a bottleneck but may counteract the benefits of using streams.
"""
if stream is not None and num_blocks_per_group != 1:
if config.stream is not None and config.num_blocks_per_group != 1:
logger.warning(
f"Using streams is only supported for num_blocks_per_group=1. Got {num_blocks_per_group=}. Setting it to 1."
f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1."
)
num_blocks_per_group = 1
config.num_blocks_per_group = 1
# Create module groups for ModuleList and Sequential blocks
modules_with_group_offloading = set()
@@ -621,19 +607,19 @@ def _apply_group_offloading_block_level(
modules_with_group_offloading.add(name)
continue
for i in range(0, len(submodule), num_blocks_per_group):
current_modules = submodule[i : i + num_blocks_per_group]
for i in range(0, len(submodule), config.num_blocks_per_group):
current_modules = submodule[i : i + config.num_blocks_per_group]
group = ModuleGroup(
modules=current_modules,
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_leader=current_modules[-1],
onload_leader=current_modules[0],
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
non_blocking=config.non_blocking,
stream=config.stream,
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
)
matched_module_groups.append(group)
@@ -643,7 +629,7 @@ def _apply_group_offloading_block_level(
# Apply group offloading hooks to the module groups
for i, group in enumerate(matched_module_groups):
for group_module in group.modules:
_apply_group_offloading_hook(group_module, group, None)
_apply_group_offloading_hook(group_module, group, None, config=config)
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
# when the forward pass of this module is called. This is because the top-level module is not
@@ -658,9 +644,9 @@ def _apply_group_offloading_block_level(
unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
unmatched_group = ModuleGroup(
modules=unmatched_modules,
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_leader=module,
onload_leader=module,
parameters=parameters,
@@ -670,54 +656,19 @@ def _apply_group_offloading_block_level(
record_stream=False,
onload_self=True,
)
if stream is None:
_apply_group_offloading_hook(module, unmatched_group, None)
if config.stream is None:
_apply_group_offloading_hook(module, unmatched_group, None, config=config)
else:
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
def _apply_group_offloading_leaf_level(
module: torch.nn.Module,
offload_device: torch.device,
onload_device: torch.device,
non_blocking: bool,
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
offload_to_disk_path: Optional[str] = None,
) -> None:
def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
r"""
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
requirements. However, it can be slower compared to other offloading methods due to the excessive number of device
synchronizations. When using devices that support streams to overlap data transfer and computation, this method can
reduce memory usage without any performance degradation.
Args:
module (`torch.nn.Module`):
The module to which group offloading is applied.
offload_device (`torch.device`):
The device to which the group of modules are offloaded. This should typically be the CPU.
onload_device (`torch.device`):
The device to which the group of modules are onloaded.
offload_to_disk_path (`str`, *optional*, defaults to `None`):
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
RAM environment settings where a reasonable speed-memory trade-off is desired.
non_blocking (`bool`):
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
and data transfer.
stream (`torch.cuda.Stream` or `torch.Stream`, *optional*):
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
for overlapping computation and data transfer.
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
details.
low_cpu_mem_usage (`bool`, defaults to `False`):
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
the CPU memory is a bottleneck but may counteract the benefits of using streams.
"""
# Create module groups for leaf modules and apply group offloading hooks
modules_with_group_offloading = set()
for name, submodule in module.named_modules():
@@ -725,18 +676,18 @@ def _apply_group_offloading_leaf_level(
continue
group = ModuleGroup(
modules=[submodule],
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_leader=submodule,
onload_leader=submodule,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
non_blocking=config.non_blocking,
stream=config.stream,
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
)
_apply_group_offloading_hook(submodule, group, None)
_apply_group_offloading_hook(submodule, group, None, config=config)
modules_with_group_offloading.add(name)
# Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
@@ -767,33 +718,32 @@ def _apply_group_offloading_leaf_level(
parameters = parent_to_parameters.get(name, [])
buffers = parent_to_buffers.get(name, [])
parent_module = module_dict[name]
assert getattr(parent_module, "_diffusers_hook", None) is None
group = ModuleGroup(
modules=[],
offload_device=offload_device,
onload_device=onload_device,
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_leader=parent_module,
onload_leader=parent_module,
offload_to_disk_path=offload_to_disk_path,
offload_to_disk_path=config.offload_to_disk_path,
parameters=parameters,
buffers=buffers,
non_blocking=non_blocking,
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
non_blocking=config.non_blocking,
stream=config.stream,
record_stream=config.record_stream,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
)
_apply_group_offloading_hook(parent_module, group, None)
_apply_group_offloading_hook(parent_module, group, None, config=config)
if stream is not None:
if config.stream is not None:
# When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
# and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the
# execution order and apply prefetching in the correct order.
unmatched_group = ModuleGroup(
modules=[],
offload_device=offload_device,
onload_device=onload_device,
offload_to_disk_path=offload_to_disk_path,
offload_device=config.offload_device,
onload_device=config.onload_device,
offload_to_disk_path=config.offload_to_disk_path,
offload_leader=module,
onload_leader=module,
parameters=None,
@@ -801,23 +751,25 @@ def _apply_group_offloading_leaf_level(
non_blocking=False,
stream=None,
record_stream=False,
low_cpu_mem_usage=low_cpu_mem_usage,
low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
)
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
def _apply_group_offloading_hook(
module: torch.nn.Module,
group: ModuleGroup,
next_group: Optional[ModuleGroup] = None,
*,
config: GroupOffloadingConfig,
) -> None:
registry = HookRegistry.check_if_exists_or_initialize(module)
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
if registry.get_hook(_GROUP_OFFLOADING) is None:
hook = GroupOffloadingHook(group, next_group)
hook = GroupOffloadingHook(group, next_group, config=config)
registry.register_hook(hook, _GROUP_OFFLOADING)
@@ -825,13 +777,15 @@ def _apply_lazy_group_offloading_hook(
module: torch.nn.Module,
group: ModuleGroup,
next_group: Optional[ModuleGroup] = None,
*,
config: GroupOffloadingConfig,
) -> None:
registry = HookRegistry.check_if_exists_or_initialize(module)
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
if registry.get_hook(_GROUP_OFFLOADING) is None:
hook = GroupOffloadingHook(group, next_group)
hook = GroupOffloadingHook(group, next_group, config=config)
registry.register_hook(hook, _GROUP_OFFLOADING)
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()
@@ -898,15 +852,48 @@ def _raise_error_if_accelerate_model_or_sequential_hook_present(module: torch.nn
)
def _is_group_offload_enabled(module: torch.nn.Module) -> bool:
def _get_top_level_group_offload_hook(module: torch.nn.Module) -> Optional[GroupOffloadingHook]:
for submodule in module.modules():
if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
return True
return False
if hasattr(submodule, "_diffusers_hook"):
group_offloading_hook = submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING)
if group_offloading_hook is not None:
return group_offloading_hook
return None
def _is_group_offload_enabled(module: torch.nn.Module) -> bool:
top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
return top_level_group_offload_hook is not None
def _get_group_onload_device(module: torch.nn.Module) -> torch.device:
for submodule in module.modules():
if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
return submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING).group.onload_device
top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
if top_level_group_offload_hook is not None:
return top_level_group_offload_hook.config.onload_device
raise ValueError("Group offloading is not enabled for the provided module.")
def _maybe_remove_and_reapply_group_offloading(module: torch.nn.Module) -> None:
r"""
Removes the group offloading hook from the module and re-applies it. This is useful when the module has been
modified in-place and the group offloading hook references-to-tensors needs to be updated. The in-place
modification can happen in a number of ways, for example, fusing QKV or unloading/loading LoRAs on-the-fly.
In this implementation, we make an assumption that group offloading has only been applied at the top-level module,
and therefore all submodules have the same onload and offload devices. If this assumption is not true, say in the
case where user has applied group offloading at multiple levels, this function will not work as expected.
There is some performance penalty associated with doing this when non-default streams are used, because we need to
retrace the execution order of the layers with `LazyPrefetchGroupOffloadingHook`.
"""
top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
if top_level_group_offload_hook is None:
return
registry = HookRegistry.check_if_exists_or_initialize(module)
registry.remove_hook(_GROUP_OFFLOADING, recurse=True)
registry.remove_hook(_LAYER_EXECUTION_TRACKER, recurse=True)
registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=True)
_apply_group_offloading(module, top_level_group_offload_hook.config)
+56 -1
View File
@@ -18,11 +18,44 @@ from typing import Any, Dict, Optional, Tuple
import torch
from ..utils.logging import get_logger
from ..utils.torch_utils import unwrap_module
logger = get_logger(__name__) # pylint: disable=invalid-name
class BaseState:
def reset(self, *args, **kwargs) -> None:
raise NotImplementedError(
"BaseState::reset is not implemented. Please implement this method in the derived class."
)
class StateManager:
def __init__(self, state_cls: BaseState, init_args=None, init_kwargs=None):
self._state_cls = state_cls
self._init_args = init_args if init_args is not None else ()
self._init_kwargs = init_kwargs if init_kwargs is not None else {}
self._state_cache = {}
self._current_context = None
def get_state(self):
if self._current_context is None:
raise ValueError("No context is set. Please set a context before retrieving the state.")
if self._current_context not in self._state_cache.keys():
self._state_cache[self._current_context] = self._state_cls(*self._init_args, **self._init_kwargs)
return self._state_cache[self._current_context]
def set_context(self, name: str) -> None:
self._current_context = name
def reset(self, *args, **kwargs) -> None:
for name, state in list(self._state_cache.items()):
state.reset(*args, **kwargs)
self._state_cache.pop(name)
self._current_context = None
class ModelHook:
r"""
A hook that contains callbacks to be executed just before and after the forward method of a model.
@@ -99,6 +132,14 @@ class ModelHook:
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
return module
def _set_context(self, module: torch.nn.Module, name: str) -> None:
# Iterate over all attributes of the hook to see if any of them have the type `StateManager`. If so, call `set_context` on them.
for attr_name in dir(self):
attr = getattr(self, attr_name)
if isinstance(attr, StateManager):
attr.set_context(name)
return module
class HookFunctionReference:
def __init__(self) -> None:
@@ -211,9 +252,10 @@ class HookRegistry:
hook.reset_state(self._module_ref)
if recurse:
for module_name, module in self._module_ref.named_modules():
for module_name, module in unwrap_module(self._module_ref).named_modules():
if module_name == "":
continue
module = unwrap_module(module)
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook.reset_stateful_hooks(recurse=False)
@@ -223,6 +265,19 @@ class HookRegistry:
module._diffusers_hook = cls(module)
return module._diffusers_hook
def _set_context(self, name: Optional[str] = None) -> None:
for hook_name in reversed(self._hook_order):
hook = self.hooks[hook_name]
if hook._is_stateful:
hook._set_context(self._module_ref, name)
for module_name, module in unwrap_module(self._module_ref).named_modules():
if module_name == "":
continue
module = unwrap_module(module)
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook._set_context(name)
def __repr__(self) -> str:
registry_repr = ""
for i, hook_name in enumerate(self._hook_order):
+55 -17
View File
@@ -25,6 +25,7 @@ import torch.nn as nn
from huggingface_hub import model_info
from huggingface_hub.constants import HF_HUB_OFFLINE
from ..hooks.group_offloading import _is_group_offload_enabled, _maybe_remove_and_reapply_group_offloading
from ..models.modeling_utils import ModelMixin, load_state_dict
from ..utils import (
USE_PEFT_BACKEND,
@@ -391,7 +392,9 @@ def _load_lora_into_text_encoder(
adapter_name = get_adapter_name(text_encoder)
# <Unsafe code
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = _func_optionally_disable_offloading(
_pipeline
)
# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
@@ -410,6 +413,10 @@ def _load_lora_into_text_encoder(
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
elif is_group_offload:
for component in _pipeline.components.values():
if isinstance(component, torch.nn.Module):
_maybe_remove_and_reapply_group_offloading(component)
# Unsafe code />
if prefix is not None and not state_dict:
@@ -433,30 +440,36 @@ def _func_optionally_disable_offloading(_pipeline):
Returns:
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` or `is_group_offload` is True.
"""
is_model_cpu_offload = False
is_sequential_cpu_offload = False
is_group_offload = False
if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
is_sequential_cpu_offload = (
isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
if not isinstance(component, nn.Module):
continue
is_group_offload = is_group_offload or _is_group_offload_enabled(component)
if not hasattr(component, "_hf_hook"):
continue
is_model_cpu_offload = is_model_cpu_offload or isinstance(component._hf_hook, CpuOffload)
is_sequential_cpu_offload = is_sequential_cpu_offload or (
isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
if is_sequential_cpu_offload or is_model_cpu_offload:
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
if is_sequential_cpu_offload or is_model_cpu_offload:
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
for _, component in _pipeline.components.items():
if not isinstance(component, nn.Module) or not hasattr(component, "_hf_hook"):
continue
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload)
class LoraBaseMixin:
@@ -921,6 +934,27 @@ class LoraBaseMixin:
Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
you want to load multiple adapters and free some GPU memory.
After offloading the LoRA adapters to CPU, as long as the rest of the model is still on GPU, the LoRA adapters
can no longer be used for inference, as that would cause a device mismatch. Remember to set the device back to
GPU before using those LoRA adapters for inference.
```python
>>> pipe.load_lora_weights(path_1, adapter_name="adapter-1")
>>> pipe.load_lora_weights(path_2, adapter_name="adapter-2")
>>> pipe.set_adapters("adapter-1")
>>> image_1 = pipe(**kwargs)
>>> # switch to adapter-2, offload adapter-1
>>> pipeline.set_lora_device(adapter_names=["adapter-1"], device="cpu")
>>> pipeline.set_lora_device(adapter_names=["adapter-2"], device="cuda:0")
>>> pipe.set_adapters("adapter-2")
>>> image_2 = pipe(**kwargs)
>>> # switch back to adapter-1, offload adapter-2
>>> pipeline.set_lora_device(adapter_names=["adapter-2"], device="cpu")
>>> pipeline.set_lora_device(adapter_names=["adapter-1"], device="cuda:0")
>>> pipe.set_adapters("adapter-1")
>>> ...
```
Args:
adapter_names (`List[str]`):
List of adapters to send device to.
@@ -936,6 +970,10 @@ class LoraBaseMixin:
for module in model.modules():
if isinstance(module, BaseTunerLayer):
for adapter_name in adapter_names:
if adapter_name not in module.lora_A:
# it is sufficient to check lora_A
continue
module.lora_A[adapter_name].to(device)
module.lora_B[adapter_name].to(device)
# this is a param, not a module, so device placement is not in-place -> re-assign
+237 -17
View File
@@ -1346,6 +1346,228 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
return converted_state_dict
def _convert_fal_kontext_lora_to_diffusers(original_state_dict):
converted_state_dict = {}
original_state_dict_keys = list(original_state_dict.keys())
num_layers = 19
num_single_layers = 38
inner_dim = 3072
mlp_ratio = 4.0
# double transformer blocks
for i in range(num_layers):
block_prefix = f"transformer_blocks.{i}."
original_block_prefix = "base_model.model."
for lora_key in ["lora_A", "lora_B"]:
# norms
converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.weight"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.img_mod.lin.{lora_key}.weight"
)
if f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.bias"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.img_mod.lin.{lora_key}.bias"
)
converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.weight"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.txt_mod.lin.{lora_key}.weight"
)
# Q, K, V
if lora_key == "lora_A":
sample_lora_weight = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.weight"
)
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_lora_weight])
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_lora_weight])
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_lora_weight])
context_lora_weight = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"
)
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat(
[context_lora_weight]
)
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat(
[context_lora_weight]
)
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat(
[context_lora_weight]
)
else:
sample_q, sample_k, sample_v = torch.chunk(
original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.weight"
),
3,
dim=0,
)
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_q])
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_k])
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_v])
context_q, context_k, context_v = torch.chunk(
original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"
),
3,
dim=0,
)
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat([context_q])
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat([context_k])
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat([context_v])
if f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
original_state_dict.pop(f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.bias"),
3,
dim=0,
)
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([sample_q_bias])
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([sample_k_bias])
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([sample_v_bias])
if f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
original_state_dict.pop(f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.bias"),
3,
dim=0,
)
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.bias"] = torch.cat([context_q_bias])
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.bias"] = torch.cat([context_k_bias])
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.bias"] = torch.cat([context_v_bias])
# ff img_mlp
converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.weight"
)
if f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.bias"
)
converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.weight"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.weight"
)
if f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.bias"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.bias"
)
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.weight"
)
if f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.bias"
)
converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.weight"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.weight"
)
if f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.bias"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.bias"
)
# output projections.
converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.weight"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.weight"
)
if f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.bias"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.bias"
)
converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.weight"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.weight"
)
if f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.bias"] = original_state_dict.pop(
f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.bias"
)
# single transformer blocks
for i in range(num_single_layers):
block_prefix = f"single_transformer_blocks.{i}."
for lora_key in ["lora_A", "lora_B"]:
# norm.linear <- single_blocks.0.modulation.lin
converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.weight"] = original_state_dict.pop(
f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.weight"
)
if f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.bias"] = original_state_dict.pop(
f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.bias"
)
# Q, K, V, mlp
mlp_hidden_dim = int(inner_dim * mlp_ratio)
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
if lora_key == "lora_A":
lora_weight = original_state_dict.pop(
f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.weight"
)
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([lora_weight])
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([lora_weight])
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([lora_weight])
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([lora_weight])
if f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
lora_bias = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias")
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([lora_bias])
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([lora_bias])
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([lora_bias])
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([lora_bias])
else:
q, k, v, mlp = torch.split(
original_state_dict.pop(f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.weight"),
split_size,
dim=0,
)
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([q])
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([k])
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([v])
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([mlp])
if f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
q_bias, k_bias, v_bias, mlp_bias = torch.split(
original_state_dict.pop(f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias"),
split_size,
dim=0,
)
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([q_bias])
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([k_bias])
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([v_bias])
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([mlp_bias])
# output projections.
converted_state_dict[f"{block_prefix}proj_out.{lora_key}.weight"] = original_state_dict.pop(
f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.weight"
)
if f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"{block_prefix}proj_out.{lora_key}.bias"] = original_state_dict.pop(
f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.bias"
)
for lora_key in ["lora_A", "lora_B"]:
converted_state_dict[f"proj_out.{lora_key}.weight"] = original_state_dict.pop(
f"{original_block_prefix}final_layer.linear.{lora_key}.weight"
)
if f"{original_block_prefix}final_layer.linear.{lora_key}.bias" in original_state_dict_keys:
converted_state_dict[f"proj_out.{lora_key}.bias"] = original_state_dict.pop(
f"{original_block_prefix}final_layer.linear.{lora_key}.bias"
)
if len(original_state_dict) > 0:
raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")
for key in list(converted_state_dict.keys()):
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
return converted_state_dict
def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
converted_state_dict = {k: original_state_dict.pop(k) for k in list(original_state_dict.keys())}
@@ -1603,24 +1825,22 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
lora_down_key = "lora_A" if any("lora_A" in k for k in original_state_dict) else "lora_down"
lora_up_key = "lora_B" if any("lora_B" in k for k in original_state_dict) else "lora_up"
has_time_projection_weight = any(
k.startswith("time_projection") and k.endswith(".weight") for k in original_state_dict
)
diff_keys = [k for k in original_state_dict if k.endswith((".diff_b", ".diff"))]
if diff_keys:
for diff_k in diff_keys:
param = original_state_dict[diff_k]
# The magnitudes of the .diff-ending weights are very low (most are below 1e-4, some are upto 1e-3,
# and 2 of them are about 1.6e-2 [the case with AccVideo lora]). The low magnitudes mostly correspond
# to norm layers. Ignoring them is the best option at the moment until a better solution is found. It
# is okay to ignore because they do not affect the model output in a significant manner.
threshold = 1.6e-2
absdiff = param.abs().max() - param.abs().min()
all_zero = torch.all(param == 0).item()
all_absdiff_lower_than_threshold = absdiff < threshold
if all_zero or all_absdiff_lower_than_threshold:
logger.debug(
f"Removed {diff_k} key from the state dict as it's all zeros, or values lower than hardcoded threshold."
)
original_state_dict.pop(diff_k)
for key in list(original_state_dict.keys()):
if key.endswith((".diff", ".diff_b")) and "norm" in key:
# NOTE: we don't support this because norm layer diff keys are just zeroed values. We can support it
# in future if needed and they are not zeroed.
original_state_dict.pop(key)
logger.debug(f"Removing {key} key from the state dict as it is a norm diff key. This is unsupported.")
if "time_projection" in key and not has_time_projection_weight:
# AccVideo lora has diff bias keys but not the weight keys. This causes a weird problem where
# our lora config adds the time proj lora layers, but we don't have the weights for them.
# CausVid lora has the weight keys and the bias keys.
original_state_dict.pop(key)
# For the `diff_b` keys, we treat them as lora_bias.
# https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias
+12
View File
@@ -41,6 +41,7 @@ from .lora_base import ( # noqa
)
from .lora_conversion_utils import (
_convert_bfl_flux_control_lora_to_diffusers,
_convert_fal_kontext_lora_to_diffusers,
_convert_hunyuan_video_lora_to_diffusers,
_convert_kohya_flux_lora_to_diffusers,
_convert_musubi_wan_lora_to_diffusers,
@@ -2062,6 +2063,17 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
return_metadata=return_lora_metadata,
)
is_fal_kontext = any("base_model" in k for k in state_dict)
if is_fal_kontext:
state_dict = _convert_fal_kontext_lora_to_diffusers(state_dict)
return cls._prepare_outputs(
state_dict,
metadata=metadata,
alphas=None,
return_alphas=return_alphas,
return_metadata=return_lora_metadata,
)
# For state dicts like
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
keys = list(state_dict.keys())
+22 -4
View File
@@ -22,6 +22,7 @@ from typing import Dict, List, Literal, Optional, Union
import safetensors
import torch
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
from ..utils import (
MIN_PEFT_VERSION,
USE_PEFT_BACKEND,
@@ -243,20 +244,29 @@ class PeftAdapterMixin:
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
}
# create LoraConfig
lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(self)
# create LoraConfig
lora_config = _create_lora_config(
state_dict,
network_alphas,
metadata,
rank,
model_state_dict=self.state_dict(),
adapter_name=adapter_name,
)
# <Unsafe code
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
# Now we remove any existing hooks to `_pipeline`.
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error.
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
_pipeline
)
peft_kwargs = {}
if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
@@ -347,6 +357,10 @@ class PeftAdapterMixin:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
elif is_group_offload:
for component in _pipeline.components.values():
if isinstance(component, torch.nn.Module):
_maybe_remove_and_reapply_group_offloading(component)
# Unsafe code />
if prefix is not None and not state_dict:
@@ -686,6 +700,10 @@ class PeftAdapterMixin:
recurse_remove_peft_layers(self)
if hasattr(self, "peft_config"):
del self.peft_config
if hasattr(self, "_hf_peft_config_loaded"):
self._hf_peft_config_loaded = None
_maybe_remove_and_reapply_group_offloading(self)
def disable_lora(self):
"""
@@ -31,6 +31,7 @@ from .single_file_utils import (
convert_autoencoder_dc_checkpoint_to_diffusers,
convert_chroma_transformer_checkpoint_to_diffusers,
convert_controlnet_checkpoint,
convert_cosmos_transformer_checkpoint_to_diffusers,
convert_flux_transformer_checkpoint_to_diffusers,
convert_hidream_transformer_to_diffusers,
convert_hunyuan_video_transformer_to_diffusers,
@@ -135,6 +136,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
"default_subfolder": "transformer",
},
"WanVACETransformer3DModel": {
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
"default_subfolder": "transformer",
},
"AutoencoderKLWan": {
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
"default_subfolder": "vae",
@@ -143,6 +148,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"checkpoint_mapping_fn": convert_hidream_transformer_to_diffusers,
"default_subfolder": "transformer",
},
"CosmosTransformer3DModel": {
"checkpoint_mapping_fn": convert_cosmos_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
}
+165 -1
View File
@@ -126,7 +126,18 @@ CHECKPOINT_KEY_NAMES = {
],
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
"wan_vae": "decoder.middle.0.residual.0.gamma",
"wan_vace": "vace_blocks.0.after_proj.bias",
"hidream": "double_stream_blocks.0.block.adaLN_modulation.1.bias",
"cosmos-1.0": [
"net.x_embedder.proj.1.weight",
"net.blocks.block1.blocks.0.block.attn.to_q.0.weight",
"net.extra_pos_embedder.pos_emb_h",
],
"cosmos-2.0": [
"net.x_embedder.proj.1.weight",
"net.blocks.0.self_attn.q_proj.weight",
"net.pos_embedder.dim_spatial_range",
],
}
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -192,7 +203,17 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
"wan-vace-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-1.3B-diffusers"},
"wan-vace-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-14B-diffusers"},
"hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"},
"cosmos-1.0-t2w-7B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-7B-Text2World"},
"cosmos-1.0-t2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-14B-Text2World"},
"cosmos-1.0-v2w-7B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-7B-Video2World"},
"cosmos-1.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-14B-Video2World"},
"cosmos-2.0-t2i-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Text2Image"},
"cosmos-2.0-t2i-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Text2Image"},
"cosmos-2.0-v2w-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Video2World"},
"cosmos-2.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Video2World"},
}
# Use to configure model sample size when original config is provided
@@ -698,17 +719,44 @@ def infer_diffusers_model_type(checkpoint):
else:
target_key = "patch_embedding.weight"
if checkpoint[target_key].shape[0] == 1536:
if CHECKPOINT_KEY_NAMES["wan_vace"] in checkpoint:
if checkpoint[target_key].shape[0] == 1536:
model_type = "wan-vace-1.3B"
elif checkpoint[target_key].shape[0] == 5120:
model_type = "wan-vace-14B"
elif checkpoint[target_key].shape[0] == 1536:
model_type = "wan-t2v-1.3B"
elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16:
model_type = "wan-t2v-14B"
else:
model_type = "wan-i2v-14B"
elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint:
# All Wan models use the same VAE so we can use the same default model repo to fetch the config
model_type = "wan-t2v-14B"
elif CHECKPOINT_KEY_NAMES["hidream"] in checkpoint:
model_type = "hidream"
elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["cosmos-1.0"]):
x_embedder_shape = checkpoint[CHECKPOINT_KEY_NAMES["cosmos-1.0"][0]].shape
if x_embedder_shape[1] == 68:
model_type = "cosmos-1.0-t2w-7B" if x_embedder_shape[0] == 4096 else "cosmos-1.0-t2w-14B"
elif x_embedder_shape[1] == 72:
model_type = "cosmos-1.0-v2w-7B" if x_embedder_shape[0] == 4096 else "cosmos-1.0-v2w-14B"
else:
raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 1.0 model.")
elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["cosmos-2.0"]):
x_embedder_shape = checkpoint[CHECKPOINT_KEY_NAMES["cosmos-2.0"][0]].shape
if x_embedder_shape[1] == 68:
model_type = "cosmos-2.0-t2i-2B" if x_embedder_shape[0] == 2048 else "cosmos-2.0-t2i-14B"
elif x_embedder_shape[1] == 72:
model_type = "cosmos-2.0-v2w-2B" if x_embedder_shape[0] == 2048 else "cosmos-2.0-v2w-14B"
else:
raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 2.0 model.")
else:
model_type = "v1"
@@ -3093,6 +3141,9 @@ def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
# For the VACE model
"before_proj": "proj_in",
"after_proj": "proj_out",
}
for key in list(checkpoint.keys()):
@@ -3479,3 +3530,116 @@ def convert_chroma_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
return converted_state_dict
def convert_cosmos_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
def remove_keys_(key: str, state_dict):
state_dict.pop(key)
def rename_transformer_blocks_(key: str, state_dict):
block_index = int(key.split(".")[1].removeprefix("block"))
new_key = key
old_prefix = f"blocks.block{block_index}"
new_prefix = f"transformer_blocks.{block_index}"
new_key = new_prefix + new_key.removeprefix(old_prefix)
state_dict[new_key] = state_dict.pop(key)
TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0 = {
"t_embedder.1": "time_embed.t_embedder",
"affline_norm": "time_embed.norm",
".blocks.0.block.attn": ".attn1",
".blocks.1.block.attn": ".attn2",
".blocks.2.block": ".ff",
".blocks.0.adaLN_modulation.1": ".norm1.linear_1",
".blocks.0.adaLN_modulation.2": ".norm1.linear_2",
".blocks.1.adaLN_modulation.1": ".norm2.linear_1",
".blocks.1.adaLN_modulation.2": ".norm2.linear_2",
".blocks.2.adaLN_modulation.1": ".norm3.linear_1",
".blocks.2.adaLN_modulation.2": ".norm3.linear_2",
"to_q.0": "to_q",
"to_q.1": "norm_q",
"to_k.0": "to_k",
"to_k.1": "norm_k",
"to_v.0": "to_v",
"layer1": "net.0.proj",
"layer2": "net.2",
"proj.1": "proj",
"x_embedder": "patch_embed",
"extra_pos_embedder": "learnable_pos_embed",
"final_layer.adaLN_modulation.1": "norm_out.linear_1",
"final_layer.adaLN_modulation.2": "norm_out.linear_2",
"final_layer.linear": "proj_out",
}
TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0 = {
"blocks.block": rename_transformer_blocks_,
"logvar.0.freqs": remove_keys_,
"logvar.0.phases": remove_keys_,
"logvar.1.weight": remove_keys_,
"pos_embedder.seq": remove_keys_,
}
TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 = {
"t_embedder.1": "time_embed.t_embedder",
"t_embedding_norm": "time_embed.norm",
"blocks": "transformer_blocks",
"adaln_modulation_self_attn.1": "norm1.linear_1",
"adaln_modulation_self_attn.2": "norm1.linear_2",
"adaln_modulation_cross_attn.1": "norm2.linear_1",
"adaln_modulation_cross_attn.2": "norm2.linear_2",
"adaln_modulation_mlp.1": "norm3.linear_1",
"adaln_modulation_mlp.2": "norm3.linear_2",
"self_attn": "attn1",
"cross_attn": "attn2",
"q_proj": "to_q",
"k_proj": "to_k",
"v_proj": "to_v",
"output_proj": "to_out.0",
"q_norm": "norm_q",
"k_norm": "norm_k",
"mlp.layer1": "ff.net.0.proj",
"mlp.layer2": "ff.net.2",
"x_embedder.proj.1": "patch_embed.proj",
"final_layer.adaln_modulation.1": "norm_out.linear_1",
"final_layer.adaln_modulation.2": "norm_out.linear_2",
"final_layer.linear": "proj_out",
}
TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0 = {
"accum_video_sample_counter": remove_keys_,
"accum_image_sample_counter": remove_keys_,
"accum_iteration": remove_keys_,
"accum_train_in_hours": remove_keys_,
"pos_embedder.seq": remove_keys_,
"pos_embedder.dim_spatial_range": remove_keys_,
"pos_embedder.dim_temporal_range": remove_keys_,
"_extra_state": remove_keys_,
}
PREFIX_KEY = "net."
if "net.blocks.block1.blocks.0.block.attn.to_q.0.weight" in checkpoint:
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0
else:
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
state_dict_keys = list(converted_state_dict.keys())
for key in state_dict_keys:
new_key = key[:]
if new_key.startswith(PREFIX_KEY):
new_key = new_key.removeprefix(PREFIX_KEY)
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
converted_state_dict[new_key] = converted_state_dict.pop(key)
state_dict_keys = list(converted_state_dict.keys())
for key in state_dict_keys:
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, converted_state_dict)
return converted_state_dict
+15 -4
View File
@@ -22,6 +22,7 @@ import torch
import torch.nn.functional as F
from huggingface_hub.utils import validate_hf_hub_args
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
from ..models.embeddings import (
ImageProjection,
IPAdapterFaceIDImageProjection,
@@ -203,6 +204,7 @@ class UNet2DConditionLoadersMixin:
is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys())
is_model_cpu_offload = False
is_sequential_cpu_offload = False
is_group_offload = False
if is_lora:
deprecation_message = "Using the `load_attn_procs()` method has been deprecated and will be removed in a future version. Please use `load_lora_adapter()`."
@@ -211,7 +213,7 @@ class UNet2DConditionLoadersMixin:
if is_custom_diffusion:
attn_processors = self._process_custom_diffusion(state_dict=state_dict)
elif is_lora:
is_model_cpu_offload, is_sequential_cpu_offload = self._process_lora(
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._process_lora(
state_dict=state_dict,
unet_identifier_key=self.unet_name,
network_alphas=network_alphas,
@@ -230,7 +232,9 @@ class UNet2DConditionLoadersMixin:
# For LoRA, the UNet is already offloaded at this stage as it is handled inside `_process_lora`.
if is_custom_diffusion and _pipeline is not None:
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline=_pipeline)
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
_pipeline=_pipeline
)
# only custom diffusion needs to set attn processors
self.set_attn_processor(attn_processors)
@@ -241,6 +245,10 @@ class UNet2DConditionLoadersMixin:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
elif is_group_offload:
for component in _pipeline.components.values():
if isinstance(component, torch.nn.Module):
_maybe_remove_and_reapply_group_offloading(component)
# Unsafe code />
def _process_custom_diffusion(self, state_dict):
@@ -307,6 +315,7 @@ class UNet2DConditionLoadersMixin:
is_model_cpu_offload = False
is_sequential_cpu_offload = False
is_group_offload = False
state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict
if len(state_dict_to_be_used) > 0:
@@ -356,7 +365,9 @@ class UNet2DConditionLoadersMixin:
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
_pipeline
)
peft_kwargs = {}
if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
@@ -389,7 +400,7 @@ class UNet2DConditionLoadersMixin:
if warn_msg:
logger.warning(warn_msg)
return is_model_cpu_offload, is_sequential_cpu_offload
return is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload
@classmethod
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
+5 -2
View File
@@ -14,6 +14,8 @@
import copy
from typing import TYPE_CHECKING, Dict, List, Union
from torch import nn
from ..utils import logging
@@ -52,7 +54,7 @@ def _maybe_expand_lora_scales(
weight_for_adapter,
blocks_with_transformer,
transformer_per_block,
unet.state_dict(),
model=unet,
default_scale=default_scale,
)
for weight_for_adapter in weight_scales
@@ -65,7 +67,7 @@ def _maybe_expand_lora_scales_for_one_adapter(
scales: Union[float, Dict],
blocks_with_transformer: Dict[str, int],
transformer_per_block: Dict[str, int],
state_dict: None,
model: nn.Module,
default_scale: float = 1.0,
):
"""
@@ -154,6 +156,7 @@ def _maybe_expand_lora_scales_for_one_adapter(
del scales[updown]
state_dict = model.state_dict()
for layer in scales.keys():
if not any(_translate_into_actual_layer_name(layer) in module for module in state_dict.keys()):
raise ValueError(
+31 -9
View File
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from ..utils.logging import get_logger
@@ -25,6 +27,7 @@ class CacheMixin:
Supported caching techniques:
- [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588)
- [FasterCache](https://huggingface.co/papers/2410.19355)
- [FirstBlockCache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching)
"""
_cache_config = None
@@ -62,8 +65,10 @@ class CacheMixin:
from ..hooks import (
FasterCacheConfig,
FirstBlockCacheConfig,
PyramidAttentionBroadcastConfig,
apply_faster_cache,
apply_first_block_cache,
apply_pyramid_attention_broadcast,
)
@@ -72,31 +77,36 @@ class CacheMixin:
f"Caching has already been enabled with {type(self._cache_config)}. To apply a new caching technique, please disable the existing one first."
)
if isinstance(config, PyramidAttentionBroadcastConfig):
apply_pyramid_attention_broadcast(self, config)
elif isinstance(config, FasterCacheConfig):
if isinstance(config, FasterCacheConfig):
apply_faster_cache(self, config)
elif isinstance(config, FirstBlockCacheConfig):
apply_first_block_cache(self, config)
elif isinstance(config, PyramidAttentionBroadcastConfig):
apply_pyramid_attention_broadcast(self, config)
else:
raise ValueError(f"Cache config {type(config)} is not supported.")
self._cache_config = config
def disable_cache(self) -> None:
from ..hooks import FasterCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
if self._cache_config is None:
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
return
if isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
registry = HookRegistry.check_if_exists_or_initialize(self)
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
elif isinstance(self._cache_config, FasterCacheConfig):
registry = HookRegistry.check_if_exists_or_initialize(self)
registry = HookRegistry.check_if_exists_or_initialize(self)
if isinstance(self._cache_config, FasterCacheConfig):
registry.remove_hook(_FASTER_CACHE_DENOISER_HOOK, recurse=True)
registry.remove_hook(_FASTER_CACHE_BLOCK_HOOK, recurse=True)
elif isinstance(self._cache_config, FirstBlockCacheConfig):
registry.remove_hook(_FBC_LEADER_BLOCK_HOOK, recurse=True)
registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True)
elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
else:
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")
@@ -106,3 +116,15 @@ class CacheMixin:
from ..hooks import HookRegistry
HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse)
@contextmanager
def cache_context(self, name: str):
r"""Context manager that provides additional methods for cache management."""
from ..hooks import HookRegistry
registry = HookRegistry.check_if_exists_or_initialize(self)
registry._set_context(name)
yield
registry._set_context(None)
@@ -343,25 +343,25 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
)
block_samples = block_samples + (hidden_states,)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
single_block_samples = ()
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
)
else:
hidden_states = block(
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
single_block_samples = single_block_samples + (hidden_states,)
# controlnet block
controlnet_block_samples = ()
@@ -21,6 +21,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention
from ..cache_utils import CacheMixin
@@ -453,6 +454,7 @@ class CogView4TrainingAttnProcessor:
return hidden_states, encoder_hidden_states
@maybe_allow_in_graph
class CogView4TransformerBlock(nn.Module):
def __init__(
self,
@@ -20,6 +20,7 @@ import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils import is_torchvision_available
from ..attention import FeedForward
from ..attention_processor import Attention
@@ -377,7 +378,7 @@ class CosmosLearnablePositionalEmbed(nn.Module):
return (emb / norm).type_as(hidden_states)
class CosmosTransformer3DModel(ModelMixin, ConfigMixin):
class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A Transformer model for video-like data used in [Cosmos](https://github.com/NVIDIA/Cosmos).
@@ -79,10 +79,14 @@ class FluxSingleTransformerBlock(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
text_seq_len = encoder_hidden_states.shape[1]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
@@ -100,7 +104,8 @@ class FluxSingleTransformerBlock(nn.Module):
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
return hidden_states
encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
return encoder_hidden_states, hidden_states
@maybe_allow_in_graph
@@ -507,20 +512,21 @@ class FluxTransformer2DModel(
)
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
)
else:
hidden_states = block(
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
@@ -530,12 +536,7 @@ class FluxTransformer2DModel(
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ controlnet_single_block_samples[index_block // interval_control]
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
@@ -22,6 +22,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention
from ..cache_utils import CacheMixin
@@ -71,14 +72,22 @@ class WanAttnProcessor2_0:
if rotary_emb is not None:
def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
dtype = torch.float32 if hidden_states.device.type == "mps" else torch.float64
x_rotated = torch.view_as_complex(hidden_states.to(dtype).unflatten(3, (-1, 2)))
x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
return x_out.type_as(hidden_states)
def apply_rotary_emb(
hidden_states: torch.Tensor,
freqs_cos: torch.Tensor,
freqs_sin: torch.Tensor,
):
x = hidden_states.view(*hidden_states.shape[:-1], -1, 2)
x1, x2 = x[..., 0], x[..., 1]
cos = freqs_cos[..., 0::2]
sin = freqs_sin[..., 1::2]
out = torch.empty_like(hidden_states)
out[..., 0::2] = x1 * cos - x2 * sin
out[..., 1::2] = x1 * sin + x2 * cos
return out.type_as(hidden_states)
query = apply_rotary_emb(query, rotary_emb)
key = apply_rotary_emb(key, rotary_emb)
query = apply_rotary_emb(query, *rotary_emb)
key = apply_rotary_emb(key, *rotary_emb)
# I2V task
hidden_states_img = None
@@ -179,7 +188,11 @@ class WanTimeTextImageEmbedding(nn.Module):
class WanRotaryPosEmbed(nn.Module):
def __init__(
self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0
self,
attention_head_dim: int,
patch_size: Tuple[int, int, int],
max_seq_len: int,
theta: float = 10000.0,
):
super().__init__()
@@ -189,38 +202,55 @@ class WanRotaryPosEmbed(nn.Module):
h_dim = w_dim = 2 * (attention_head_dim // 6)
t_dim = attention_head_dim - h_dim - w_dim
freqs = []
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
freqs_cos = []
freqs_sin = []
for dim in [t_dim, h_dim, w_dim]:
freq = get_1d_rotary_pos_embed(
dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=freqs_dtype
freq_cos, freq_sin = get_1d_rotary_pos_embed(
dim,
max_seq_len,
theta,
use_real=True,
repeat_interleave_real=True,
freqs_dtype=freqs_dtype,
)
freqs.append(freq)
self.freqs = torch.cat(freqs, dim=1)
freqs_cos.append(freq_cos)
freqs_sin.append(freq_sin)
self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.patch_size
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
freqs = self.freqs.to(hidden_states.device)
freqs = freqs.split_with_sizes(
[
self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
self.attention_head_dim // 6,
self.attention_head_dim // 6,
],
dim=1,
)
split_sizes = [
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
self.attention_head_dim // 3,
self.attention_head_dim // 3,
]
freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
return freqs
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
return freqs_cos, freqs_sin
@maybe_allow_in_graph
class WanTransformerBlock(nn.Module):
def __init__(
self,
+2
View File
@@ -141,6 +141,7 @@ else:
"FluxPriorReduxPipeline",
"ReduxImageEncoder",
"FluxKontextPipeline",
"FluxKontextInpaintPipeline",
]
_import_structure["audioldm"] = ["AudioLDMPipeline"]
_import_structure["audioldm2"] = [
@@ -610,6 +611,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxFillPipeline,
FluxImg2ImgPipeline,
FluxInpaintPipeline,
FluxKontextInpaintPipeline,
FluxKontextPipeline,
FluxPipeline,
FluxPriorReduxPipeline,
@@ -718,14 +718,15 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
with self.transformer.cache_context("cond_uncond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
# perform guidance
@@ -784,14 +784,15 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
with self.transformer.cache_context("cond_uncond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
# perform guidance
@@ -831,15 +831,16 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
ofs=ofs_emb,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
with self.transformer.cache_context("cond_uncond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
ofs=ofs_emb,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
# perform guidance
@@ -799,14 +799,15 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
with self.transformer.cache_context("cond_uncond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
# perform guidance
@@ -619,22 +619,10 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0])
noise_pred_cond = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
original_size=original_size,
target_size=target_size,
crop_coords=crops_coords_top_left,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond = self.transformer(
with self.transformer.cache_context("cond"):
noise_pred_cond = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=negative_prompt_embeds,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
original_size=original_size,
target_size=target_size,
@@ -643,6 +631,19 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
return_dict=False,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
with self.transformer.cache_context("uncond"):
noise_pred_uncond = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=negative_prompt_embeds,
timestep=timestep,
original_size=original_size,
target_size=target_size,
crop_coords=crops_coords_top_left,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
else:
noise_pred = noise_pred_cond
@@ -29,7 +29,7 @@ from ...utils.torch_utils import randn_tensor
from ..blip_diffusion.blip_image_processing import BlipImageProcessor
from ..blip_diffusion.modeling_blip2 import Blip2QFormerModel
from ..blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
@@ -88,7 +88,7 @@ EXAMPLE_DOC_STRING = """
"""
class BlipDiffusionControlNetPipeline(DiffusionPipeline):
class BlipDiffusionControlNetPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
"""
Pipeline for Canny Edge based Controlled subject-driven generation using Blip Diffusion.
@@ -116,6 +116,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
Position of the context token in the text encoder.
"""
_last_supported_version = "0.33.1"
model_cpu_offload_seq = "qformer->text_encoder->unet->vae"
def __init__(
+2
View File
@@ -34,6 +34,7 @@ else:
_import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"]
_import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"]
_import_structure["pipeline_flux_kontext"] = ["FluxKontextPipeline"]
_import_structure["pipeline_flux_kontext_inpaint"] = ["FluxKontextInpaintPipeline"]
_import_structure["pipeline_flux_prior_redux"] = ["FluxPriorReduxPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -54,6 +55,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipeline_flux_img2img import FluxImg2ImgPipeline
from .pipeline_flux_inpaint import FluxInpaintPipeline
from .pipeline_flux_kontext import FluxKontextPipeline
from .pipeline_flux_kontext_inpaint import FluxKontextInpaintPipeline
from .pipeline_flux_prior_redux import FluxPriorReduxPipeline
else:
import sys
+22 -19
View File
@@ -912,32 +912,35 @@ class FluxPipeline(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
if do_true_cfg:
if negative_image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
neg_noise_pred = self.transformer(
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=negative_pooled_prompt_embeds,
encoder_hidden_states=negative_prompt_embeds,
txt_ids=negative_text_ids,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
if do_true_cfg:
if negative_image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
with self.transformer.cache_context("uncond"):
neg_noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=negative_pooled_prompt_embeds,
encoder_hidden_states=negative_prompt_embeds,
txt_ids=negative_text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
# compute the previous noisy sample x_t -> x_t-1
@@ -163,9 +163,9 @@ class FluxControlPipeline(
TextualInversionLoaderMixin,
):
r"""
The Flux pipeline for controllable text-to-image generation.
The Flux pipeline for controllable text-to-image generation with image conditions.
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
Reference: https://bfl.ai/flux-1-tools
Args:
transformer ([`FluxTransformer2DModel`]):
@@ -195,9 +195,9 @@ class FluxKontextPipeline(
FluxIPAdapterMixin,
):
r"""
The Flux Kontext pipeline for text-to-image generation.
The Flux Kontext pipeline for image-to-image and text-to-image generation.
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
Reference: https://bfl.ai/announcements/flux-1-kontext-dev
Args:
transformer ([`FluxTransformer2DModel`]):
File diff suppressed because it is too large Load Diff
@@ -693,28 +693,30 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
pooled_projections=pooled_prompt_embeds,
guidance=guidance,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if do_true_cfg:
neg_noise_pred = self.transformer(
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
encoder_attention_mask=negative_prompt_attention_mask,
pooled_projections=negative_pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
pooled_projections=pooled_prompt_embeds,
guidance=guidance,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if do_true_cfg:
with self.transformer.cache_context("uncond"):
neg_noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
encoder_attention_mask=negative_prompt_attention_mask,
pooled_projections=negative_pooled_prompt_embeds,
guidance=guidance,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
# compute the previous noisy sample x_t -> x_t-1

Some files were not shown because too many files have changed in this diff Show More