Compare commits

...

36 Commits

Author SHA1 Message Date
Sayak Paul 3fb66f23ac Merge branch 'main' into migrate-lora-pytest 2025-11-20 10:13:01 +05:30
sayakpaul 9c3bed1783 up 2025-11-20 10:12:31 +05:30
Pratim Dasude d5da453de5 Community Pipeline: FluxFillControlNetInpaintPipeline for FLUX Fill-Based Inpainting with ControlNet (#12649)
* new flux fill controlnet inpaint pipline

* Delete src/diffusers/pipelines/flux/pipline_flux_fill_controlnet_Inpaint.py

deleting from main flux pipeline

* Fluc_fill_controlnet community pipline

* Update README.md

* Apply style fixes
2025-11-19 16:18:46 -03:00
David El Malih 15370f8412 Improve docstrings and type hints in scheduling_pndm.py (#12676)
* Enhance docstrings and type hints in PNDMScheduler class

- Updated parameter descriptions to include default values and specific types using Literal for better clarity.
- Improved docstring formatting and consistency across methods, including detailed explanations for the `_get_prev_sample` method.
- Added type hints for method return types to enhance code readability and maintainability.

* Refactor docstring in PNDMScheduler class to enhance clarity

- Simplified the explanation of the method for computing the previous sample from the current sample.
- Updated the reference to the PNDM paper for better accessibility.
- Removed redundant notation explanations to streamline the documentation.
2025-11-19 09:36:41 -08:00
Dhruv Nair a96b145304 [CI] Fix failing Pipeline CPU tests (#12681)
update

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-11-19 21:19:24 +05:30
Dhruv Nair 6d8973ffe2 [CI] Fix indentation issue in workflow files (#12685)
update
2025-11-19 09:30:04 +05:30
Sayak Paul ab71f3c864 [core] Refactor hub attn kernels (#12475)
* refactor how attention kernels from hub are used.

* up

* refactor according to Dhruv's ideas.

Co-authored-by: Dhruv Nair <dhruv@huggingface.co>

* empty

Co-authored-by: Dhruv Nair <dhruv@huggingface.co>

* empty

Co-authored-by: Dhruv Nair <dhruv@huggingface.co>

* empty

Co-authored-by: dn6 <dhruv@huggingface.co>

* up

---------

Co-authored-by: Dhruv Nair <dhruv@huggingface.co>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2025-11-19 08:19:00 +05:30
Dhruv Nair b7df4a5387 [CI] Temporarily pin transformers (#12677)
* update

* update

* update

* update
2025-11-18 14:43:06 +05:30
dg845 67dc65e2e3 Revert AutoencoderKLWan's dim_mult default value back to list (#12640)
Revert dim_mult back to list and fix type annotation
2025-11-17 18:39:53 +05:30
Dhruv Nair 3579fdabf9 [CI] Make CI logs less verbose (#12674)
update
2025-11-17 14:23:09 +05:30
Junsong Chen 1afc21855e SANA-Video Image to Video pipeline SanaImageToVideoPipeline support (#12634)
* move sana-video to a new dir and add `SanaImageToVideoPipeline` with no modify;

* fix bug and run text/image-to-vidoe success;

* make style; quality; fix-copies;

* add sana image-to-video pipeline in markdown;

* add test case for sana image-to-video;

* make style;

* add a init file in sana-video test dir;

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update tests/pipelines/sana_video/test_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update tests/pipelines/sana_video/test_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* minor update;

* fix bug and skip fp16 save test;

Co-authored-by: Yuyang Zhao <43061147+HeliosZhao@users.noreply.github.com>

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* add copied from for `encode_prompt`

* Apply style fixes

---------

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
Co-authored-by: Yuyang Zhao <43061147+HeliosZhao@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-11-17 00:23:34 -08:00
Sayak Paul 11b80d09b0 Merge branch 'main' into migrate-lora-pytest 2025-11-10 13:27:10 +05:30
Sayak Paul 9201505554 Merge branch 'main' into migrate-lora-pytest 2025-11-06 10:39:44 +05:30
sayakpaul eece7120dd up 2025-11-06 10:31:37 +05:30
Sayak Paul 2e42205c3a Merge branch 'main' into migrate-lora-pytest 2025-11-06 10:24:51 +05:30
Sayak Paul 757bbf7b05 Merge branch 'main' into migrate-lora-pytest 2025-10-24 22:24:15 +05:30
Sayak Paul 4561c065aa Merge branch 'main' into migrate-lora-pytest 2025-10-17 19:29:40 +05:30
Sayak Paul 4ae5772fef Merge branch 'main' into migrate-lora-pytest 2025-10-17 07:55:31 +05:30
sayakpaul 0d3da485a0 up 2025-10-03 21:00:05 +05:30
sayakpaul 4f5e9a665e up 2025-10-03 20:49:50 +05:30
Sayak Paul 23e5559c54 Merge branch 'main' into migrate-lora-pytest 2025-10-03 20:44:52 +05:30
sayakpaul f8f27891c6 up 2025-10-03 20:14:45 +05:30
sayakpaul 128535cfcd up 2025-10-03 20:03:50 +05:30
sayakpaul bdc9537999 more fixtures. 2025-10-03 20:01:26 +05:30
sayakpaul dae161ed26 up 2025-10-03 17:39:55 +05:30
sayakpaul c4bcf72084 up 2025-10-03 16:56:31 +05:30
sayakpaul 1737b710a2 up 2025-10-03 16:45:04 +05:30
sayakpaul 565d674cc4 change flux lora integration tests to use pytest 2025-10-03 16:30:58 +05:30
sayakpaul 610842af1a up 2025-10-03 16:14:36 +05:30
sayakpaul cba82591e8 up 2025-10-03 15:56:37 +05:30
sayakpaul 949cc1c326 up 2025-10-03 14:54:23 +05:30
sayakpaul ec866f5de8 tempfile is now a fixture. 2025-10-03 14:25:54 +05:30
sayakpaul 7b4bcce602 up 2025-10-03 14:10:31 +05:30
sayakpaul d61bb38fb4 up 2025-10-03 13:14:05 +05:30
sayakpaul 9e92f6bb63 up 2025-10-03 12:53:37 +05:30
sayakpaul 6c6cade1a7 migrate lora pipeline tests to pytest 2025-10-03 12:52:56 +05:30
52 changed files with 4215 additions and 1918 deletions
+21 -7
View File
@@ -73,6 +73,8 @@ jobs:
run: |
uv pip install -e ".[quality]"
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip install pytest-reportlog
- name: Environment
run: |
@@ -84,7 +86,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
--report-log=tests_pipeline_${{ matrix.module }}_cuda.log \
tests/pipelines/${{ matrix.module }}
@@ -126,6 +128,8 @@ jobs:
uv pip install -e ".[quality]"
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip install pytest-reportlog
- name: Environment
run: python utils/print_env.py
@@ -138,7 +142,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_torch_${{ matrix.module }}_cuda \
--report-log=tests_torch_${{ matrix.module }}_cuda.log \
tests/${{ matrix.module }}
@@ -151,7 +155,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v --make-reports=examples_torch_cuda \
--make-reports=examples_torch_cuda \
--report-log=examples_torch_cuda.log \
examples/
@@ -190,6 +194,8 @@ jobs:
- name: Install dependencies
run: |
uv pip install -e ".[quality,training]"
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
python utils/print_env.py
@@ -198,7 +204,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
RUN_COMPILE: yes
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "compile" --make-reports=tests_torch_compile_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_compile_cuda_failures_short.txt
@@ -232,6 +238,8 @@ jobs:
uv pip install -e ".[quality]"
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip install pytest-reportlog
- name: Environment
run: |
@@ -281,6 +289,8 @@ jobs:
uv pip install -e ".[quality]"
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
@@ -293,7 +303,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_torch_minimum_version_cuda \
tests/models/test_modeling_common.py \
tests/pipelines/test_pipelines_common.py \
@@ -358,6 +368,8 @@ jobs:
uv pip install ${{ join(matrix.config.additional_deps, ' ') }}
fi
uv pip install pytest-reportlog
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
python utils/print_env.py
@@ -405,6 +417,8 @@ jobs:
run: |
uv pip install -e ".[quality]"
uv pip install -U bitsandbytes optimum_quanto
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip install pytest-reportlog
- name: Environment
run: |
@@ -531,7 +545,7 @@ jobs:
# HF_HOME: /System/Volumes/Data/mnt/cache
# HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# run: |
# ${CONDA_RUN} pytest -n 1 -s -v --make-reports=tests_torch_mps \
# ${CONDA_RUN} pytest -n 1 --make-reports=tests_torch_mps \
# --report-log=tests_torch_mps.log \
# tests/
# - name: Failure short reports
@@ -587,7 +601,7 @@ jobs:
# HF_HOME: /System/Volumes/Data/mnt/cache
# HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# run: |
# ${CONDA_RUN} pytest -n 1 -s -v --make-reports=tests_torch_mps \
# ${CONDA_RUN} pytest -n 1 --make-reports=tests_torch_mps \
# --report-log=tests_torch_mps.log \
# tests/
# - name: Failure short reports
+3 -2
View File
@@ -109,7 +109,8 @@ jobs:
- name: Install dependencies
run: |
uv pip install -e ".[quality]"
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
- name: Environment
@@ -120,7 +121,7 @@ jobs:
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
run: |
pytest -n 8 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/modular_pipelines
+8 -6
View File
@@ -115,7 +115,8 @@ jobs:
- name: Install dependencies
run: |
uv pip install -e ".[quality]"
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
- name: Environment
@@ -126,7 +127,7 @@ jobs:
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
run: |
pytest -n 8 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/pipelines
@@ -134,7 +135,7 @@ jobs:
if: ${{ matrix.config.framework == 'pytorch_models' }}
run: |
pytest -n 4 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx and not Dependency" \
-k "not Flax and not Onnx and not Dependency" \
--make-reports=tests_${{ matrix.config.report }} \
tests/models tests/schedulers tests/others
@@ -246,7 +247,8 @@ jobs:
uv pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps
uv pip install -U tokenizers
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
@@ -255,11 +257,11 @@ jobs:
- name: Run fast PyTorch LoRA tests with PEFT
run: |
pytest -n 4 --max-worker-restart=0 --dist=loadfile \
-s -v \
\
--make-reports=tests_peft_main \
tests/lora/
pytest -n 4 --max-worker-restart=0 --dist=loadfile \
-s -v \
\
--make-reports=tests_models_lora_peft_main \
tests/models/ -k "lora"
+18 -15
View File
@@ -1,4 +1,4 @@
name: Fast GPU Tests on PR
name: Fast GPU Tests on PR
on:
pull_request:
@@ -71,7 +71,7 @@ jobs:
if: ${{ failure() }}
run: |
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
setup_torch_cuda_pipeline_matrix:
needs: [check_code_quality, check_repository_consistency]
name: Setup Torch Pipelines CUDA Slow Tests Matrix
@@ -131,7 +131,8 @@ jobs:
run: |
uv pip install -e ".[quality]"
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
@@ -149,18 +150,18 @@ jobs:
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
if [ "${{ matrix.module }}" = "ip_adapters" ]; then
if [ "${{ matrix.module }}" = "ip_adapters" ]; then
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
else
else
pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx and $pattern" \
-k "not Flax and not Onnx and $pattern" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
fi
fi
- name: Failure short reports
if: ${{ failure() }}
@@ -201,7 +202,8 @@ jobs:
uv pip install -e ".[quality]"
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
@@ -222,11 +224,11 @@ jobs:
run: |
pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})
if [ -z "$pattern" ]; then
pytest -n 1 -sv --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx" tests/${{ matrix.module }} \
--make-reports=tests_torch_cuda_${{ matrix.module }}
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx" tests/${{ matrix.module }} \
--make-reports=tests_torch_cuda_${{ matrix.module }}
else
pytest -n 1 -sv --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx and $pattern" tests/${{ matrix.module }} \
--make-reports=tests_torch_cuda_${{ matrix.module }}
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx and $pattern" tests/${{ matrix.module }} \
--make-reports=tests_torch_cuda_${{ matrix.module }}
fi
- name: Failure short reports
@@ -262,7 +264,8 @@ jobs:
nvidia-smi
- name: Install dependencies
run: |
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip install -e ".[quality,training]"
- name: Environment
@@ -274,7 +277,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
uv pip install ".[training]"
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
pytest -n 1 --max-worker-restart=0 --dist=loadfile --make-reports=examples_torch_cuda examples/
- name: Failure short reports
if: ${{ failure() }}
+11 -8
View File
@@ -76,7 +76,8 @@ jobs:
run: |
uv pip install -e ".[quality]"
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
python utils/print_env.py
@@ -87,7 +88,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
- name: Failure short reports
@@ -128,7 +129,8 @@ jobs:
uv pip install -e ".[quality]"
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
@@ -141,7 +143,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_torch_cuda_${{ matrix.module }} \
tests/${{ matrix.module }}
@@ -180,7 +182,8 @@ jobs:
- name: Install dependencies
run: |
uv pip install -e ".[quality,training]"
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
python utils/print_env.py
@@ -189,7 +192,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
RUN_COMPILE: yes
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "compile" --make-reports=tests_torch_compile_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_compile_cuda_failures_short.txt
@@ -230,7 +233,7 @@ jobs:
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_xformers_cuda_failures_short.txt
@@ -273,7 +276,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
uv pip install ".[training]"
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
pytest -n 1 --max-worker-restart=0 --dist=loadfile --make-reports=examples_torch_cuda examples/
- name: Failure short reports
if: ${{ failure() }}
+1 -1
View File
@@ -70,7 +70,7 @@ jobs:
if: ${{ matrix.config.framework == 'pytorch' }}
run: |
pytest -n 4 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
tests/
+1 -1
View File
@@ -57,7 +57,7 @@ jobs:
HF_HOME: /System/Volumes/Data/mnt/cache
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
${CONDA_RUN} python -m pytest -n 0 -s -v --make-reports=tests_torch_mps tests/
${CONDA_RUN} python -m pytest -n 0 --make-reports=tests_torch_mps tests/
- name: Failure short reports
if: ${{ failure() }}
+6 -6
View File
@@ -84,7 +84,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
- name: Failure short reports
@@ -137,7 +137,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_torch_${{ matrix.module }}_cuda \
tests/${{ matrix.module }}
@@ -187,7 +187,7 @@ jobs:
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
-k "not Flax and not Onnx" \
--make-reports=tests_torch_minimum_cuda \
tests/models/test_modeling_common.py \
tests/pipelines/test_pipelines_common.py \
@@ -240,7 +240,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
RUN_COMPILE: yes
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "compile" --make-reports=tests_torch_compile_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_compile_cuda_failures_short.txt
@@ -281,7 +281,7 @@ jobs:
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_xformers_cuda_failures_short.txt
@@ -326,7 +326,7 @@ jobs:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
uv pip install ".[training]"
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
pytest -n 1 --max-worker-restart=0 --dist=loadfile --make-reports=examples_torch_cuda examples/
- name: Failure short reports
if: ${{ failure() }}
+88 -2
View File
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License. -->
# SanaVideoPipeline
# Sana-Video
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
@@ -37,6 +37,85 @@ Refer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-vi
Note: The recommended dtype mentioned is for the transformer weights. The text encoder and VAE weights must stay in `torch.bfloat16` or `torch.float32` for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype.
## Generation Pipelines
<hfoptions id="generation pipelines">`
<hfoption id="Text-to-Video">
The example below demonstrates how to use the text-to-video pipeline to generate a video using a text descriptio and a starting frame.
```python
model_id =
pipe = SanaVideoPipeline.from_pretrained("Efficient-Large-Model/SANA-Video_2B_480p_diffusers", torch_dtype=torch.bfloat16)
pipe.text_encoder.to(torch.bfloat16)
pipe.vae.to(torch.float32)
pipe.to("cuda")
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
motion_scale = 30
motion_prompt = f" motion score: {motion_scale}."
prompt = prompt + motion_prompt
video = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=480,
width=832,
frames=81,
guidance_scale=6,
num_inference_steps=50,
generator=torch.Generator(device="cuda").manual_seed(0),
).frames[0]
export_to_video(video, "sana_video.mp4", fps=16)
```
</hfoption>
<hfoption id="Image-to-Video">
The example below demonstrates how to use the image-to-video pipeline to generate a video using a text descriptio and a starting frame.
```python
model_id = "Efficient-Large-Model/SANA-Video_2B_480p_diffusers"
pipe = SanaImageToVideoPipeline.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
)
pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
pipe.vae.to(torch.float32)
pipe.text_encoder.to(torch.bfloat16)
pipe.to("cuda")
image = load_image("https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/samples/i2v-1.png")
prompt = "A woman stands against a stunning sunset backdrop, her long, wavy brown hair gently blowing in the breeze. She wears a sleeveless, light-colored blouse with a deep V-neckline, which accentuates her graceful posture. The warm hues of the setting sun cast a golden glow across her face and hair, creating a serene and ethereal atmosphere. The background features a blurred landscape with soft, rolling hills and scattered clouds, adding depth to the scene. The camera remains steady, capturing the tranquil moment from a medium close-up angle."
negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
motion_scale = 30
motion_prompt = f" motion score: {motion_scale}."
prompt = prompt + motion_prompt
motion_scale = 30.0
video = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
height=480,
width=832,
frames=81,
guidance_scale=6,
num_inference_steps=50,
generator=torch.Generator(device="cuda").manual_seed(0),
).frames[0]
export_to_video(video, "sana-i2v.mp4", fps=16)
```
</hfoption>
</hfoptions>
## Quantization
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
@@ -97,6 +176,13 @@ export_to_video(output, "sana-video-output.mp4", fps=16)
- __call__
## SanaImageToVideoPipeline
[[autodoc]] SanaImageToVideoPipeline
- all
- __call__
## SanaVideoPipelineOutput
[[autodoc]] pipelines.sana.pipeline_sana_video.SanaVideoPipelineOutput
[[autodoc]] pipelines.sana_video.pipeline_sana_video.SanaVideoPipelineOutput
+104 -1
View File
@@ -88,7 +88,7 @@ PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixar
| FaithDiff Stable Diffusion XL Pipeline | Implementation of [(CVPR 2025) FaithDiff: Unleashing Diffusion Priors for Faithful Image Super-resolutionUnleashing Diffusion Priors for Faithful Image Super-resolution](https://huggingface.co/papers/2411.18824) - FaithDiff is a faithful image super-resolution method that leverages latent diffusion models by actively adapting the diffusion prior and jointly fine-tuning its components (encoder and diffusion model) with an alignment module to ensure high fidelity and structural consistency. | [FaithDiff Stable Diffusion XL Pipeline](#faithdiff-stable-diffusion-xl-pipeline) | [![Hugging Face Models](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/jychen9811/FaithDiff) | [Junyang Chen, Jinshan Pan, Jiangxin Dong, IMAG Lab, (Adapted by Eliseu Silva)](https://github.com/JyChen9811/FaithDiff) |
| Stable Diffusion 3 InstructPix2Pix Pipeline | Implementation of Stable Diffusion 3 InstructPix2Pix Pipeline | [Stable Diffusion 3 InstructPix2Pix Pipeline](#stable-diffusion-3-instructpix2pix-pipeline) | [![Hugging Face Models](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/BleachNick/SD3_UltraEdit_freeform) [![Hugging Face Models](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/CaptainZZZ/sd3-instructpix2pix) | [Jiayu Zhang](https://github.com/xduzhangjiayu) and [Haozhe Zhao](https://github.com/HaozheZhao)|
| Flux Kontext multiple images | A modified version of the `FluxKontextPipeline` that supports calling Flux Kontext with multiple reference images.| [Flux Kontext multiple input Pipeline](#flux-kontext-multiple-images) | - | [Net-Mist](https://github.com/Net-Mist) |
| Flux Fill ControlNet Pipeline | A modified version of the `FluxFillPipeline` and `FluxControlNetInpaintPipeline` that supports Controlnet with Flux Fill model.| [Flux Fill ControlNet Pipeline](#Flux-Fill-ControlNet-Pipeline) | - | [pratim4dasude](https://github.com/pratim4dasude) |
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
@@ -5527,3 +5527,106 @@ images = pipe(
).images
images[0].save("pizzeria.png")
```
# Flux Fill ControlNet Pipeline
This implementation of Flux Fill + ControlNet Inpaint combines the fill-style masked editing of FLUX.1-Fill-dev with full ControlNet conditioning. The base image is processed through the Fill model while the ControlNet receives the corresponding conditioning input (depth, canny, pose, etc.), and both outputs are fused during denoising to guide structure and composition.
While FLUX.1-Fill-dev is designed for mask-based edits, it was not originally trained to operate jointly with ControlNet. In practice, this combined setup works well for structured inpainting tasks, though results may vary depending on the conditioning strength and the alignment between the mask and the control input.
## Example Usage
```python
import torch
from diffusers import (
FluxControlNetModel,
FluxPriorReduxPipeline,
)
from diffusers.utils import load_image
# NEW PIPELINE (updated name)
from pipline_flux_fill_controlnet_Inpaint import FluxControlNetFillInpaintPipeline
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16
# Models
base_model = "black-forest-labs/FLUX.1-Fill-dev"
controlnet_model = "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro-2.0"
prior_model = "black-forest-labs/FLUX.1-Redux-dev"
# Load ControlNet
controlnet = FluxControlNetModel.from_pretrained(
controlnet_model,
torch_dtype=dtype,
)
# Load Fill + ControlNet Pipeline
fill_pipe = FluxControlNetFillInpaintPipeline.from_pretrained(
base_model,
controlnet=controlnet,
torch_dtype=dtype,
).to(device)
# OPTIONAL FP8
# fill_pipe.transformer.enable_layerwise_casting(
# storage_dtype=torch.float8_e4m3fn,
# compute_dtype=torch.bfloat16
# )
# OPTIONAL Prior Redux
#pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(
# prior_model,
# torch_dtype=dtype,
#).to(device)
# Inputs
# combined_image = load_image("person_input.png")
# 1. Prior conditioning
#prior_out = pipe_prior_redux(
# image=cloth_image,
# prompt=cloth_prompt,
#)
# 2. Fill Inpaint with ControlNet
# canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6).
img = load_image(r"imgs/background.jpg")
mask = load_image(r"imgs/mask.png")
control_image_depth = load_image(r"imgs/dog_depth _2.png")
result = fill_pipe(
prompt="a dog on a bench",
image=img,
mask_image=mask,
control_image=control_image_depth,
control_mode=[2], # union mode
control_guidance_start=0.0,
control_guidance_end=0.8,
controlnet_conditioning_scale=0.9,
height=1024,
width=1024,
strength=1.0,
guidance_scale=50.0,
num_inference_steps=60,
max_sequence_length=512,
# **prior_out,
)
# result.images[0].save("flux_fill_controlnet_inpaint.png")
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
result.images[0].save(f"flux_fill_controlnet_inpaint_depth{timestamp}.jpg")
```
File diff suppressed because it is too large Load Diff
@@ -80,6 +80,8 @@ def main(args):
# scheduler
flow_shift = 8.0
if args.task == "i2v":
assert args.scheduler_type == "flow-euler", "Scheduler type must be flow-euler for i2v task."
# model config
layer_num = 20
@@ -312,6 +314,7 @@ if __name__ == "__main__":
choices=["flow-dpm_solver", "flow-euler", "uni-pc"],
help="Scheduler type to use.",
)
parser.add_argument("--task", default="t2v", type=str, required=True, help="Task to convert, t2v or i2v.")
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.")
parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")
+3
View File
@@ -545,11 +545,13 @@ else:
"QwenImagePipeline",
"ReduxImageEncoder",
"SanaControlNetPipeline",
"SanaImageToVideoPipeline",
"SanaPAGPipeline",
"SanaPipeline",
"SanaSprintImg2ImgPipeline",
"SanaSprintPipeline",
"SanaVideoPipeline",
"SanaVideoPipeline",
"SemanticStableDiffusionPipeline",
"ShapEImg2ImgPipeline",
"ShapEPipeline",
@@ -1227,6 +1229,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
QwenImagePipeline,
ReduxImageEncoder,
SanaControlNetPipeline,
SanaImageToVideoPipeline,
SanaPAGPipeline,
SanaPipeline,
SanaSprintImg2ImgPipeline,
+47 -20
View File
@@ -16,6 +16,7 @@ import contextlib
import functools
import inspect
import math
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
@@ -42,7 +43,7 @@ from ..utils import (
is_xformers_available,
is_xformers_version,
)
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ENABLE_HUB_KERNELS
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
if TYPE_CHECKING:
@@ -82,24 +83,11 @@ else:
flash_attn_3_func = None
flash_attn_3_varlen_func = None
if _CAN_USE_AITER_ATTN:
from aiter import flash_attn_func as aiter_flash_attn_func
else:
aiter_flash_attn_func = None
if DIFFUSERS_ENABLE_HUB_KERNELS:
if not is_kernels_available():
raise ImportError(
"To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
)
from ..utils.kernels_utils import _get_fa3_from_hub
flash_attn_interface_hub = _get_fa3_from_hub()
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func
else:
flash_attn_3_func_hub = None
if _CAN_USE_SAGE_ATTN:
from sageattention import (
sageattn,
@@ -261,6 +249,25 @@ class _AttentionBackendRegistry:
return supports_context_parallel
@dataclass
class _HubKernelConfig:
"""Configuration for downloading and using a hub-based attention kernel."""
repo_id: str
function_attr: str
revision: Optional[str] = None
kernel_fn: Optional[Callable] = None
# Registry for hub-based attention kernels
_HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = {
# TODO: temporary revision for now. Remove when merged upstream into `main`.
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
)
}
@contextlib.contextmanager
def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE):
"""
@@ -415,13 +422,9 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
# TODO: add support Hub variant of FA3 varlen later
elif backend in [AttentionBackendName._FLASH_3_HUB]:
if not DIFFUSERS_ENABLE_HUB_KERNELS:
raise RuntimeError(
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
)
if not is_kernels_available():
raise RuntimeError(
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
)
elif backend == AttentionBackendName.AITER:
@@ -571,6 +574,29 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
return q_idx >= kv_idx
# ===== Helpers for downloading kernels =====
def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
if backend not in _HUB_KERNELS_REGISTRY:
return
config = _HUB_KERNELS_REGISTRY[backend]
if config.kernel_fn is not None:
return
try:
from kernels import get_kernel
kernel_module = get_kernel(config.repo_id, revision=config.revision)
kernel_func = getattr(kernel_module, config.function_attr)
# Cache the downloaded kernel function in the config object
config.kernel_fn = kernel_func
except Exception as e:
logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}")
raise
# ===== torch op registrations =====
# Registrations are required for fullgraph tracing compatibility
# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
@@ -1418,7 +1444,8 @@ def _flash_attention_3_hub(
return_attn_probs: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
out = flash_attn_3_func_hub(
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
@@ -971,7 +971,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
base_dim: int = 96,
decoder_base_dim: Optional[int] = None,
z_dim: int = 16,
dim_mult: Tuple[int, ...] = (1, 2, 4, 4),
dim_mult: List[int] = [1, 2, 4, 4],
num_res_blocks: int = 2,
attn_scales: List[float] = [],
temperal_downsample: List[bool] = [False, True, True],
+7 -1
View File
@@ -595,7 +595,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
attention as backend.
"""
from .attention import AttentionModuleMixin
from .attention_dispatch import AttentionBackendName, _check_attention_backend_requirements
from .attention_dispatch import (
AttentionBackendName,
_check_attention_backend_requirements,
_maybe_download_kernel_for_backend,
)
# TODO: the following will not be required when everything is refactored to AttentionModuleMixin
from .attention_processor import Attention, MochiAttention
@@ -606,8 +610,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
if backend not in available_backends:
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
backend = AttentionBackendName(backend)
_check_attention_backend_requirements(backend)
_maybe_download_kernel_for_backend(backend)
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
for module in self.modules():
@@ -237,7 +237,6 @@ class WanRotaryPosEmbed(nn.Module):
return freqs_cos, freqs_sin
# Copied from diffusers.models.transformers.sana_transformer.SanaModulatedNorm
class SanaModulatedNorm(nn.Module):
def __init__(self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6):
super().__init__()
@@ -247,7 +246,7 @@ class SanaModulatedNorm(nn.Module):
self, hidden_states: torch.Tensor, temb: torch.Tensor, scale_shift_table: torch.Tensor
) -> torch.Tensor:
hidden_states = self.norm(hidden_states)
shift, scale = (scale_shift_table[None] + temb[:, None].to(scale_shift_table.device)).chunk(2, dim=1)
shift, scale = (scale_shift_table[None, None] + temb[:, :, None].to(scale_shift_table.device)).unbind(dim=2)
hidden_states = hidden_states * (1 + scale) + shift
return hidden_states
@@ -423,8 +422,8 @@ class SanaVideoTransformerBlock(nn.Module):
# 1. Modulation
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
self.scale_shift_table[None, None] + timestep.reshape(batch_size, timestep.shape[1], 6, -1)
).unbind(dim=2)
# 2. Self Attention
norm_hidden_states = self.norm1(hidden_states)
@@ -635,13 +634,16 @@ class SanaVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
if guidance is not None:
timestep, embedded_timestep = self.time_embed(
timestep, guidance=guidance, hidden_dtype=hidden_states.dtype
timestep.flatten(), guidance=guidance, hidden_dtype=hidden_states.dtype
)
else:
timestep, embedded_timestep = self.time_embed(
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
timestep.flatten(), batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
timestep = timestep.view(batch_size, -1, timestep.size(-1))
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
+4 -1
View File
@@ -308,7 +308,10 @@ else:
"SanaSprintPipeline",
"SanaControlNetPipeline",
"SanaSprintImg2ImgPipeline",
]
_import_structure["sana_video"] = [
"SanaVideoPipeline",
"SanaImageToVideoPipeline",
]
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
@@ -749,8 +752,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
SanaPipeline,
SanaSprintImg2ImgPipeline,
SanaSprintPipeline,
SanaVideoPipeline,
)
from .sana_video import SanaImageToVideoPipeline, SanaVideoPipeline
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
from .stable_audio import StableAudioPipeline, StableAudioProjectionModel
-2
View File
@@ -26,7 +26,6 @@ else:
_import_structure["pipeline_sana_controlnet"] = ["SanaControlNetPipeline"]
_import_structure["pipeline_sana_sprint"] = ["SanaSprintPipeline"]
_import_structure["pipeline_sana_sprint_img2img"] = ["SanaSprintImg2ImgPipeline"]
_import_structure["pipeline_sana_video"] = ["SanaVideoPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -40,7 +39,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipeline_sana_controlnet import SanaControlNetPipeline
from .pipeline_sana_sprint import SanaSprintPipeline
from .pipeline_sana_sprint_img2img import SanaSprintImg2ImgPipeline
from .pipeline_sana_video import SanaVideoPipeline
else:
import sys
@@ -3,7 +3,6 @@ from typing import List, Union
import numpy as np
import PIL.Image
import torch
from ...utils import BaseOutput
@@ -20,18 +19,3 @@ class SanaPipelineOutput(BaseOutput):
"""
images: Union[List[PIL.Image.Image], np.ndarray]
@dataclass
class SanaVideoPipelineOutput(BaseOutput):
r"""
Output class for Sana-Video pipelines.
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`.
"""
frames: torch.Tensor
@@ -0,0 +1,49 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_sana_video"] = ["SanaVideoPipeline"]
_import_structure["pipeline_sana_video_i2v"] = ["SanaImageToVideoPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_sana_video import SanaVideoPipeline
from .pipeline_sana_video_i2v import SanaImageToVideoPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
@@ -0,0 +1,20 @@
from dataclasses import dataclass
import torch
from ...utils import BaseOutput
@dataclass
class SanaVideoPipelineOutput(BaseOutput):
r"""
Output class for Sana-Video pipelines.
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`.
"""
frames: torch.Tensor
@@ -95,17 +95,16 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers import SanaVideoPipeline
>>> from diffusers.utils import export_to_video
>>> model_id = "Efficient-Large-Model/SANA-Video_2B_480p_diffusers"
>>> pipe = SanaVideoPipeline.from_pretrained(model_id)
>>> pipe = SanaVideoPipeline.from_pretrained("Efficient-Large-Model/SANA-Video_2B_480p_diffusers")
>>> pipe.transformer.to(torch.bfloat16)
>>> pipe.text_encoder.to(torch.bfloat16)
>>> pipe.vae.to(torch.float32)
>>> pipe.to("cuda")
>>> model_score = 30
>>> motion_score = 30
>>> prompt = "Evening, backlight, side lighting, soft light, high contrast, mid-shot, centered composition, clean solo shot, warm color. A young Caucasian man stands in a forest, golden light glimmers on his hair as sunlight filters through the leaves. He wears a light shirt, wind gently blowing his hair and collar, light dances across his face with his movements. The background is blurred, with dappled light and soft tree shadows in the distance. The camera focuses on his lifted gaze, clear and emotional."
>>> negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
>>> motion_prompt = f" motion score: {model_score}."
>>> motion_prompt = f" motion score: {motion_score}."
>>> prompt = prompt + motion_prompt
>>> output = pipe(
@@ -231,6 +230,7 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
# Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds
def _get_gemma_prompt_embeds(
self,
prompt: Union[str, List[str]],
@@ -827,9 +827,9 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
Examples:
Returns:
[`~pipelines.sana.pipeline_output.SanaVideoPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaVideoPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list with the generated videos
[`~pipelines.sana_video.pipeline_output.SanaVideoPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.sana_video.pipeline_output.SanaVideoPipelineOutput`] is
returned, otherwise a `tuple` is returned where the first element is a list with the generated videos
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
File diff suppressed because it is too large Load Diff
+37 -33
View File
@@ -79,15 +79,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
methods the library implements for all schedulers such as loading and saving.
Args:
num_train_timesteps (`int`, defaults to 1000):
num_train_timesteps (`int`, defaults to `1000`):
The number of diffusion steps to train the model.
beta_start (`float`, defaults to 0.0001):
beta_start (`float`, defaults to `0.0001`):
The starting `beta` value of inference.
beta_end (`float`, defaults to 0.02):
beta_end (`float`, defaults to `0.02`):
The final `beta` value.
beta_schedule (`str`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model.
trained_betas (`np.ndarray`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
skip_prk_steps (`bool`, defaults to `False`):
@@ -97,14 +96,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
otherwise it uses the alpha value at step 0.
prediction_type (`str`, defaults to `epsilon`, *optional*):
prediction_type (`"epsilon"` or `"v_prediction"`, defaults to `"epsilon"`):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process)
or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf)
paper).
timestep_spacing (`str`, defaults to `"leading"`):
or `v_prediction` (see section 2.4 of [Imagen Video](https://huggingface.co/papers/2210.02303) paper).
timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"leading"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
steps_offset (`int`, defaults to 0):
steps_offset (`int`, defaults to `0`):
An offset added to the inference steps, as required by some model families.
"""
@@ -117,12 +115,12 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
skip_prk_steps: bool = False,
set_alpha_to_one: bool = False,
prediction_type: str = "epsilon",
timestep_spacing: str = "leading",
prediction_type: Literal["epsilon", "v_prediction"] = "epsilon",
timestep_spacing: Literal["linspace", "leading", "trailing"] = "leading",
steps_offset: int = 0,
):
if trained_betas is not None:
@@ -164,7 +162,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self.plms_timesteps = None
self.timesteps = None
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -243,7 +241,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
return_dict (`bool`):
return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
Returns:
@@ -276,14 +274,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
return_dict (`bool`):
return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
if self.num_inference_steps is None:
raise ValueError(
@@ -335,14 +332,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
return_dict (`bool`):
return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
if self.num_inference_steps is None:
raise ValueError(
@@ -403,19 +399,27 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
"""
return sample
def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
# See formula (9) of PNDM paper https://huggingface.co/papers/2202.09778
# this function computes x_(t−δ) using the formula of (9)
# Note that x_t needs to be added to both sides of the equation
def _get_prev_sample(
self, sample: torch.Tensor, timestep: int, prev_timestep: int, model_output: torch.Tensor
) -> torch.Tensor:
"""
Compute the previous sample x_(t-δ) from the current sample x_t using formula (9) from the [PNDM
paper](https://huggingface.co/papers/2202.09778).
# Notation (<variable name> -> <name in paper>
# alpha_prod_t -> α_t
# alpha_prod_t_prev -> α_(t−δ)
# beta_prod_t -> (1 - α_t)
# beta_prod_t_prev -> (1 - α_(t−δ))
# sample -> x_t
# model_output -> e_θ(x_t, t)
# prev_sample -> x_(t−δ)
Args:
sample (`torch.Tensor`):
The current sample x_t.
timestep (`int`):
The current timestep t.
prev_timestep (`int`):
The previous timestep (t-δ).
model_output (`torch.Tensor`):
The model output e_θ(x_t, t).
Returns:
`torch.Tensor`:
The previous sample x_(t-δ).
"""
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
@@ -489,5 +493,5 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def __len__(self):
def __len__(self) -> int:
return self.config.num_train_timesteps
-1
View File
@@ -46,7 +46,6 @@ DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0").upper() in ENV_V
DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8
HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").upper() in ENV_VARS_TRUE_VALUES
DIFFUSERS_ENABLE_HUB_KERNELS = os.environ.get("DIFFUSERS_ENABLE_HUB_KERNELS", "").upper() in ENV_VARS_TRUE_VALUES
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
@@ -2147,6 +2147,21 @@ class SanaControlNetPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class SanaImageToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class SanaPAGPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
-23
View File
@@ -1,23 +0,0 @@
from ..utils import get_logger
from .import_utils import is_kernels_available
logger = get_logger(__name__)
_DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3"
def _get_fa3_from_hub():
if not is_kernels_available():
return None
else:
from kernels import get_kernel
try:
# TODO: temporary revision for now. Remove when merged upstream into `main`.
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops-return-probs")
return flash_attn_3_hub
except Exception as e:
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
raise
+11 -15
View File
@@ -13,16 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import pytest
import torch
from transformers import AutoTokenizer, UMT5EncoderModel
from diffusers import (
AuraFlowPipeline,
AuraFlowTransformer2DModel,
FlowMatchEulerDiscreteScheduler,
)
from diffusers import AuraFlowPipeline, AuraFlowTransformer2DModel, FlowMatchEulerDiscreteScheduler
from ..testing_utils import (
floats_tensor,
@@ -40,7 +36,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class TestAuraFlowLoRA(PeftLoraLoaderMixinTests):
pipeline_class = AuraFlowPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
@@ -103,34 +99,34 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs
@unittest.skip("Not supported in AuraFlow.")
@pytest.mark.skip("Not supported in AuraFlow.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in AuraFlow.")
@pytest.mark.skip("Not supported in AuraFlow.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in AuraFlow.")
@pytest.mark.skip("Not supported in AuraFlow.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
@pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
@pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
@pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
@pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
@pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.")
def test_simple_inference_with_text_lora_save_load(self):
pass
+25 -21
View File
@@ -13,10 +13,9 @@
# limitations under the License.
import sys
import unittest
import pytest
import torch
from parameterized import parameterized
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import (
@@ -39,7 +38,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class TestCogVideoXLoRA(PeftLoraLoaderMixinTests):
pipeline_class = CogVideoXPipeline
scheduler_cls = CogVideoXDPMScheduler
scheduler_kwargs = {"timestep_spacing": "trailing"}
@@ -119,54 +118,59 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3, pipe=pipe)
def test_simple_inference_with_text_denoiser_lora_unfused(self):
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
def test_lora_scale_kwargs_match_fusion(self):
super().test_lora_scale_kwargs_match_fusion(expected_atol=9e-3, expected_rtol=9e-3)
def test_lora_scale_kwargs_match_fusion(self, base_pipe_output):
super().test_lora_scale_kwargs_match_fusion(
base_pipe_output=base_pipe_output, expected_atol=9e-3, expected_rtol=9e-3
)
@parameterized.expand([("block_level", True), ("leaf_level", False)])
@pytest.mark.parametrize(
"offload_type, use_stream",
[("block_level", True), ("leaf_level", False)],
)
@require_torch_accelerator
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
def test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname, pipe):
# TODO: We don't run the (leaf_level, True) test here that is enabled for other models.
# The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338
super()._test_group_offloading_inference_denoiser(offload_type, use_stream)
super()._test_group_offloading_inference_denoiser(offload_type, use_stream, tmpdirname, pipe)
@unittest.skip("Not supported in CogVideoX.")
@pytest.mark.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in CogVideoX.")
@pytest.mark.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in CogVideoX.")
@pytest.mark.skip("Not supported in CogVideoX.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
@pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
@pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
@pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
@pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
@pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_text_lora_save_load(self):
pass
@unittest.skip("Not supported in CogVideoX.")
@pytest.mark.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
pass
+20 -46
View File
@@ -13,12 +13,9 @@
# limitations under the License.
import sys
import tempfile
import unittest
import numpy as np
import pytest
import torch
from parameterized import parameterized
from transformers import AutoTokenizer, GlmModel
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
@@ -28,7 +25,6 @@ from ..testing_utils import (
require_peft_backend,
require_torch_accelerator,
skip_mps,
torch_device,
)
@@ -47,7 +43,7 @@ class TokenizerWrapper:
@require_peft_backend
@skip_mps
class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class TestCogView4LoRA(PeftLoraLoaderMixinTests):
pipeline_class = CogView4Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
@@ -113,72 +109,50 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self):
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
def test_simple_inference_save_pretrained(self):
"""
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
"""
components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname)
pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
pipe_from_pretrained.to(torch_device)
images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.",
)
@parameterized.expand([("block_level", True), ("leaf_level", False)])
@pytest.mark.parametrize(
"offload_type, use_stream",
[("block_level", True), ("leaf_level", False)],
)
@require_torch_accelerator
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
def test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname, pipe):
# TODO: We don't run the (leaf_level, True) test here that is enabled for other models.
# The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338
super()._test_group_offloading_inference_denoiser(offload_type, use_stream)
super()._test_group_offloading_inference_denoiser(offload_type, use_stream, tmpdirname, pipe)
@unittest.skip("Not supported in CogView4.")
@pytest.mark.skip("Not supported in CogView4.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in CogView4.")
@pytest.mark.skip("Not supported in CogView4.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in CogView4.")
@pytest.mark.skip("Not supported in CogView4.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
@pytest.mark.skip("Text encoder LoRA is not supported in CogView4.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
@pytest.mark.skip("Text encoder LoRA is not supported in CogView4.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
@pytest.mark.skip("Text encoder LoRA is not supported in CogView4.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
@pytest.mark.skip("Text encoder LoRA is not supported in CogView4.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
@pytest.mark.skip("Text encoder LoRA is not supported in CogView4.")
def test_simple_inference_with_text_lora_save_load(self):
pass
+240 -340
View File
@@ -16,13 +16,11 @@ import copy
import gc
import os
import sys
import tempfile
import unittest
import numpy as np
import pytest
import safetensors.torch
import torch
from parameterized import parameterized
from PIL import Image
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
@@ -46,14 +44,12 @@ from ..testing_utils import (
if is_peft_available():
from peft.utils import get_peft_model_state_dict
sys.path.append(".")
from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set
@require_peft_backend
class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class TestFluxLoRA(PeftLoraLoaderMixinTests):
pipeline_class = FluxPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
@@ -115,165 +111,134 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs
def test_with_alpha_in_state_dict(self):
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
def test_with_alpha_in_state_dict(self, tmpdirname, pipe):
_, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.transformer.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer"
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
with tempfile.TemporaryDirectory() as tmpdirname:
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
# modify the state dict to have alpha values following
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA/blob/main/jon_snow.safetensors
state_dict_with_alpha = safetensors.torch.load_file(
os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")
)
alpha_dict = {}
for k, v in state_dict_with_alpha.items():
# only do for `transformer` and for the k projections -- should be enough to test.
if "transformer" in k and "to_k" in k and "lora_A" in k:
alpha_dict[f"{k}.alpha"] = float(torch.randint(10, 100, size=()))
state_dict_with_alpha.update(alpha_dict)
# modify the state dict to have alpha values following
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA/blob/main/jon_snow.safetensors
state_dict_with_alpha = safetensors.torch.load_file(
os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")
)
alpha_dict = {}
for k, v in state_dict_with_alpha.items():
if "transformer" in k and "to_k" in k and ("lora_A" in k):
alpha_dict[f"{k}.alpha"] = float(torch.randint(10, 100, size=()))
state_dict_with_alpha.update(alpha_dict)
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
pipe.unload_lora_weights()
pipe.load_lora_weights(state_dict_with_alpha)
images_lora_with_alpha = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.",
assert np.allclose(images_lora, images_lora_from_pretrained, atol=0.001, rtol=0.001), (
"Loading from saved checkpoints should give same results."
)
self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3))
assert not np.allclose(images_lora_with_alpha, images_lora, atol=0.001, rtol=0.001)
def test_lora_expansion_works_for_absent_keys(self):
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
def test_lora_expansion_works_for_absent_keys(self, base_pipe_output, tmpdirname, pipe):
_, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = self.get_base_pipe_output()
# Modify the config to have a layer which won't be present in the second LoRA we will load.
modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
modified_denoiser_lora_config.target_modules.add("x_embedder")
pipe.transformer.add_adapter(modified_denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer"
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertFalse(
np.allclose(images_lora, output_no_lora, atol=1e-3, rtol=1e-3),
"LoRA should lead to different results.",
assert not np.allclose(images_lora, base_pipe_output, atol=0.001, rtol=0.001), (
"LoRA should lead to different results."
)
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one")
# Modify the state dict to exclude "x_embedder" related LoRA params.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k}
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one")
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
lora_state_dict_without_xembedder = {k: v for (k, v) in lora_state_dict.items() if "x_embedder" not in k}
pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="two")
pipe.set_adapters(["one", "two"])
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer"
images_lora_with_absent_keys = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertFalse(
np.allclose(images_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3),
"Different LoRAs should lead to different results.",
assert not np.allclose(images_lora, images_lora_with_absent_keys, atol=0.001, rtol=0.001), (
"Different LoRAs should lead to different results."
)
self.assertFalse(
np.allclose(output_no_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3),
"LoRA should lead to different results.",
assert not np.allclose(base_pipe_output, images_lora_with_absent_keys, atol=0.001, rtol=0.001), (
"LoRA should lead to different results."
)
def test_lora_expansion_works_for_extra_keys(self):
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
def test_lora_expansion_works_for_extra_keys(self, base_pipe_output, tmpdirname, pipe):
_, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = self.get_base_pipe_output()
# Modify the config to have a layer which won't be present in the first LoRA we will load.
modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
modified_denoiser_lora_config.target_modules.add("x_embedder")
pipe.transformer.add_adapter(modified_denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer"
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertFalse(
np.allclose(images_lora, output_no_lora, atol=1e-3, rtol=1e-3),
"LoRA should lead to different results.",
assert not np.allclose(images_lora, base_pipe_output, atol=0.001, rtol=0.001), (
"LoRA should lead to different results."
)
with tempfile.TemporaryDirectory() as tmpdirname:
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
pipe.unload_lora_weights()
# Modify the state dict to exclude "x_embedder" related LoRA params.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k}
pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="one")
# Load state dict with `x_embedder`.
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="two")
pipe.unload_lora_weights()
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
lora_state_dict_without_xembedder = {k: v for (k, v) in lora_state_dict.items() if "x_embedder" not in k}
pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="one")
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="two")
pipe.set_adapters(["one", "two"])
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer"
images_lora_with_extra_keys = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertFalse(
np.allclose(images_lora, images_lora_with_extra_keys, atol=1e-3, rtol=1e-3),
"Different LoRAs should lead to different results.",
assert not np.allclose(images_lora, images_lora_with_extra_keys, atol=0.001, rtol=0.001), (
"Different LoRAs should lead to different results."
)
self.assertFalse(
np.allclose(output_no_lora, images_lora_with_extra_keys, atol=1e-3, rtol=1e-3),
"LoRA should lead to different results.",
assert not np.allclose(base_pipe_output, images_lora_with_extra_keys, atol=0.001, rtol=0.001), (
"LoRA should lead to different results."
)
@unittest.skip("Not supported in Flux.")
@pytest.mark.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in Flux.")
@pytest.mark.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in Flux.")
@pytest.mark.skip("Not supported in Flux.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Not supported in Flux.")
@pytest.mark.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
pass
class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class TestFluxControlLoRA(PeftLoraLoaderMixinTests):
pipeline_class = FluxControlPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
@@ -338,12 +303,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs
def test_with_norm_in_state_dict(self):
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
def test_with_norm_in_state_dict(self, pipe):
_, _, inputs = self.get_dummy_inputs(with_generator=False)
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
@@ -364,39 +324,32 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipe.load_lora_weights(norm_state_dict)
lora_load_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
assert (
"The provided state dict contains normalization layers in addition to LoRA layers"
in cap_logger.out
)
self.assertTrue(len(pipe.transformer._transformer_norm_layers) > 0)
assert len(pipe.transformer._transformer_norm_layers) > 0
pipe.unload_lora_weights()
lora_unload_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(pipe.transformer._transformer_norm_layers is None)
self.assertTrue(np.allclose(original_output, lora_unload_output, atol=1e-5, rtol=1e-5))
self.assertFalse(
np.allclose(original_output, lora_load_output, atol=1e-6, rtol=1e-6), f"{norm_layer} is tested"
assert pipe.transformer._transformer_norm_layers is None
assert np.allclose(original_output, lora_unload_output, atol=1e-05, rtol=1e-05)
assert not np.allclose(original_output, lora_load_output, atol=1e-06, rtol=1e-06), (
f"{norm_layer} is tested"
)
with CaptureLogger(logger) as cap_logger:
for key in list(norm_state_dict.keys()):
norm_state_dict[key.replace("norm", "norm_k_something_random")] = norm_state_dict.pop(key)
pipe.load_lora_weights(norm_state_dict)
assert "Unsupported keys found in state dict when trying to load normalization layers" in cap_logger.out
self.assertTrue(
"Unsupported keys found in state dict when trying to load normalization layers" in cap_logger.out
)
def test_lora_parameter_expanded_shapes(self):
def test_lora_parameter_expanded_shapes(self, pipe):
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger.setLevel(logging.DEBUG)
@@ -405,24 +358,21 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
transformer = FluxTransformer2DModel.from_config(
components["transformer"].config, in_channels=num_channels_without_control
).to(torch_device)
self.assertTrue(
transformer.config.in_channels == num_channels_without_control,
f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
assert transformer.config.in_channels == num_channels_without_control, (
f"Expected {num_channels_without_control} channels in the modified transformer but has transformer.config.in_channels={transformer.config.in_channels!r}"
)
original_transformer_state_dict = pipe.transformer.state_dict()
x_embedder_weight = original_transformer_state_dict.pop("x_embedder.weight")
incompatible_keys = transformer.load_state_dict(original_transformer_state_dict, strict=False)
self.assertTrue(
"x_embedder.weight" in incompatible_keys.missing_keys,
"Could not find x_embedder.weight in the missing keys.",
assert "x_embedder.weight" in incompatible_keys.missing_keys, (
"Could not find x_embedder.weight in the missing keys."
)
transformer.x_embedder.weight.data.copy_(x_embedder_weight[..., :num_channels_without_control])
pipe.transformer = transformer
out_features, in_features = pipe.transformer.x_embedder.weight.shape
rank = 4
dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = {
@@ -431,15 +381,13 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
}
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001)
assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features
assert pipe.transformer.config.in_channels == 2 * in_features
assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")
# Testing opposite direction where the LoRA params are zero-padded.
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
@@ -454,15 +402,13 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
}
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001)
assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features
assert pipe.transformer.config.in_channels == 2 * in_features
assert "The following LoRA modules were zero padded to match the state dict of" in cap_logger.out
def test_normal_lora_with_expanded_lora_raises_error(self):
# Test the following situation. Load a regular LoRA (such as the ones trained on Flux.1-Dev). And then
@@ -494,32 +440,28 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
}
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
self.assertTrue(pipe.get_active_adapters() == ["adapter-1"])
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
_, _, inputs = self.get_dummy_inputs(with_generator=False)
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
assert pipe.get_active_adapters() == ["adapter-1"]
assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features
assert pipe.transformer.config.in_channels == 2 * in_features
assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")
(_, _, inputs) = self.get_dummy_inputs(with_generator=False)
lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
normal_lora_A = torch.nn.Linear(in_features, rank, bias=False)
normal_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
}
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(lora_state_dict, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
self.assertTrue(pipe.get_active_adapters() == ["adapter-2"])
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
assert "The following LoRA modules were zero padded to match the state dict of" in cap_logger.out
assert pipe.get_active_adapters() == ["adapter-2"]
lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3))
assert not np.allclose(lora_output, lora_output_2, atol=0.001, rtol=0.001)
# Test the opposite case where the first lora has the correct input features and the second lora has expanded input features.
# This should raise a runtime error on input shapes being incompatible.
@@ -540,32 +482,24 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
out_features, in_features = pipe.transformer.x_embedder.weight.shape
rank = 4
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
}
pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features)
self.assertTrue(pipe.transformer.config.in_channels == in_features)
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features
assert pipe.transformer.config.in_channels == in_features
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight,
"transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight,
}
# We should check for input shapes being incompatible here. But because above mentioned issue is
# not a supported use case, and because of the PEFT renaming, we will currently have a shape
# mismatch error.
self.assertRaisesRegex(
RuntimeError,
"size mismatch for x_embedder.lora_A.adapter-2.weight",
pipe.load_lora_weights,
lora_state_dict,
"adapter-2",
)
with pytest.raises(RuntimeError, match="size mismatch for x_embedder.lora_A.adapter-2.weight"):
pipe.load_lora_weights(lora_state_dict, "adapter-2")
def test_fuse_expanded_lora_with_regular_lora(self):
# This test checks if it works when a lora with expanded shapes (like control loras) but
@@ -597,7 +531,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight,
}
pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
_, _, inputs = self.get_dummy_inputs(with_generator=False)
lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -610,54 +544,44 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
}
pipe.load_lora_weights(lora_state_dict, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters(["adapter-1", "adapter-2"], [1.0, 1.0])
lora_output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3))
self.assertFalse(np.allclose(lora_output, lora_output_3, atol=1e-3, rtol=1e-3))
self.assertFalse(np.allclose(lora_output_2, lora_output_3, atol=1e-3, rtol=1e-3))
assert not np.allclose(lora_output, lora_output_2, atol=0.001, rtol=0.001)
assert not np.allclose(lora_output, lora_output_3, atol=0.001, rtol=0.001)
assert not np.allclose(lora_output_2, lora_output_3, atol=0.001, rtol=0.001)
pipe.fuse_lora(lora_scale=1.0, adapter_names=["adapter-1", "adapter-2"])
lora_output_4 = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(np.allclose(lora_output_3, lora_output_4, atol=1e-3, rtol=1e-3))
assert np.allclose(lora_output_3, lora_output_4, atol=0.001, rtol=0.001)
def test_load_regular_lora(self):
def test_load_regular_lora(self, base_pipe_output, pipe):
# This test checks if a regular lora (think of one trained on Flux.1 Dev for example) can be loaded
# into the transformer with more input channels than Flux.1 Dev, for example. Some examples of those
# transformers include Flux Fill, Flux Control, etc.
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
out_features, in_features = pipe.transformer.x_embedder.weight.shape
rank = 4
in_features = in_features // 2 # to mimic the Flux.1-Dev LoRA.
in_features = in_features // 2
normal_lora_A = torch.nn.Linear(in_features, rank, bias=False)
normal_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
}
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger.setLevel(logging.INFO)
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2)
self.assertFalse(np.allclose(original_output, lora_output, atol=1e-3, rtol=1e-3))
assert "The following LoRA modules were zero padded to match the state dict of" in cap_logger.out
assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2
assert not np.allclose(base_pipe_output, lora_output, atol=0.001, rtol=0.001)
def test_lora_unload_with_parameter_expanded_shapes(self):
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
@@ -670,9 +594,8 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
transformer = FluxTransformer2DModel.from_config(
components["transformer"].config, in_channels=num_channels_without_control
).to(torch_device)
self.assertTrue(
transformer.config.in_channels == num_channels_without_control,
f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
assert transformer.config.in_channels == num_channels_without_control, (
f"Expected {num_channels_without_control} channels in the modified transformer but has transformer.config.in_channels={transformer.config.in_channels!r}"
)
# This should be initialized with a Flux pipeline variant that doesn't accept `control_image`.
@@ -697,33 +620,31 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
}
with CaptureLogger(logger) as cap_logger:
control_pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
inputs["control_image"] = control_image
lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001)
assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features
assert pipe.transformer.config.in_channels == 2 * in_features
assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")
control_pipe.unload_lora_weights(reset_to_overwritten_params=True)
self.assertTrue(
control_pipe.transformer.config.in_channels == num_channels_without_control,
f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}",
assert control_pipe.transformer.config.in_channels == num_channels_without_control, (
f"Expected {num_channels_without_control} channels in the modified transformer but has control_pipe.transformer.config.in_channels={control_pipe.transformer.config.in_channels!r}"
)
loaded_pipe = FluxPipeline.from_pipe(control_pipe)
self.assertTrue(
loaded_pipe.transformer.config.in_channels == num_channels_without_control,
f"Expected {num_channels_without_control} channels in the modified transformer but has {loaded_pipe.transformer.config.in_channels=}",
assert loaded_pipe.transformer.config.in_channels == num_channels_without_control, (
f"Expected {num_channels_without_control} channels in the modified transformer but has loaded_pipe.transformer.config.in_channels={loaded_pipe.transformer.config.in_channels!r}"
)
inputs.pop("control_image")
unloaded_lora_out = loaded_pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(np.allclose(unloaded_lora_out, lora_out, rtol=1e-4, atol=1e-4))
self.assertTrue(np.allclose(unloaded_lora_out, original_out, atol=1e-4, rtol=1e-4))
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features)
self.assertTrue(pipe.transformer.config.in_channels == in_features)
assert not np.allclose(unloaded_lora_out, lora_out, rtol=0.0001, atol=0.0001)
assert np.allclose(unloaded_lora_out, original_out, atol=0.0001, rtol=0.0001)
assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features
assert pipe.transformer.config.in_channels == in_features
def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self):
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
@@ -731,14 +652,12 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger.setLevel(logging.DEBUG)
# Change the transformer config to mimic a real use case.
num_channels_without_control = 4
transformer = FluxTransformer2DModel.from_config(
components["transformer"].config, in_channels=num_channels_without_control
).to(torch_device)
self.assertTrue(
transformer.config.in_channels == num_channels_without_control,
f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
assert transformer.config.in_channels == num_channels_without_control, (
f"Expected {num_channels_without_control} channels in the modified transformer but has transformer.config.in_channels={transformer.config.in_channels!r}"
)
# This should be initialized with a Flux pipeline variant that doesn't accept `control_image`.
@@ -763,40 +682,38 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
}
with CaptureLogger(logger) as cap_logger:
control_pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
inputs["control_image"] = control_image
lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001)
assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features
assert pipe.transformer.config.in_channels == 2 * in_features
assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")
control_pipe.unload_lora_weights(reset_to_overwritten_params=False)
self.assertTrue(
control_pipe.transformer.config.in_channels == 2 * num_channels_without_control,
f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}",
assert control_pipe.transformer.config.in_channels == 2 * num_channels_without_control, (
f"Expected {num_channels_without_control} channels in the modified transformer but has control_pipe.transformer.config.in_channels={control_pipe.transformer.config.in_channels!r}"
)
no_lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
assert not np.allclose(no_lora_out, lora_out, rtol=0.0001, atol=0.0001)
assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2
assert pipe.transformer.config.in_channels == in_features * 2
self.assertFalse(np.allclose(no_lora_out, lora_out, rtol=1e-4, atol=1e-4))
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2)
self.assertTrue(pipe.transformer.config.in_channels == in_features * 2)
@unittest.skip("Not supported in Flux.")
@pytest.mark.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in Flux.")
@pytest.mark.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in Flux.")
@pytest.mark.skip("Not supported in Flux.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Not supported in Flux.")
@pytest.mark.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
pass
@@ -806,7 +723,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@require_torch_accelerator
@require_peft_backend
@require_big_accelerator
class FluxLoRAIntegrationTests(unittest.TestCase):
class TestFluxLoRAIntegration:
"""internal note: The integration slices were obtained on audace.
torch: 2.6.0.dev20241006+cu124 with CUDA 12.5. Need the same setup for the
@@ -816,33 +733,27 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
num_inference_steps = 10
seed = 0
def setUp(self):
super().setUp()
@pytest.fixture(scope="function")
def pipeline(self):
gc.collect()
backend_empty_cache(torch_device)
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(
torch_device
)
try:
yield pipe
finally:
del pipe
gc.collect()
backend_empty_cache(torch_device)
self.pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
def tearDown(self):
super().tearDown()
del self.pipeline
gc.collect()
backend_empty_cache(torch_device)
def test_flux_the_last_ben(self):
self.pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors")
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
# Instead of calling `enable_model_cpu_offload()`, we do a accelerator placement here because the CI
# run supports it. We have about 34GB RAM in the CI runner which kills the test when run with
# `enable_model_cpu_offload()`. We repeat this for the other tests, too.
self.pipeline = self.pipeline.to(torch_device)
def test_flux_the_last_ben(self, pipeline):
pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors")
pipeline.fuse_lora()
pipeline.unload_lora_weights()
pipeline = pipeline.to(torch_device)
prompt = "jon snow eating pizza with ketchup"
out = self.pipeline(
out = pipeline(
prompt,
num_inference_steps=self.num_inference_steps,
guidance_scale=4.0,
@@ -851,71 +762,57 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
).images
out_slice = out[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.1855, 0.1855, 0.1836, 0.1855, 0.1836, 0.1875, 0.1777, 0.1758, 0.2246])
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
assert max_diff < 0.001
assert max_diff < 1e-3
def test_flux_kohya(self):
self.pipeline.load_lora_weights("Norod78/brain-slug-flux")
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
self.pipeline = self.pipeline.to(torch_device)
def test_flux_kohya(self, pipeline):
pipeline.load_lora_weights("Norod78/brain-slug-flux")
pipeline.fuse_lora()
pipeline.unload_lora_weights()
pipeline = pipeline.to(torch_device)
prompt = "The cat with a brain slug earring"
out = self.pipeline(
out = pipeline(
prompt,
num_inference_steps=self.num_inference_steps,
guidance_scale=4.5,
output_type="np",
generator=torch.manual_seed(self.seed),
).images
out_slice = out[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.6367, 0.6367, 0.6328, 0.6367, 0.6328, 0.6289, 0.6367, 0.6328, 0.6484])
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
assert max_diff < 0.001
assert max_diff < 1e-3
def test_flux_kohya_with_text_encoder(self):
self.pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors")
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
self.pipeline = self.pipeline.to(torch_device)
def test_flux_kohya_with_text_encoder(self, pipeline):
pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors")
pipeline.fuse_lora()
pipeline.unload_lora_weights()
pipeline = pipeline.to(torch_device)
prompt = "optimus is cleaning the house with broomstick"
out = self.pipeline(
out = pipeline(
prompt,
num_inference_steps=self.num_inference_steps,
guidance_scale=4.5,
output_type="np",
generator=torch.manual_seed(self.seed),
).images
out_slice = out[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.4023, 0.4023, 0.4023, 0.3965, 0.3984, 0.3965, 0.3926, 0.3906, 0.4219])
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
assert max_diff < 0.001
assert max_diff < 1e-3
def test_flux_kohya_embedders_conversion(self):
def test_flux_kohya_embedders_conversion(self, pipeline):
"""Test that embedders load without throwing errors"""
self.pipeline.load_lora_weights("rockerBOO/flux-bpo-po-lora")
self.pipeline.unload_lora_weights()
assert True
def test_flux_xlabs(self):
self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
self.pipeline = self.pipeline.to(torch_device)
pipeline.load_lora_weights("rockerBOO/flux-bpo-po-lora")
pipeline.unload_lora_weights()
def test_flux_xlabs(self, pipeline):
pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
pipeline.fuse_lora()
pipeline.unload_lora_weights()
pipeline = pipeline.to(torch_device)
prompt = "A blue jay standing on a large basket of rainbow macarons, disney style"
out = self.pipeline(
out = pipeline(
prompt,
num_inference_steps=self.num_inference_steps,
guidance_scale=3.5,
@@ -923,23 +820,17 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
generator=torch.manual_seed(self.seed),
).images
out_slice = out[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.3965, 0.4180, 0.4434, 0.4082, 0.4375, 0.4590, 0.4141, 0.4375, 0.4980])
expected_slice = np.array([0.3965, 0.418, 0.4434, 0.4082, 0.4375, 0.459, 0.4141, 0.4375, 0.498])
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
assert max_diff < 0.001
assert max_diff < 1e-3
def test_flux_xlabs_load_lora_with_single_blocks(self):
self.pipeline.load_lora_weights(
"salinasr/test_xlabs_flux_lora_with_singleblocks", weight_name="lora.safetensors"
)
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
self.pipeline.enable_model_cpu_offload()
def test_flux_xlabs_load_lora_with_single_blocks(self, pipeline):
pipeline.load_lora_weights("salinasr/test_xlabs_flux_lora_with_singleblocks", weight_name="lora.safetensors")
pipeline.fuse_lora()
pipeline.unload_lora_weights()
pipeline.enable_model_cpu_offload()
prompt = "a wizard mouse playing chess"
out = self.pipeline(
out = pipeline(
prompt,
num_inference_steps=self.num_inference_steps,
guidance_scale=3.5,
@@ -951,40 +842,43 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
[0.04882812, 0.04101562, 0.04882812, 0.03710938, 0.02929688, 0.02734375, 0.0234375, 0.01757812, 0.0390625]
)
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
assert max_diff < 1e-3
assert max_diff < 0.001
@nightly
@require_torch_accelerator
@require_peft_backend
@require_big_accelerator
class FluxControlLoRAIntegrationTests(unittest.TestCase):
class TestFluxControlLoRAIntegration:
num_inference_steps = 10
seed = 0
prompt = "A robot made of exotic candies and chocolates of different kinds."
def setUp(self):
super().setUp()
@pytest.fixture(scope="function")
def pipeline(self):
gc.collect()
backend_empty_cache(torch_device)
pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(
torch_device
)
try:
yield pipe
finally:
del pipe
gc.collect()
backend_empty_cache(torch_device)
self.pipeline = FluxControlPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
).to(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
@parameterized.expand(["black-forest-labs/FLUX.1-Canny-dev-lora", "black-forest-labs/FLUX.1-Depth-dev-lora"])
def test_lora(self, lora_ckpt_id):
self.pipeline.load_lora_weights(lora_ckpt_id)
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
@pytest.mark.parametrize(
"lora_ckpt_id",
[
"black-forest-labs/FLUX.1-Canny-dev-lora",
"black-forest-labs/FLUX.1-Depth-dev-lora",
],
)
def test_lora(self, pipeline, lora_ckpt_id):
pipeline.load_lora_weights(lora_ckpt_id)
pipeline.fuse_lora()
pipeline.unload_lora_weights()
if "Canny" in lora_ckpt_id:
control_image = load_image(
@@ -995,7 +889,7 @@ class FluxControlLoRAIntegrationTests(unittest.TestCase):
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/depth_condition_image.png"
)
image = self.pipeline(
image = pipeline(
prompt=self.prompt,
control_image=control_image,
height=1024,
@@ -1016,12 +910,18 @@ class FluxControlLoRAIntegrationTests(unittest.TestCase):
assert max_diff < 1e-3
@parameterized.expand(["black-forest-labs/FLUX.1-Canny-dev-lora", "black-forest-labs/FLUX.1-Depth-dev-lora"])
def test_lora_with_turbo(self, lora_ckpt_id):
self.pipeline.load_lora_weights(lora_ckpt_id)
self.pipeline.load_lora_weights("ByteDance/Hyper-SD", weight_name="Hyper-FLUX.1-dev-8steps-lora.safetensors")
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
@pytest.mark.parametrize(
"lora_ckpt_id",
[
"black-forest-labs/FLUX.1-Canny-dev-lora",
"black-forest-labs/FLUX.1-Depth-dev-lora",
],
)
def test_lora_with_turbo(self, pipeline, lora_ckpt_id):
pipeline.load_lora_weights(lora_ckpt_id)
pipeline.load_lora_weights("ByteDance/Hyper-SD", weight_name="Hyper-FLUX.1-dev-8steps-lora.safetensors")
pipeline.fuse_lora()
pipeline.unload_lora_weights()
if "Canny" in lora_ckpt_id:
control_image = load_image(
+32 -38
View File
@@ -14,9 +14,9 @@
import gc
import sys
import unittest
import numpy as np
import pytest
import torch
from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast
@@ -48,7 +48,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
@skip_mps
class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class TestHunyuanVideoLoRA(PeftLoraLoaderMixinTests):
pipeline_class = HunyuanVideoPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
@@ -149,46 +149,41 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self):
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
# TODO(aryan): Fix the following test
@unittest.skip("This test fails with an error I haven't been able to debug yet.")
def test_simple_inference_save_pretrained(self):
pass
@unittest.skip("Not supported in HunyuanVideo.")
@pytest.mark.skip("Not supported in HunyuanVideo.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in HunyuanVideo.")
@pytest.mark.skip("Not supported in HunyuanVideo.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in HunyuanVideo.")
@pytest.mark.skip("Not supported in HunyuanVideo.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
@pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
@pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
@pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
@pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
@pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_text_lora_save_load(self):
pass
@@ -197,7 +192,7 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@require_torch_accelerator
@require_peft_backend
@require_big_accelerator
class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
class TestHunyuanVideoLoRAIntegration:
"""internal note: The integration slices were obtained on DGX.
torch: 2.5.1+cu124 with CUDA 12.5. Need the same setup for the
@@ -207,9 +202,8 @@ class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
num_inference_steps = 10
seed = 0
def setUp(self):
super().setUp()
@pytest.fixture(scope="function")
def pipeline(self):
gc.collect()
backend_empty_cache(torch_device)
@@ -217,27 +211,27 @@ class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
model_id, subfolder="transformer", torch_dtype=torch.bfloat16
)
self.pipeline = HunyuanVideoPipeline.from_pretrained(
model_id, transformer=transformer, torch_dtype=torch.float16
).to(torch_device)
pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16).to(
torch_device
)
try:
yield pipe
finally:
del pipe
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
def test_original_format_cseti(self):
self.pipeline.load_lora_weights(
def test_original_format_cseti(self, pipeline):
pipeline.load_lora_weights(
"Cseti/HunyuanVideo-LoRA-Arcane_Jinx-v1", weight_name="csetiarcane-nfjinx-v1-6000.safetensors"
)
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
self.pipeline.vae.enable_tiling()
pipeline.fuse_lora()
pipeline.unload_lora_weights()
pipeline.vae.enable_tiling()
prompt = "CSETIARCANE. A cat walks on the grass, realistic"
out = self.pipeline(
out = pipeline(
prompt=prompt,
height=320,
width=512,
+14 -14
View File
@@ -13,8 +13,8 @@
# limitations under the License.
import sys
import unittest
import pytest
import torch
from transformers import AutoTokenizer, T5EncoderModel
@@ -34,7 +34,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class TestLTXVideoLoRA(PeftLoraLoaderMixinTests):
pipeline_class = LTXPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
@@ -108,40 +108,40 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self):
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
@unittest.skip("Not supported in LTXVideo.")
@pytest.mark.skip("Not supported in LTXVideo.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in LTXVideo.")
@pytest.mark.skip("Not supported in LTXVideo.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in LTXVideo.")
@pytest.mark.skip("Not supported in LTXVideo.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
@pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
@pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
@pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
@pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
@pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.")
def test_simple_inference_with_text_lora_save_load(self):
pass
+14 -18
View File
@@ -13,7 +13,6 @@
# limitations under the License.
import sys
import unittest
import numpy as np
import pytest
@@ -36,7 +35,7 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
@require_peft_backend
class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class TestLumina2LoRA(PeftLoraLoaderMixinTests):
pipeline_class = Lumina2Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
@@ -101,35 +100,35 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs
@unittest.skip("Not supported in Lumina2.")
@pytest.mark.skip("Not supported in Lumina2.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in Lumina2.")
@pytest.mark.skip("Not supported in Lumina2.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in Lumina2.")
@pytest.mark.skip("Not supported in Lumina2.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
@pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
@pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
@pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
@pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
@pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.")
def test_simple_inference_with_text_lora_save_load(self):
pass
@@ -139,20 +138,17 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
strict=False,
)
def test_lora_fuse_nan(self):
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
def test_lora_fuse_nan(self, pipe):
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
# corrupt one LoRA weight with `inf` values
with torch.no_grad():
@@ -166,4 +162,4 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
out = pipe(**inputs)[0]
self.assertTrue(np.isnan(out).all())
assert np.isnan(out).all()
+15 -15
View File
@@ -13,8 +13,8 @@
# limitations under the License.
import sys
import unittest
import pytest
import torch
from transformers import AutoTokenizer, T5EncoderModel
@@ -34,7 +34,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
@skip_mps
class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class TestMochiLoRA(PeftLoraLoaderMixinTests):
pipeline_class = MochiPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
@@ -99,44 +99,44 @@ class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self):
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
@unittest.skip("Not supported in Mochi.")
@pytest.mark.skip("Not supported in Mochi.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in Mochi.")
@pytest.mark.skip("Not supported in Mochi.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in Mochi.")
@pytest.mark.skip("Not supported in Mochi.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
@pytest.mark.skip("Text encoder LoRA is not supported in Mochi.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
@pytest.mark.skip("Text encoder LoRA is not supported in Mochi.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
@pytest.mark.skip("Text encoder LoRA is not supported in Mochi.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
@pytest.mark.skip("Text encoder LoRA is not supported in Mochi.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
@pytest.mark.skip("Text encoder LoRA is not supported in Mochi.")
def test_simple_inference_with_text_lora_save_load(self):
pass
@unittest.skip("Not supported in CogVideoX.")
@pytest.mark.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
pass
+10 -10
View File
@@ -13,8 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import pytest
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
@@ -34,7 +34,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class TestQwenImageLoRA(PeftLoraLoaderMixinTests):
pipeline_class = QwenImagePipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
@@ -96,34 +96,34 @@ class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs
@unittest.skip("Not supported in Qwen Image.")
@pytest.mark.skip("Not supported in Qwen Image.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in Qwen Image.")
@pytest.mark.skip("Not supported in Qwen Image.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in Qwen Image.")
@pytest.mark.skip("Not supported in Qwen Image.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
@pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
@pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
@pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
@pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
@pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_text_lora_save_load(self):
pass
+11 -11
View File
@@ -13,8 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import pytest
import torch
from transformers import Gemma2Model, GemmaTokenizer
@@ -29,7 +29,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class TestSanaLoRA(PeftLoraLoaderMixinTests):
pipeline_class = SanaPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {"shift": 7.0}
@@ -105,38 +105,38 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs
@unittest.skip("Not supported in SANA.")
@pytest.mark.skip("Not supported in SANA.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Not supported in SANA.")
@pytest.mark.skip("Not supported in SANA.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in SANA.")
@pytest.mark.skip("Not supported in SANA.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Text encoder LoRA is not supported in SANA.")
@pytest.mark.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in SANA.")
@pytest.mark.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in SANA.")
@pytest.mark.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in SANA.")
@pytest.mark.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in SANA.")
@pytest.mark.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora_save_load(self):
pass
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
@pytest.mark.skipif(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
def test_layerwise_casting_inference_denoiser(self):
return super().test_layerwise_casting_inference_denoiser()
+36 -70
View File
@@ -14,9 +14,9 @@
# limitations under the License.
import gc
import sys
import unittest
import numpy as np
import pytest
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
@@ -55,7 +55,7 @@ if is_accelerate_available():
from accelerate.utils import release_memory
class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
class TestStableDiffusionLoRA(PeftLoraLoaderMixinTests):
pipeline_class = StableDiffusionPipeline
scheduler_cls = DDIMScheduler
scheduler_kwargs = {
@@ -91,16 +91,6 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
def output_shape(self):
return (1, 64, 64, 3)
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
# Keeping this test here makes sense because it doesn't look any integration
# (value assertions on logits).
@slow
@@ -114,15 +104,8 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
pipe.load_lora_weights(lora_id, adapter_name="adapter-2")
pipe = pipe.to(torch_device)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder),
"Lora not correctly set in text encoder",
)
self.assertTrue(
check_if_lora_correctly_set(pipe.unet),
"Lora not correctly set in unet",
)
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
assert check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in unet"
# We will offload the first adapter in CPU and check if the offloading
# has been performed correctly
@@ -130,35 +113,35 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
for name, module in pipe.unet.named_modules():
if "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
self.assertTrue(module.weight.device == torch.device("cpu"))
assert module.weight.device == torch.device("cpu")
elif "adapter-2" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
self.assertTrue(module.weight.device != torch.device("cpu"))
assert module.weight.device != torch.device("cpu")
for name, module in pipe.text_encoder.named_modules():
if "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
self.assertTrue(module.weight.device == torch.device("cpu"))
assert module.weight.device == torch.device("cpu")
elif "adapter-2" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
self.assertTrue(module.weight.device != torch.device("cpu"))
assert module.weight.device != torch.device("cpu")
pipe.set_lora_device(["adapter-1"], 0)
for n, m in pipe.unet.named_modules():
if "adapter-1" in n and not isinstance(m, (nn.Dropout, nn.Identity)):
self.assertTrue(m.weight.device != torch.device("cpu"))
assert m.weight.device != torch.device("cpu")
for n, m in pipe.text_encoder.named_modules():
if "adapter-1" in n and not isinstance(m, (nn.Dropout, nn.Identity)):
self.assertTrue(m.weight.device != torch.device("cpu"))
assert m.weight.device != torch.device("cpu")
pipe.set_lora_device(["adapter-1", "adapter-2"], torch_device)
for n, m in pipe.unet.named_modules():
if ("adapter-1" in n or "adapter-2" in n) and not isinstance(m, (nn.Dropout, nn.Identity)):
self.assertTrue(m.weight.device != torch.device("cpu"))
assert m.weight.device != torch.device("cpu")
for n, m in pipe.text_encoder.named_modules():
if ("adapter-1" in n or "adapter-2" in n) and not isinstance(m, (nn.Dropout, nn.Identity)):
self.assertTrue(m.weight.device != torch.device("cpu"))
assert m.weight.device != torch.device("cpu")
@slow
@require_torch_accelerator
@@ -181,15 +164,9 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder),
"Lora not correctly set in text encoder",
)
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
self.assertTrue(
check_if_lora_correctly_set(pipe.unet),
"Lora not correctly set in unet",
)
assert check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in unet"
for name, param in pipe.unet.named_parameters():
if "lora_" in name:
@@ -225,17 +202,14 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
pipe.unet.add_adapter(config1, adapter_name="adapter-1")
pipe = pipe.to(torch_device)
self.assertTrue(
check_if_lora_correctly_set(pipe.unet),
"Lora not correctly set in unet",
)
assert check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in unet"
# sanity check that the adapters don't target the same layers, otherwise the test passes even without the fix
modules_adapter_0 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-0")}
modules_adapter_1 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-1")}
self.assertNotEqual(modules_adapter_0, modules_adapter_1)
self.assertTrue(modules_adapter_0 - modules_adapter_1)
self.assertTrue(modules_adapter_1 - modules_adapter_0)
assert modules_adapter_0 != modules_adapter_1
assert modules_adapter_0 - modules_adapter_1
assert modules_adapter_1 - modules_adapter_0
# setting both separately works
pipe.set_lora_device(["adapter-0"], "cpu")
@@ -243,32 +217,30 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
for name, module in pipe.unet.named_modules():
if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
self.assertTrue(module.weight.device == torch.device("cpu"))
assert module.weight.device == torch.device("cpu")
elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
self.assertTrue(module.weight.device == torch.device("cpu"))
assert module.weight.device == torch.device("cpu")
# setting both at once also works
pipe.set_lora_device(["adapter-0", "adapter-1"], torch_device)
for name, module in pipe.unet.named_modules():
if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
self.assertTrue(module.weight.device != torch.device("cpu"))
assert module.weight.device != torch.device("cpu")
elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
self.assertTrue(module.weight.device != torch.device("cpu"))
assert module.weight.device != torch.device("cpu")
@slow
@nightly
@require_torch_accelerator
@require_peft_backend
class LoraIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
class TestSDLoraIntegration:
@pytest.fixture(autouse=True)
def _gc_and_cache_cleanup(self, torch_device):
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
yield
gc.collect()
backend_empty_cache(torch_device)
@@ -280,10 +252,7 @@ class LoraIntegrationTests(unittest.TestCase):
pipe.load_lora_weights(lora_id)
pipe = pipe.to(torch_device)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder),
"Lora not correctly set in text encoder",
)
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
prompt = "a red sks dog"
@@ -312,10 +281,7 @@ class LoraIntegrationTests(unittest.TestCase):
pipe.load_lora_weights(lora_id)
pipe = pipe.to(torch_device)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder),
"Lora not correctly set in text encoder",
)
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
prompt = "a red sks dog"
@@ -587,8 +553,8 @@ class LoraIntegrationTests(unittest.TestCase):
).images
unloaded_lora_images = unloaded_lora_images[0, -3:, -3:, -1].flatten()
self.assertFalse(np.allclose(initial_images, lora_images))
self.assertTrue(np.allclose(initial_images, unloaded_lora_images, atol=1e-3))
assert not np.allclose(initial_images, lora_images)
assert np.allclose(initial_images, unloaded_lora_images, atol=1e-3)
release_memory(pipe)
@@ -625,8 +591,8 @@ class LoraIntegrationTests(unittest.TestCase):
).images
unloaded_lora_images = unloaded_lora_images[0, -3:, -3:, -1].flatten()
self.assertFalse(np.allclose(initial_images, lora_images))
self.assertTrue(np.allclose(initial_images, unloaded_lora_images, atol=1e-3))
assert not np.allclose(initial_images, lora_images)
assert np.allclose(initial_images, unloaded_lora_images, atol=1e-3)
# make sure we can load a LoRA again after unloading and they don't have
# any undesired effects.
@@ -637,7 +603,7 @@ class LoraIntegrationTests(unittest.TestCase):
).images
lora_images_again = lora_images_again[0, -3:, -3:, -1].flatten()
self.assertTrue(np.allclose(lora_images, lora_images_again, atol=1e-3))
assert np.allclose(lora_images, lora_images_again, atol=1e-3)
release_memory(pipe)
def test_not_empty_state_dict(self):
@@ -651,7 +617,7 @@ class LoraIntegrationTests(unittest.TestCase):
lcm_lora = load_file(cached_file)
pipe.load_lora_weights(lcm_lora, adapter_name="lcm")
self.assertTrue(lcm_lora != {})
assert lcm_lora != {}
release_memory(pipe)
def test_load_unload_load_state_dict(self):
@@ -666,11 +632,11 @@ class LoraIntegrationTests(unittest.TestCase):
previous_state_dict = lcm_lora.copy()
pipe.load_lora_weights(lcm_lora, adapter_name="lcm")
self.assertDictEqual(lcm_lora, previous_state_dict)
assert lcm_lora == previous_state_dict
pipe.unload_lora_weights()
pipe.load_lora_weights(lcm_lora, adapter_name="lcm")
self.assertDictEqual(lcm_lora, previous_state_dict)
assert lcm_lora == previous_state_dict
release_memory(pipe)
+10 -12
View File
@@ -14,9 +14,9 @@
# limitations under the License.
import gc
import sys
import unittest
import numpy as np
import pytest
import torch
from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
@@ -51,7 +51,7 @@ if is_accelerate_available():
@require_peft_backend
class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class TestSD3LoRA(PeftLoraLoaderMixinTests):
pipeline_class = StableDiffusion3Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
@@ -113,19 +113,19 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
lora_filename = "lora_peft_format.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
@unittest.skip("Not supported in SD3.")
@pytest.mark.skip("Not supported in SD3.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in SD3.")
@pytest.mark.skip("Not supported in SD3.")
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
pass
@unittest.skip("Not supported in SD3.")
@pytest.mark.skip("Not supported in SD3.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in SD3.")
@pytest.mark.skip("Not supported in SD3.")
def test_modify_padding_mode(self):
pass
@@ -138,17 +138,15 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@require_torch_accelerator
@require_peft_backend
@require_big_accelerator
class SD3LoraIntegrationTests(unittest.TestCase):
class TestSD3LoraIntegration:
pipeline_class = StableDiffusion3Img2ImgPipeline
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"
def setUp(self):
super().setUp()
@pytest.fixture(autouse=True)
def _gc_and_cache_cleanup(self, torch_device):
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
yield
gc.collect()
backend_empty_cache(torch_device)
+19 -29
View File
@@ -17,9 +17,9 @@ import gc
import importlib
import sys
import time
import unittest
import numpy as np
import pytest
import torch
from packaging import version
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
@@ -59,7 +59,7 @@ if is_accelerate_available():
from accelerate.utils import release_memory
class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
class TestStableDiffusionXLLoRA(PeftLoraLoaderMixinTests):
has_two_text_encoders = True
pipeline_class = StableDiffusionXLPipeline
scheduler_cls = EulerDiscreteScheduler
@@ -104,21 +104,11 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
def output_shape(self):
return (1, 64, 64, 3)
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
@is_flaky
def test_multiple_wrong_adapter_name_raises_error(self):
super().test_multiple_wrong_adapter_name_raises_error()
def test_simple_inference_with_text_denoiser_lora_unfused(self):
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
if torch.cuda.is_available():
expected_atol = 9e-2
expected_rtol = 9e-2
@@ -127,10 +117,10 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
expected_rtol = 1e-3
super().test_simple_inference_with_text_denoiser_lora_unfused(
expected_atol=expected_atol, expected_rtol=expected_rtol
pipe=pipe, expected_atol=expected_atol, expected_rtol=expected_rtol
)
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
if torch.cuda.is_available():
expected_atol = 9e-2
expected_rtol = 9e-2
@@ -139,10 +129,10 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
expected_rtol = 1e-3
super().test_simple_inference_with_text_lora_denoiser_fused_multi(
expected_atol=expected_atol, expected_rtol=expected_rtol
pipe=pipe, expected_atol=expected_atol, expected_rtol=expected_rtol
)
def test_lora_scale_kwargs_match_fusion(self):
def test_lora_scale_kwargs_match_fusion(self, base_pipe_output):
if torch.cuda.is_available():
expected_atol = 9e-2
expected_rtol = 9e-2
@@ -150,21 +140,21 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
expected_atol = 1e-3
expected_rtol = 1e-3
super().test_lora_scale_kwargs_match_fusion(expected_atol=expected_atol, expected_rtol=expected_rtol)
super().test_lora_scale_kwargs_match_fusion(
base_pipe_output=base_pipe_output, expected_atol=expected_atol, expected_rtol=expected_rtol
)
@slow
@nightly
@require_torch_accelerator
@require_peft_backend
class LoraSDXLIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
class TestLoraSDXLIntegration:
@pytest.fixture(autouse=True)
def _gc_and_cache_cleanup(self, torch_device):
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
yield
gc.collect()
backend_empty_cache(torch_device)
@@ -383,7 +373,7 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
end_time = time.time()
elapsed_time_fusion = end_time - start_time
self.assertTrue(elapsed_time_fusion < elapsed_time_non_fusion)
assert elapsed_time_fusion < elapsed_time_non_fusion
release_memory(pipe)
@@ -439,14 +429,14 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
for key, value in text_encoder_1_sd.items():
key = remap_key(key, fused_te_state_dict)
self.assertTrue(torch.allclose(fused_te_state_dict[key], value))
assert torch.allclose(fused_te_state_dict[key], value)
for key, value in text_encoder_2_sd.items():
key = remap_key(key, fused_te_2_state_dict)
self.assertTrue(torch.allclose(fused_te_2_state_dict[key], value))
assert torch.allclose(fused_te_2_state_dict[key], value)
for key, value in unet_state_dict.items():
self.assertTrue(torch.allclose(unet_state_dict[key], value))
assert torch.allclose(unet_state_dict[key], value)
pipe.fuse_lora()
pipe.unload_lora_weights()
@@ -589,7 +579,7 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
pipe.load_lora_weights(lora_id, weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
pipe = pipe.to(torch_device)
self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
assert check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet"
prompt = "toy_face of a hacker with a hoodie"
+14 -14
View File
@@ -13,8 +13,8 @@
# limitations under the License.
import sys
import unittest
import pytest
import torch
from transformers import AutoTokenizer, T5EncoderModel
@@ -39,7 +39,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
@skip_mps
class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class TestWanLoRA(PeftLoraLoaderMixinTests):
pipeline_class = WanPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
@@ -104,40 +104,40 @@ class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self):
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
@unittest.skip("Not supported in Wan.")
@pytest.mark.skip("Not supported in Wan.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in Wan.")
@pytest.mark.skip("Not supported in Wan.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in Wan.")
@pytest.mark.skip("Not supported in Wan.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan.")
@pytest.mark.skip("Text encoder LoRA is not supported in Wan.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan.")
@pytest.mark.skip("Text encoder LoRA is not supported in Wan.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan.")
@pytest.mark.skip("Text encoder LoRA is not supported in Wan.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan.")
@pytest.mark.skip("Text encoder LoRA is not supported in Wan.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan.")
@pytest.mark.skip("Text encoder LoRA is not supported in Wan.")
def test_simple_inference_with_text_lora_save_load(self):
pass
+39 -52
View File
@@ -14,10 +14,9 @@
import os
import sys
import tempfile
import unittest
import numpy as np
import pytest
import safetensors.torch
import torch
from PIL import Image
@@ -32,7 +31,6 @@ from ..testing_utils import (
require_peft_backend,
require_peft_version_greater,
skip_mps,
torch_device,
)
@@ -47,7 +45,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
@skip_mps
@is_flaky(max_attempts=10, description="very flaky class")
class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class TestWanVACELoRA(PeftLoraLoaderMixinTests):
pipeline_class = WanVACEPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
@@ -121,56 +119,51 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self):
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
@unittest.skip("Not supported in Wan VACE.")
@pytest.mark.skip("Not supported in Wan VACE.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in Wan VACE.")
@pytest.mark.skip("Not supported in Wan VACE.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in Wan VACE.")
@pytest.mark.skip("Not supported in Wan VACE.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
@pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
@pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
@pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
@pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
@pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.")
def test_simple_inference_with_text_lora_save_load(self):
pass
def test_layerwise_casting_inference_denoiser(self):
super().test_layerwise_casting_inference_denoiser()
@require_peft_version_greater("0.13.2")
def test_lora_exclude_modules_wanvace(self):
def test_lora_exclude_modules_wanvace(self, base_pipe_output, tmpdirname, pipe):
exclude_module_name = "vace_blocks.0.proj_out"
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device)
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = self.get_base_pipe_output()
self.assertTrue(output_no_lora.shape == self.output_shape)
assert base_pipe_output.shape == self.output_shape
# only supported for `denoiser` now
denoiser_lora_config.target_modules = ["proj_out"]
@@ -180,36 +173,30 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
)
# The state dict shouldn't contain the modules to be excluded from LoRA.
state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default")
self.assertTrue(not any(exclude_module_name in k for k in state_dict_from_model))
self.assertTrue(any("proj_out" in k for k in state_dict_from_model))
assert not any(exclude_module_name in k for k in state_dict_from_model)
assert any("proj_out" in k for k in state_dict_from_model)
output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdir:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts)
pipe.unload_lora_weights()
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts)
pipe.unload_lora_weights()
# Check in the loaded state dict.
loaded_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
self.assertTrue(not any(exclude_module_name in k for k in loaded_state_dict))
self.assertTrue(any("proj_out" in k for k in loaded_state_dict))
# Check in the loaded state dict.
loaded_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
assert not any(exclude_module_name in k for k in loaded_state_dict)
assert any("proj_out" in k for k in loaded_state_dict)
# Check in the state dict obtained after loading LoRA.
pipe.load_lora_weights(tmpdir)
state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default_0")
self.assertTrue(not any(exclude_module_name in k for k in state_dict_from_model))
self.assertTrue(any("proj_out" in k for k in state_dict_from_model))
# Check in the state dict obtained after loading LoRA.
pipe.load_lora_weights(tmpdirname)
state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default_0")
assert not any(exclude_module_name in k for k in state_dict_from_model)
assert any("proj_out" in k for k in state_dict_from_model)
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3),
"LoRA should change outputs.",
)
self.assertTrue(
np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3),
"Lora outputs should match.",
)
def test_simple_inference_with_text_denoiser_lora_and_scale(self):
super().test_simple_inference_with_text_denoiser_lora_and_scale()
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert not np.allclose(base_pipe_output, output_lora_exclude_modules, atol=1e-3, rtol=1e-3), (
"LoRA should change outputs."
)
assert np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3), (
"Lora outputs should match."
)
+587 -1022
View File
File diff suppressed because it is too large Load Diff
-1
View File
@@ -7,7 +7,6 @@ To run this test suite:
```bash
export RUN_ATTENTION_BACKEND_TESTS=yes
export DIFFUSERS_ENABLE_HUB_KERNELS=yes
pytest tests/others/test_attention_backends.py
```
+27 -29
View File
@@ -21,11 +21,9 @@ import numpy as np
import pytest
import torch
from transformers import (
ClapAudioConfig,
ClapConfig,
ClapFeatureExtractor,
ClapModel,
ClapTextConfig,
GPT2Config,
GPT2LMHeadModel,
RobertaTokenizer,
@@ -111,33 +109,33 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
latent_channels=4,
)
torch.manual_seed(0)
text_branch_config = ClapTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=8,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=1,
num_hidden_layers=1,
pad_token_id=1,
vocab_size=1000,
projection_dim=8,
)
audio_branch_config = ClapAudioConfig(
spec_size=8,
window_size=4,
num_mel_bins=8,
intermediate_size=37,
layer_norm_eps=1e-05,
depths=[1, 1],
num_attention_heads=[1, 1],
num_hidden_layers=1,
hidden_size=192,
projection_dim=8,
patch_size=2,
patch_stride=2,
patch_embed_input_channels=4,
)
text_branch_config = {
"bos_token_id": 0,
"eos_token_id": 2,
"hidden_size": 8,
"intermediate_size": 37,
"layer_norm_eps": 1e-05,
"num_attention_heads": 1,
"num_hidden_layers": 1,
"pad_token_id": 1,
"vocab_size": 1000,
"projection_dim": 8,
}
audio_branch_config = {
"spec_size": 8,
"window_size": 4,
"num_mel_bins": 8,
"intermediate_size": 37,
"layer_norm_eps": 1e-05,
"depths": [1, 1],
"num_attention_heads": [1, 1],
"num_hidden_layers": 1,
"hidden_size": 192,
"projection_dim": 8,
"patch_size": 2,
"patch_stride": 2,
"patch_embed_input_channels": 4,
}
text_encoder_config = ClapConfig(
text_config=text_branch_config, audio_config=audio_branch_config, projection_dim=16
)
@@ -23,7 +23,7 @@ from diffusers import (
KandinskyV22InpaintCombinedPipeline,
)
from ...testing_utils import enable_full_determinism, require_torch_accelerator, torch_device
from ...testing_utils import enable_full_determinism, require_accelerator, require_torch_accelerator, torch_device
from ..test_pipelines_common import PipelineTesterMixin
from .test_kandinsky import Dummies
from .test_kandinsky_img2img import Dummies as Img2ImgDummies
@@ -402,6 +402,7 @@ class KandinskyV22PipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest
def test_save_load_optional_components(self):
super().test_save_load_optional_components(expected_max_difference=5e-4)
@require_accelerator
def test_sequential_cpu_offload_forward_pass(self):
super().test_sequential_cpu_offload_forward_pass(expected_max_diff=5e-4)
@@ -37,6 +37,7 @@ from ...testing_utils import (
load_image,
load_numpy,
numpy_cosine_similarity_distance,
require_accelerator,
require_torch_accelerator,
slow,
torch_device,
@@ -254,6 +255,7 @@ class KandinskyV22InpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCas
def test_save_load_optional_components(self):
super().test_save_load_optional_components(expected_max_difference=5e-4)
@require_accelerator
def test_sequential_cpu_offload_forward_pass(self):
super().test_sequential_cpu_offload_forward_pass(expected_max_diff=5e-4)
@@ -0,0 +1,238 @@
# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import tempfile
import unittest
import numpy as np
import torch
from PIL import Image
from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer
from diffusers import (
AutoencoderKLWan,
FlowMatchEulerDiscreteScheduler,
SanaImageToVideoPipeline,
SanaVideoTransformer3DModel,
)
from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
slow,
torch_device,
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class SanaImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = SanaImageToVideoPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
test_xformers_attention = False
supports_dduf = False
def get_dummy_components(self):
torch.manual_seed(0)
vae = AutoencoderKLWan(
base_dim=3,
z_dim=16,
dim_mult=[1, 1, 1, 1],
num_res_blocks=1,
temperal_downsample=[False, True, True],
)
torch.manual_seed(0)
scheduler = FlowMatchEulerDiscreteScheduler()
torch.manual_seed(0)
text_encoder_config = Gemma2Config(
head_dim=16,
hidden_size=8,
initializer_range=0.02,
intermediate_size=64,
max_position_embeddings=8192,
model_type="gemma2",
num_attention_heads=2,
num_hidden_layers=1,
num_key_value_heads=2,
vocab_size=8,
attn_implementation="eager",
)
text_encoder = Gemma2Model(text_encoder_config)
tokenizer = GemmaTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
torch.manual_seed(0)
transformer = SanaVideoTransformer3DModel(
in_channels=16,
out_channels=16,
num_attention_heads=2,
attention_head_dim=12,
num_layers=2,
num_cross_attention_heads=2,
cross_attention_head_dim=12,
cross_attention_dim=24,
caption_channels=8,
mlp_ratio=2.5,
dropout=0.0,
attention_bias=False,
sample_size=8,
patch_size=(1, 2, 2),
norm_elementwise_affine=False,
norm_eps=1e-6,
qk_norm="rms_norm_across_heads",
rope_max_seq_len=32,
)
components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
# Create a dummy image input (PIL Image)
image = Image.new("RGB", (32, 32))
inputs = {
"image": image,
"prompt": "",
"negative_prompt": "",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
"height": 32,
"width": 32,
"frames": 9,
"max_sequence_length": 16,
"output_type": "pt",
"complex_human_instruction": [],
"use_resolution_binning": False,
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (9, 3, 32, 32))
@unittest.skip("Test not supported")
def test_attention_slicing_forward_pass(self):
pass
def test_save_load_local(self, expected_max_difference=5e-4):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
torch.manual_seed(0)
output = pipe(**inputs)[0]
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, safe_serialization=False)
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
for component in pipe_loaded.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
torch.manual_seed(0)
output_loaded = pipe_loaded(**inputs)[0]
max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
self.assertLess(max_diff, expected_max_difference)
# TODO(aryan): Create a dummy gemma model with smol vocab size
@unittest.skip(
"A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
)
def test_inference_batch_consistent(self):
pass
@unittest.skip(
"A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
)
def test_inference_batch_single_identical(self):
pass
@unittest.skip("Skipping fp16 test as model is trained with bf16")
def test_float16_inference(self):
# Requires higher tolerance as model seems very sensitive to dtype
super().test_float16_inference(expected_max_diff=0.08)
@unittest.skip("Skipping fp16 test as model is trained with bf16")
def test_save_load_float16(self):
# Requires higher tolerance as model seems very sensitive to dtype
super().test_save_load_float16(expected_max_diff=0.2)
@slow
@require_torch_accelerator
class SanaVideoPipelineIntegrationTests(unittest.TestCase):
prompt = "Evening, backlight, side lighting, soft light, high contrast, mid-shot, centered composition, clean solo shot, warm color. A young Caucasian man stands in a forest."
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
@unittest.skip("TODO: test needs to be implemented")
def test_sana_video_480p(self):
pass
@@ -37,6 +37,7 @@ from ...testing_utils import (
floats_tensor,
load_image,
load_numpy,
require_accelerator,
require_torch_accelerator,
slow,
torch_device,
@@ -222,6 +223,7 @@ class StableDiffusionLatentUpscalePipelineFastTests(
def test_attention_slicing_forward_pass(self):
super().test_attention_slicing_forward_pass(expected_max_diff=7e-3)
@require_accelerator
def test_sequential_cpu_offload_forward_pass(self):
super().test_sequential_cpu_offload_forward_pass(expected_max_diff=3e-3)