Compare commits
48 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b2e62d9487 | |||
| d5da453de5 | |||
| 15370f8412 | |||
| a96b145304 | |||
| 6d8973ffe2 | |||
| ab71f3c864 | |||
| b7df4a5387 | |||
| 67dc65e2e3 | |||
| d0b66ad469 | |||
| 3579fdabf9 | |||
| 9d1f757e32 | |||
| 1afc21855e | |||
| 0c35b580fe | |||
| 01a56927f1 | |||
| a9e4883b6a | |||
| 63dd601758 | |||
| 5307ae2d5d | |||
| 799cf8de89 | |||
| 2cf6dd1d88 | |||
| 9aea015e02 | |||
| eeae0338e7 | |||
| 3c1ca869d7 | |||
| 6fe4a6ff8e | |||
| 40de88af8c | |||
| 6a2309b98d | |||
| cd3bbe2910 | |||
| 7a001c3ee2 | |||
| d8e4805816 | |||
| 44c3101685 | |||
| d6c63bb956 | |||
| 2f44d63046 | |||
| f3db38c1e7 | |||
| f5e5f34823 | |||
| 093cd3f040 | |||
| aecf0c53bf | |||
| 0c7589293b | |||
| ff263947ad | |||
| 66e6a0215f | |||
| 73906381ab | |||
| 5a47442f92 | |||
| 8f6328c4a4 | |||
| 8d45f219d0 | |||
| 0fd58c7706 | |||
| 35d703310c | |||
| b455dc94a2 | |||
| 21a03f93ef | |||
| 04f9d2bf3d | |||
| bc8fd864eb |
@@ -73,6 +73,8 @@ jobs:
|
||||
run: |
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
uv pip install pytest-reportlog
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -84,7 +86,7 @@ jobs:
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
-k "not Flax and not Onnx" \
|
||||
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
|
||||
--report-log=tests_pipeline_${{ matrix.module }}_cuda.log \
|
||||
tests/pipelines/${{ matrix.module }}
|
||||
@@ -126,6 +128,8 @@ jobs:
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
uv pip install pytest-reportlog
|
||||
- name: Environment
|
||||
run: python utils/print_env.py
|
||||
@@ -138,7 +142,7 @@ jobs:
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
-k "not Flax and not Onnx" \
|
||||
--make-reports=tests_torch_${{ matrix.module }}_cuda \
|
||||
--report-log=tests_torch_${{ matrix.module }}_cuda.log \
|
||||
tests/${{ matrix.module }}
|
||||
@@ -151,7 +155,7 @@ jobs:
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v --make-reports=examples_torch_cuda \
|
||||
--make-reports=examples_torch_cuda \
|
||||
--report-log=examples_torch_cuda.log \
|
||||
examples/
|
||||
|
||||
@@ -190,6 +194,8 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality,training]"
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -198,7 +204,7 @@ jobs:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
RUN_COMPILE: yes
|
||||
run: |
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "compile" --make-reports=tests_torch_compile_cuda tests/
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_torch_compile_cuda_failures_short.txt
|
||||
@@ -232,6 +238,8 @@ jobs:
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
uv pip install pytest-reportlog
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -281,6 +289,8 @@ jobs:
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -293,7 +303,7 @@ jobs:
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
-k "not Flax and not Onnx" \
|
||||
--make-reports=tests_torch_minimum_version_cuda \
|
||||
tests/models/test_modeling_common.py \
|
||||
tests/pipelines/test_pipelines_common.py \
|
||||
@@ -358,6 +368,8 @@ jobs:
|
||||
uv pip install ${{ join(matrix.config.additional_deps, ' ') }}
|
||||
fi
|
||||
uv pip install pytest-reportlog
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -405,6 +417,8 @@ jobs:
|
||||
run: |
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install -U bitsandbytes optimum_quanto
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
uv pip install pytest-reportlog
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -531,7 +545,7 @@ jobs:
|
||||
# HF_HOME: /System/Volumes/Data/mnt/cache
|
||||
# HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
# run: |
|
||||
# ${CONDA_RUN} pytest -n 1 -s -v --make-reports=tests_torch_mps \
|
||||
# ${CONDA_RUN} pytest -n 1 --make-reports=tests_torch_mps \
|
||||
# --report-log=tests_torch_mps.log \
|
||||
# tests/
|
||||
# - name: Failure short reports
|
||||
@@ -587,7 +601,7 @@ jobs:
|
||||
# HF_HOME: /System/Volumes/Data/mnt/cache
|
||||
# HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
# run: |
|
||||
# ${CONDA_RUN} pytest -n 1 -s -v --make-reports=tests_torch_mps \
|
||||
# ${CONDA_RUN} pytest -n 1 --make-reports=tests_torch_mps \
|
||||
# --report-log=tests_torch_mps.log \
|
||||
# tests/
|
||||
# - name: Failure short reports
|
||||
|
||||
@@ -109,7 +109,8 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
|
||||
- name: Environment
|
||||
@@ -120,7 +121,7 @@ jobs:
|
||||
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
|
||||
run: |
|
||||
pytest -n 8 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
-k "not Flax and not Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/modular_pipelines
|
||||
|
||||
|
||||
@@ -115,7 +115,8 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
|
||||
- name: Environment
|
||||
@@ -126,7 +127,7 @@ jobs:
|
||||
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
|
||||
run: |
|
||||
pytest -n 8 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
-k "not Flax and not Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/pipelines
|
||||
|
||||
@@ -134,7 +135,7 @@ jobs:
|
||||
if: ${{ matrix.config.framework == 'pytorch_models' }}
|
||||
run: |
|
||||
pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx and not Dependency" \
|
||||
-k "not Flax and not Onnx and not Dependency" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/models tests/schedulers tests/others
|
||||
|
||||
@@ -246,7 +247,8 @@ jobs:
|
||||
uv pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps
|
||||
uv pip install -U tokenizers
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -255,11 +257,11 @@ jobs:
|
||||
- name: Run fast PyTorch LoRA tests with PEFT
|
||||
run: |
|
||||
pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v \
|
||||
\
|
||||
--make-reports=tests_peft_main \
|
||||
tests/lora/
|
||||
pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v \
|
||||
\
|
||||
--make-reports=tests_models_lora_peft_main \
|
||||
tests/models/ -k "lora"
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
name: Fast GPU Tests on PR
|
||||
name: Fast GPU Tests on PR
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
@@ -71,7 +71,7 @@ jobs:
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
|
||||
setup_torch_cuda_pipeline_matrix:
|
||||
needs: [check_code_quality, check_repository_consistency]
|
||||
name: Setup Torch Pipelines CUDA Slow Tests Matrix
|
||||
@@ -131,7 +131,8 @@ jobs:
|
||||
run: |
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -149,18 +150,18 @@ jobs:
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
if [ "${{ matrix.module }}" = "ip_adapters" ]; then
|
||||
if [ "${{ matrix.module }}" = "ip_adapters" ]; then
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
-k "not Flax and not Onnx" \
|
||||
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
|
||||
tests/pipelines/${{ matrix.module }}
|
||||
else
|
||||
else
|
||||
pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx and $pattern" \
|
||||
-k "not Flax and not Onnx and $pattern" \
|
||||
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
|
||||
tests/pipelines/${{ matrix.module }}
|
||||
fi
|
||||
fi
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
@@ -201,7 +202,8 @@ jobs:
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -222,11 +224,11 @@ jobs:
|
||||
run: |
|
||||
pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})
|
||||
if [ -z "$pattern" ]; then
|
||||
pytest -n 1 -sv --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx" tests/${{ matrix.module }} \
|
||||
--make-reports=tests_torch_cuda_${{ matrix.module }}
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx" tests/${{ matrix.module }} \
|
||||
--make-reports=tests_torch_cuda_${{ matrix.module }}
|
||||
else
|
||||
pytest -n 1 -sv --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx and $pattern" tests/${{ matrix.module }} \
|
||||
--make-reports=tests_torch_cuda_${{ matrix.module }}
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx and $pattern" tests/${{ matrix.module }} \
|
||||
--make-reports=tests_torch_cuda_${{ matrix.module }}
|
||||
fi
|
||||
|
||||
- name: Failure short reports
|
||||
@@ -262,7 +264,8 @@ jobs:
|
||||
nvidia-smi
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
uv pip install -e ".[quality,training]"
|
||||
|
||||
- name: Environment
|
||||
@@ -274,7 +277,7 @@ jobs:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
run: |
|
||||
uv pip install ".[training]"
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile --make-reports=examples_torch_cuda examples/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
|
||||
@@ -76,6 +76,8 @@ jobs:
|
||||
run: |
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -86,7 +88,7 @@ jobs:
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
-k "not Flax and not Onnx" \
|
||||
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
|
||||
tests/pipelines/${{ matrix.module }}
|
||||
- name: Failure short reports
|
||||
@@ -127,6 +129,8 @@ jobs:
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -139,7 +143,7 @@ jobs:
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
-k "not Flax and not Onnx" \
|
||||
--make-reports=tests_torch_cuda_${{ matrix.module }} \
|
||||
tests/${{ matrix.module }}
|
||||
|
||||
@@ -178,6 +182,8 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality,training]"
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -186,7 +192,7 @@ jobs:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
RUN_COMPILE: yes
|
||||
run: |
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "compile" --make-reports=tests_torch_compile_cuda tests/
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_torch_compile_cuda_failures_short.txt
|
||||
@@ -227,7 +233,7 @@ jobs:
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
run: |
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_torch_xformers_cuda_failures_short.txt
|
||||
@@ -270,7 +276,7 @@ jobs:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
run: |
|
||||
uv pip install ".[training]"
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile --make-reports=examples_torch_cuda examples/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
|
||||
@@ -70,7 +70,7 @@ jobs:
|
||||
if: ${{ matrix.config.framework == 'pytorch' }}
|
||||
run: |
|
||||
pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
-k "not Flax and not Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ jobs:
|
||||
HF_HOME: /System/Volumes/Data/mnt/cache
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
run: |
|
||||
${CONDA_RUN} python -m pytest -n 0 -s -v --make-reports=tests_torch_mps tests/
|
||||
${CONDA_RUN} python -m pytest -n 0 --make-reports=tests_torch_mps tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
|
||||
@@ -84,7 +84,7 @@ jobs:
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
-k "not Flax and not Onnx" \
|
||||
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
|
||||
tests/pipelines/${{ matrix.module }}
|
||||
- name: Failure short reports
|
||||
@@ -137,7 +137,7 @@ jobs:
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
-k "not Flax and not Onnx" \
|
||||
--make-reports=tests_torch_${{ matrix.module }}_cuda \
|
||||
tests/${{ matrix.module }}
|
||||
|
||||
@@ -187,7 +187,7 @@ jobs:
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
-k "not Flax and not Onnx" \
|
||||
--make-reports=tests_torch_minimum_cuda \
|
||||
tests/models/test_modeling_common.py \
|
||||
tests/pipelines/test_pipelines_common.py \
|
||||
@@ -240,7 +240,7 @@ jobs:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
RUN_COMPILE: yes
|
||||
run: |
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "compile" --make-reports=tests_torch_compile_cuda tests/
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_torch_compile_cuda_failures_short.txt
|
||||
@@ -281,7 +281,7 @@ jobs:
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
run: |
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_torch_xformers_cuda_failures_short.txt
|
||||
@@ -326,7 +326,7 @@ jobs:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
run: |
|
||||
uv pip install ".[training]"
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile --make-reports=examples_torch_cuda examples/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
|
||||
@@ -22,6 +22,8 @@
|
||||
title: Reproducibility
|
||||
- local: using-diffusers/schedulers
|
||||
title: Schedulers
|
||||
- local: using-diffusers/automodel
|
||||
title: AutoModel
|
||||
- local: using-diffusers/other-formats
|
||||
title: Model formats
|
||||
- local: using-diffusers/push_to_hub
|
||||
@@ -119,6 +121,8 @@
|
||||
title: ComponentsManager
|
||||
- local: modular_diffusers/guiders
|
||||
title: Guiders
|
||||
- local: modular_diffusers/custom_blocks
|
||||
title: Building Custom Blocks
|
||||
title: Modular Diffusers
|
||||
- isExpanded: false
|
||||
sections:
|
||||
@@ -329,6 +333,8 @@
|
||||
title: BriaTransformer2DModel
|
||||
- local: api/models/chroma_transformer
|
||||
title: ChromaTransformer2DModel
|
||||
- local: api/models/chronoedit_transformer_3d
|
||||
title: ChronoEditTransformer3DModel
|
||||
- local: api/models/cogvideox_transformer3d
|
||||
title: CogVideoXTransformer3DModel
|
||||
- local: api/models/cogview3plus_transformer2d
|
||||
@@ -385,6 +391,8 @@
|
||||
title: Transformer2DModel
|
||||
- local: api/models/transformer_temporal
|
||||
title: TransformerTemporalModel
|
||||
- local: api/models/wan_animate_transformer_3d
|
||||
title: WanAnimateTransformer3DModel
|
||||
- local: api/models/wan_transformer_3d
|
||||
title: WanTransformer3DModel
|
||||
title: Transformers
|
||||
@@ -446,6 +454,8 @@
|
||||
- sections:
|
||||
- local: api/pipelines/overview
|
||||
title: Overview
|
||||
- local: api/pipelines/auto_pipeline
|
||||
title: AutoPipeline
|
||||
- sections:
|
||||
- local: api/pipelines/audioldm
|
||||
title: AudioLDM
|
||||
@@ -458,8 +468,6 @@
|
||||
- local: api/pipelines/stable_audio
|
||||
title: Stable Audio
|
||||
title: Audio
|
||||
- local: api/pipelines/auto_pipeline
|
||||
title: AutoPipeline
|
||||
- sections:
|
||||
- local: api/pipelines/amused
|
||||
title: aMUSEd
|
||||
@@ -523,6 +531,8 @@
|
||||
title: HiDream-I1
|
||||
- local: api/pipelines/hunyuandit
|
||||
title: Hunyuan-DiT
|
||||
- local: api/pipelines/hunyuanimage21
|
||||
title: HunyuanImage2.1
|
||||
- local: api/pipelines/pix2pix
|
||||
title: InstructPix2Pix
|
||||
- local: api/pipelines/kandinsky
|
||||
@@ -628,14 +638,14 @@
|
||||
- sections:
|
||||
- local: api/pipelines/allegro
|
||||
title: Allegro
|
||||
- local: api/pipelines/chronoedit
|
||||
title: ChronoEdit
|
||||
- local: api/pipelines/cogvideox
|
||||
title: CogVideoX
|
||||
- local: api/pipelines/consisid
|
||||
title: ConsisID
|
||||
- local: api/pipelines/framepack
|
||||
title: Framepack
|
||||
- local: api/pipelines/hunyuanimage21
|
||||
title: HunyuanImage2.1
|
||||
- local: api/pipelines/hunyuan_video
|
||||
title: HunyuanVideo
|
||||
- local: api/pipelines/i2vgenxl
|
||||
|
||||
@@ -12,15 +12,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# AutoModel
|
||||
|
||||
The `AutoModel` is designed to make it easy to load a checkpoint without needing to know the specific model class. `AutoModel` automatically retrieves the correct model class from the checkpoint `config.json` file.
|
||||
|
||||
```python
|
||||
from diffusers import AutoModel, AutoPipelineForText2Image
|
||||
|
||||
unet = AutoModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet")
|
||||
pipe = AutoPipelineForText2Image.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", unet=unet)
|
||||
```
|
||||
|
||||
[`AutoModel`] automatically retrieves the correct model class from the checkpoint `config.json` file.
|
||||
|
||||
## AutoModel
|
||||
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
<!-- Copyright 2025 The ChronoEdit Team and HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# ChronoEditTransformer3DModel
|
||||
|
||||
A Diffusion Transformer model for 3D video-like data from [ChronoEdit: Towards Temporal Reasoning for Image Editing and World Simulation](https://huggingface.co/papers/2510.04290) from NVIDIA and University of Toronto, by Jay Zhangjie Wu, Xuanchi Ren, Tianchang Shen, Tianshi Cao, Kai He, Yifan Lu, Ruiyuan Gao, Enze Xie, Shiyi Lan, Jose M. Alvarez, Jun Gao, Sanja Fidler, Zian Wang, Huan Ling.
|
||||
|
||||
> **TL;DR:** ChronoEdit reframes image editing as a video generation task, using input and edited images as start/end frames to leverage pretrained video models with temporal consistency. A temporal reasoning stage introduces reasoning tokens to ensure physically plausible edits and visualize the editing trajectory.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import ChronoEditTransformer3DModel
|
||||
|
||||
transformer = ChronoEditTransformer3DModel.from_pretrained("nvidia/ChronoEdit-14B-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## ChronoEditTransformer3DModel
|
||||
|
||||
[[autodoc]] ChronoEditTransformer3DModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
|
||||
@@ -0,0 +1,30 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# WanAnimateTransformer3DModel
|
||||
|
||||
A Diffusion Transformer model for 3D video-like data was introduced in [Wan Animate](https://github.com/Wan-Video/Wan2.2) by the Alibaba Wan Team.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import WanAnimateTransformer3DModel
|
||||
|
||||
transformer = WanAnimateTransformer3DModel.from_pretrained("Wan-AI/Wan2.2-Animate-14B-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## WanAnimateTransformer3DModel
|
||||
|
||||
[[autodoc]] WanAnimateTransformer3DModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
|
||||
@@ -0,0 +1,156 @@
|
||||
<!-- Copyright 2025 The ChronoEdit Team and HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<a href="https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference" target="_blank" rel="noopener">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
# ChronoEdit
|
||||
|
||||
[ChronoEdit: Towards Temporal Reasoning for Image Editing and World Simulation](https://huggingface.co/papers/2510.04290) from NVIDIA and University of Toronto, by Jay Zhangjie Wu, Xuanchi Ren, Tianchang Shen, Tianshi Cao, Kai He, Yifan Lu, Ruiyuan Gao, Enze Xie, Shiyi Lan, Jose M. Alvarez, Jun Gao, Sanja Fidler, Zian Wang, Huan Ling.
|
||||
|
||||
> **TL;DR:** ChronoEdit reframes image editing as a video generation task, using input and edited images as start/end frames to leverage pretrained video models with temporal consistency. A temporal reasoning stage introduces reasoning tokens to ensure physically plausible edits and visualize the editing trajectory.
|
||||
|
||||
*Recent advances in large generative models have greatly enhanced both image editing and in-context image generation, yet a critical gap remains in ensuring physical consistency, where edited objects must remain coherent. This capability is especially vital for world simulation related tasks. In this paper, we present ChronoEdit, a framework that reframes image editing as a video generation problem. First, ChronoEdit treats the input and edited images as the first and last frames of a video, allowing it to leverage large pretrained video generative models that capture not only object appearance but also the implicit physics of motion and interaction through learned temporal consistency. Second, ChronoEdit introduces a temporal reasoning stage that explicitly performs editing at inference time. Under this setting, target frame is jointly denoised with reasoning tokens to imagine a plausible editing trajectory that constrains the solution space to physically viable transformations. The reasoning tokens are then dropped after a few steps to avoid the high computational cost of rendering a full video. To validate ChronoEdit, we introduce PBench-Edit, a new benchmark of image-prompt pairs for contexts that require physical consistency, and demonstrate that ChronoEdit surpasses state-of-the-art baselines in both visual fidelity and physical plausibility. Project page for code and models: [this https URL](https://research.nvidia.com/labs/toronto-ai/chronoedit).*
|
||||
|
||||
The ChronoEdit pipeline is developed by the ChronoEdit Team. The original code is available on [GitHub](https://github.com/nv-tlabs/ChronoEdit), and pretrained models can be found in the [nvidia/ChronoEdit](https://huggingface.co/collections/nvidia/chronoedit) collection on Hugging Face.
|
||||
|
||||
|
||||
### Image Editing
|
||||
|
||||
```py
|
||||
import torch
|
||||
import numpy as np
|
||||
from diffusers import AutoencoderKLWan, ChronoEditTransformer3DModel, ChronoEditPipeline
|
||||
from diffusers.utils import export_to_video, load_image
|
||||
from transformers import CLIPVisionModel
|
||||
from PIL import Image
|
||||
|
||||
model_id = "nvidia/ChronoEdit-14B-Diffusers"
|
||||
image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
|
||||
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
transformer = ChronoEditTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
pipe = ChronoEditPipeline.from_pretrained(model_id, image_encoder=image_encoder, transformer=transformer, vae=vae, torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/spaces/nvidia/ChronoEdit/resolve/main/examples/3.png"
|
||||
)
|
||||
max_area = 720 * 1280
|
||||
aspect_ratio = image.height / image.width
|
||||
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
|
||||
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
||||
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
||||
print("width", width, "height", height)
|
||||
image = image.resize((width, height))
|
||||
prompt = (
|
||||
"The user wants to transform the image by adding a small, cute mouse sitting inside the floral teacup, enjoying a spa bath. The mouse should appear relaxed and cheerful, with a tiny white bath towel draped over its head like a turban. It should be positioned comfortably in the cup’s liquid, with gentle steam rising around it to blend with the cozy atmosphere. "
|
||||
"The mouse’s pose should be natural—perhaps sitting upright with paws resting lightly on the rim or submerged in the tea. The teacup’s floral design, gold trim, and warm lighting must remain unchanged to preserve the original aesthetic. The steam should softly swirl around the mouse, enhancing the spa-like, whimsical mood."
|
||||
)
|
||||
|
||||
output = pipe(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=5,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=5.0,
|
||||
enable_temporal_reasoning=False,
|
||||
num_temporal_reasoning_steps=0,
|
||||
).frames[0]
|
||||
Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output.png")
|
||||
```
|
||||
|
||||
Optionally, enable **temporal reasoning** for improved physical consistency:
|
||||
```py
|
||||
output = pipe(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=29,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=5.0,
|
||||
enable_temporal_reasoning=True,
|
||||
num_temporal_reasoning_steps=50,
|
||||
).frames[0]
|
||||
export_to_video(output, "output.mp4", fps=16)
|
||||
Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output.png")
|
||||
```
|
||||
|
||||
### Inference with 8-Step Distillation Lora
|
||||
|
||||
```py
|
||||
import torch
|
||||
import numpy as np
|
||||
from diffusers import AutoencoderKLWan, ChronoEditTransformer3DModel, ChronoEditPipeline
|
||||
from diffusers.utils import export_to_video, load_image
|
||||
from transformers import CLIPVisionModel
|
||||
from PIL import Image
|
||||
|
||||
model_id = "nvidia/ChronoEdit-14B-Diffusers"
|
||||
image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
|
||||
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
transformer = ChronoEditTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
pipe = ChronoEditPipeline.from_pretrained(model_id, image_encoder=image_encoder, transformer=transformer, vae=vae, torch_dtype=torch.bfloat16)
|
||||
lora_path = hf_hub_download(repo_id=model_id, filename="lora/chronoedit_distill_lora.safetensors")
|
||||
pipe.load_lora_weights(lora_path)
|
||||
pipe.fuse_lora(lora_scale=1.0)
|
||||
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=2.0)
|
||||
pipe.to("cuda")
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/spaces/nvidia/ChronoEdit/resolve/main/examples/3.png"
|
||||
)
|
||||
max_area = 720 * 1280
|
||||
aspect_ratio = image.height / image.width
|
||||
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
|
||||
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
||||
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
||||
print("width", width, "height", height)
|
||||
image = image.resize((width, height))
|
||||
prompt = (
|
||||
"The user wants to transform the image by adding a small, cute mouse sitting inside the floral teacup, enjoying a spa bath. The mouse should appear relaxed and cheerful, with a tiny white bath towel draped over its head like a turban. It should be positioned comfortably in the cup’s liquid, with gentle steam rising around it to blend with the cozy atmosphere. "
|
||||
"The mouse’s pose should be natural—perhaps sitting upright with paws resting lightly on the rim or submerged in the tea. The teacup’s floral design, gold trim, and warm lighting must remain unchanged to preserve the original aesthetic. The steam should softly swirl around the mouse, enhancing the spa-like, whimsical mood."
|
||||
)
|
||||
|
||||
output = pipe(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=5,
|
||||
num_inference_steps=8,
|
||||
guidance_scale=1.0,
|
||||
enable_temporal_reasoning=False,
|
||||
num_temporal_reasoning_steps=0,
|
||||
).frames[0]
|
||||
export_to_video(output, "output.mp4", fps=16)
|
||||
Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output.png")
|
||||
```
|
||||
|
||||
## ChronoEditPipeline
|
||||
|
||||
[[autodoc]] ChronoEditPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## ChronoEditPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.chronoedit.pipeline_output.ChronoEditPipelineOutput
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# SanaVideoPipeline
|
||||
# Sana-Video
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
@@ -37,6 +37,85 @@ Refer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-vi
|
||||
|
||||
Note: The recommended dtype mentioned is for the transformer weights. The text encoder and VAE weights must stay in `torch.bfloat16` or `torch.float32` for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype.
|
||||
|
||||
|
||||
## Generation Pipelines
|
||||
|
||||
<hfoptions id="generation pipelines">`
|
||||
<hfoption id="Text-to-Video">
|
||||
|
||||
The example below demonstrates how to use the text-to-video pipeline to generate a video using a text descriptio and a starting frame.
|
||||
|
||||
```python
|
||||
model_id =
|
||||
pipe = SanaVideoPipeline.from_pretrained("Efficient-Large-Model/SANA-Video_2B_480p_diffusers", torch_dtype=torch.bfloat16)
|
||||
pipe.text_encoder.to(torch.bfloat16)
|
||||
pipe.vae.to(torch.float32)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
|
||||
negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
|
||||
motion_scale = 30
|
||||
motion_prompt = f" motion score: {motion_scale}."
|
||||
prompt = prompt + motion_prompt
|
||||
|
||||
video = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=480,
|
||||
width=832,
|
||||
frames=81,
|
||||
guidance_scale=6,
|
||||
num_inference_steps=50,
|
||||
generator=torch.Generator(device="cuda").manual_seed(0),
|
||||
).frames[0]
|
||||
|
||||
export_to_video(video, "sana_video.mp4", fps=16)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Image-to-Video">
|
||||
|
||||
The example below demonstrates how to use the image-to-video pipeline to generate a video using a text descriptio and a starting frame.
|
||||
|
||||
```python
|
||||
model_id = "Efficient-Large-Model/SANA-Video_2B_480p_diffusers"
|
||||
pipe = SanaImageToVideoPipeline.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
|
||||
pipe.vae.to(torch.float32)
|
||||
pipe.text_encoder.to(torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
image = load_image("https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/samples/i2v-1.png")
|
||||
prompt = "A woman stands against a stunning sunset backdrop, her long, wavy brown hair gently blowing in the breeze. She wears a sleeveless, light-colored blouse with a deep V-neckline, which accentuates her graceful posture. The warm hues of the setting sun cast a golden glow across her face and hair, creating a serene and ethereal atmosphere. The background features a blurred landscape with soft, rolling hills and scattered clouds, adding depth to the scene. The camera remains steady, capturing the tranquil moment from a medium close-up angle."
|
||||
negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
|
||||
motion_scale = 30
|
||||
motion_prompt = f" motion score: {motion_scale}."
|
||||
prompt = prompt + motion_prompt
|
||||
|
||||
motion_scale = 30.0
|
||||
|
||||
video = pipe(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=480,
|
||||
width=832,
|
||||
frames=81,
|
||||
guidance_scale=6,
|
||||
num_inference_steps=50,
|
||||
generator=torch.Generator(device="cuda").manual_seed(0),
|
||||
).frames[0]
|
||||
|
||||
export_to_video(video, "sana-i2v.mp4", fps=16)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
|
||||
## Quantization
|
||||
|
||||
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
|
||||
@@ -97,6 +176,13 @@ export_to_video(output, "sana-video-output.mp4", fps=16)
|
||||
- __call__
|
||||
|
||||
|
||||
## SanaImageToVideoPipeline
|
||||
|
||||
[[autodoc]] SanaImageToVideoPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
## SanaVideoPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.sana.pipeline_sana_video.SanaVideoPipelineOutput
|
||||
[[autodoc]] pipelines.sana_video.pipeline_sana_video.SanaVideoPipelineOutput
|
||||
|
||||
@@ -40,6 +40,7 @@ The following Wan models are supported in Diffusers:
|
||||
- [Wan 2.2 T2V 14B](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B-Diffusers)
|
||||
- [Wan 2.2 I2V 14B](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers)
|
||||
- [Wan 2.2 TI2V 5B](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B-Diffusers)
|
||||
- [Wan 2.2 Animate 14B](https://huggingface.co/Wan-AI/Wan2.2-Animate-14B-Diffusers)
|
||||
|
||||
> [!TIP]
|
||||
> Click on the Wan models in the right sidebar for more examples of video generation.
|
||||
@@ -95,15 +96,15 @@ pipeline = WanPipeline.from_pretrained(
|
||||
pipeline.to("cuda")
|
||||
|
||||
prompt = """
|
||||
The camera rushes from far to near in a low-angle shot,
|
||||
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
|
||||
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
|
||||
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
|
||||
The camera rushes from far to near in a low-angle shot,
|
||||
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
|
||||
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
|
||||
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
|
||||
shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
|
||||
"""
|
||||
negative_prompt = """
|
||||
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
|
||||
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
|
||||
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
|
||||
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
|
||||
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
|
||||
"""
|
||||
|
||||
@@ -150,15 +151,15 @@ pipeline.transformer = torch.compile(
|
||||
)
|
||||
|
||||
prompt = """
|
||||
The camera rushes from far to near in a low-angle shot,
|
||||
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
|
||||
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
|
||||
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
|
||||
The camera rushes from far to near in a low-angle shot,
|
||||
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
|
||||
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
|
||||
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
|
||||
shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
|
||||
"""
|
||||
negative_prompt = """
|
||||
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
|
||||
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
|
||||
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
|
||||
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
|
||||
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
|
||||
"""
|
||||
|
||||
@@ -249,6 +250,208 @@ The code snippets available in [this](https://github.com/huggingface/diffusers/p
|
||||
|
||||
The general rule of thumb to keep in mind when preparing inputs for the VACE pipeline is that the input images, or frames of a video that you want to use for conditioning, should have a corresponding mask that is black in color. The black mask signifies that the model will not generate new content for that area, and only use those parts for conditioning the generation process. For parts/frames that should be generated by the model, the mask should be white in color.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Wan-Animate: Unified Character Animation and Replacement with Holistic Replication
|
||||
|
||||
[Wan-Animate](https://huggingface.co/papers/2509.14055) by the Wan Team.
|
||||
|
||||
*We introduce Wan-Animate, a unified framework for character animation and replacement. Given a character image and a reference video, Wan-Animate can animate the character by precisely replicating the expressions and movements of the character in the video to generate high-fidelity character videos. Alternatively, it can integrate the animated character into the reference video to replace the original character, replicating the scene's lighting and color tone to achieve seamless environmental integration. Wan-Animate is built upon the Wan model. To adapt it for character animation tasks, we employ a modified input paradigm to differentiate between reference conditions and regions for generation. This design unifies multiple tasks into a common symbolic representation. We use spatially-aligned skeleton signals to replicate body motion and implicit facial features extracted from source images to reenact expressions, enabling the generation of character videos with high controllability and expressiveness. Furthermore, to enhance environmental integration during character replacement, we develop an auxiliary Relighting LoRA. This module preserves the character's appearance consistency while applying the appropriate environmental lighting and color tone. Experimental results demonstrate that Wan-Animate achieves state-of-the-art performance. We are committed to open-sourcing the model weights and its source code.*
|
||||
|
||||
The project page: https://humanaigc.github.io/wan-animate
|
||||
|
||||
This model was mostly contributed by [M. Tolga Cangöz](https://github.com/tolgacangoz).
|
||||
|
||||
#### Usage
|
||||
|
||||
The Wan-Animate pipeline supports two modes of operation:
|
||||
|
||||
1. **Animation Mode** (default): Animates a character image based on motion and expression from reference videos
|
||||
2. **Replacement Mode**: Replaces a character in a background video with a new character while preserving the scene
|
||||
|
||||
##### Prerequisites
|
||||
|
||||
Before using the pipeline, you need to preprocess your reference video to extract:
|
||||
- **Pose video**: Contains skeletal keypoints representing body motion
|
||||
- **Face video**: Contains facial feature representations for expression control
|
||||
|
||||
For replacement mode, you additionally need:
|
||||
- **Background video**: The original video containing the scene
|
||||
- **Mask video**: A mask indicating where to generate content (white) vs. preserve original (black)
|
||||
|
||||
> [!NOTE]
|
||||
> Raw videos should not be used for inputs such as `pose_video`, which the pipeline expects to be preprocessed to extract the proper information. Preprocessing scripts to prepare these inputs are available in the [original Wan-Animate repository](https://github.com/Wan-Video/Wan2.2?tab=readme-ov-file#1-preprocessing). Integration of these preprocessing steps into Diffusers is planned for a future release.
|
||||
|
||||
The example below demonstrates how to use the Wan-Animate pipeline:
|
||||
|
||||
<hfoptions id="Animate usage">
|
||||
<hfoption id="Animation mode">
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import AutoencoderKLWan, WanAnimatePipeline
|
||||
from diffusers.utils import export_to_video, load_image, load_video
|
||||
|
||||
model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"
|
||||
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
pipe = WanAnimatePipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
# Load character image and preprocessed videos
|
||||
image = load_image("path/to/character.jpg")
|
||||
pose_video = load_video("path/to/pose_video.mp4") # Preprocessed skeletal keypoints
|
||||
face_video = load_video("path/to/face_video.mp4") # Preprocessed facial features
|
||||
|
||||
# Resize image to match VAE constraints
|
||||
def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
|
||||
aspect_ratio = image.height / image.width
|
||||
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
|
||||
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
||||
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
||||
image = image.resize((width, height))
|
||||
return image, height, width
|
||||
|
||||
image, height, width = aspect_ratio_resize(image, pipe)
|
||||
|
||||
prompt = "A person dancing energetically in a studio with dynamic lighting and professional camera work"
|
||||
negative_prompt = "blurry, low quality, distorted, deformed, static, poorly drawn"
|
||||
|
||||
# Generate animated video
|
||||
output = pipe(
|
||||
image=image,
|
||||
pose_video=pose_video,
|
||||
face_video=face_video,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
segment_frame_length=77,
|
||||
guidance_scale=1.0,
|
||||
mode="animate", # Animation mode (default)
|
||||
).frames[0]
|
||||
export_to_video(output, "animated_character.mp4", fps=30)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Replacement mode">
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import AutoencoderKLWan, WanAnimatePipeline
|
||||
from diffusers.utils import export_to_video, load_image, load_video
|
||||
|
||||
model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"
|
||||
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
pipe = WanAnimatePipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
# Load all required inputs for replacement mode
|
||||
image = load_image("path/to/new_character.jpg")
|
||||
pose_video = load_video("path/to/pose_video.mp4") # Preprocessed skeletal keypoints
|
||||
face_video = load_video("path/to/face_video.mp4") # Preprocessed facial features
|
||||
background_video = load_video("path/to/background_video.mp4") # Original scene
|
||||
mask_video = load_video("path/to/mask_video.mp4") # Black: preserve, White: generate
|
||||
|
||||
# Resize image to match video dimensions
|
||||
def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
|
||||
aspect_ratio = image.height / image.width
|
||||
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
|
||||
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
||||
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
||||
image = image.resize((width, height))
|
||||
return image, height, width
|
||||
|
||||
image, height, width = aspect_ratio_resize(image, pipe)
|
||||
|
||||
prompt = "A person seamlessly integrated into the scene with consistent lighting and environment"
|
||||
negative_prompt = "blurry, low quality, inconsistent lighting, floating, disconnected from scene"
|
||||
|
||||
# Replace character in background video
|
||||
output = pipe(
|
||||
image=image,
|
||||
pose_video=pose_video,
|
||||
face_video=face_video,
|
||||
background_video=background_video,
|
||||
mask_video=mask_video,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
segment_frame_lengths=77,
|
||||
guidance_scale=1.0,
|
||||
mode="replace", # Replacement mode
|
||||
).frames[0]
|
||||
export_to_video(output, "character_replaced.mp4", fps=30)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Advanced options">
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import AutoencoderKLWan, WanAnimatePipeline
|
||||
from diffusers.utils import export_to_video, load_image, load_video
|
||||
|
||||
model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"
|
||||
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
pipe = WanAnimatePipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
image = load_image("path/to/character.jpg")
|
||||
pose_video = load_video("path/to/pose_video.mp4")
|
||||
face_video = load_video("path/to/face_video.mp4")
|
||||
|
||||
def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
|
||||
aspect_ratio = image.height / image.width
|
||||
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
|
||||
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
||||
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
||||
image = image.resize((width, height))
|
||||
return image, height, width
|
||||
|
||||
image, height, width = aspect_ratio_resize(image, pipe)
|
||||
|
||||
prompt = "A person dancing energetically in a studio"
|
||||
negative_prompt = "blurry, low quality"
|
||||
|
||||
# Advanced: Use temporal guidance and custom callback
|
||||
def callback_fn(pipe, step_index, timestep, callback_kwargs):
|
||||
# You can modify latents or other tensors here
|
||||
print(f"Step {step_index}, Timestep {timestep}")
|
||||
return callback_kwargs
|
||||
|
||||
output = pipe(
|
||||
image=image,
|
||||
pose_video=pose_video,
|
||||
face_video=face_video,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
segment_frame_length=77,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=5.0,
|
||||
prev_segment_conditioning_frames=5, # Use 5 frames for temporal guidance (1 or 5 recommended)
|
||||
callback_on_step_end=callback_fn,
|
||||
callback_on_step_end_tensor_inputs=["latents"],
|
||||
).frames[0]
|
||||
export_to_video(output, "animated_advanced.mp4", fps=30)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
#### Key Parameters
|
||||
|
||||
- **mode**: Choose between `"animate"` (default) or `"replace"`
|
||||
- **prev_segment_conditioning_frames**: Number of frames for temporal guidance (1 or 5 recommended). Using 5 provides better temporal consistency but requires more memory
|
||||
- **guidance_scale**: Controls how closely the output follows the text prompt. Higher values (5-7) produce results more aligned with the prompt. For Wan-Animate, CFG is disabled by default (`guidance_scale=1.0`) but can be enabled to support negative prompts and finer control over facial expressions. (Note that CFG will only target the text prompt and face conditioning.)
|
||||
|
||||
|
||||
## Notes
|
||||
|
||||
- Wan2.1 supports LoRAs with [`~loaders.WanLoraLoaderMixin.load_lora_weights`].
|
||||
@@ -281,10 +484,10 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
|
||||
|
||||
# use "steamboat willie style" to trigger the LoRA
|
||||
prompt = """
|
||||
steamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot,
|
||||
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
|
||||
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
|
||||
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
|
||||
steamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot,
|
||||
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
|
||||
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
|
||||
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
|
||||
shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
|
||||
"""
|
||||
|
||||
@@ -359,6 +562,12 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## WanAnimatePipeline
|
||||
|
||||
[[autodoc]] WanAnimatePipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## WanPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput
|
||||
[[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput
|
||||
|
||||
@@ -0,0 +1,492 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
|
||||
# Building Custom Blocks
|
||||
|
||||
[ModularPipelineBlocks](./pipeline_block) are the fundamental building blocks of a [`ModularPipeline`]. You can create custom blocks by defining their inputs, outputs, and computation logic. This guide demonstrates how to create and use a custom block.
|
||||
|
||||
> [!TIP]
|
||||
> Explore the [Modular Diffusers Custom Blocks](https://huggingface.co/collections/diffusers/modular-diffusers-custom-blocks) collection for official custom modular blocks like Nano Banana.
|
||||
|
||||
## Project Structure
|
||||
|
||||
Your custom block project should use the following structure:
|
||||
|
||||
```shell
|
||||
.
|
||||
├── block.py
|
||||
└── modular_config.json
|
||||
```
|
||||
|
||||
- `block.py` contains the custom block implementation
|
||||
- `modular_config.json` contains the metadata needed to load the block
|
||||
|
||||
## Example: Florence 2 Inpainting Block
|
||||
|
||||
In this example we will create a custom block that uses the [Florence 2](https://huggingface.co/docs/transformers/model_doc/florence2) model to process an input image and generate a mask for inpainting.
|
||||
|
||||
The first step is to define the components that the block will use. In this case, we will need to use the `Florence2ForConditionalGeneration` model and its corresponding processor `AutoProcessor`. When defining components, we must specify the name of the component within our pipeline, model class via `type_hint`, and provide a `pretrained_model_name_or_path` for the component if we intend to load the model weights from a specific repository on the Hub.
|
||||
|
||||
```py
|
||||
# Inside block.py
|
||||
from diffusers.modular_pipelines import (
|
||||
ModularPipelineBlocks,
|
||||
ComponentSpec,
|
||||
)
|
||||
from transformers import AutoProcessor, Florence2ForConditionalGeneration
|
||||
|
||||
|
||||
class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def expected_components(self):
|
||||
return [
|
||||
ComponentSpec(
|
||||
name="image_annotator",
|
||||
type_hint=Florence2ForConditionalGeneration,
|
||||
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
|
||||
),
|
||||
ComponentSpec(
|
||||
name="image_annotator_processor",
|
||||
type_hint=AutoProcessor,
|
||||
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
|
||||
),
|
||||
]
|
||||
```
|
||||
|
||||
Next, we define the inputs and outputs of the block. The inputs include the image to be annotated, the annotation task, and the annotation prompt. The outputs include the generated mask image and annotations.
|
||||
|
||||
```py
|
||||
from typing import List, Union
|
||||
from PIL import Image, ImageDraw
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from diffusers.modular_pipelines import (
|
||||
PipelineState,
|
||||
ModularPipelineBlocks,
|
||||
InputParam,
|
||||
ComponentSpec,
|
||||
OutputParam,
|
||||
)
|
||||
from transformers import AutoProcessor, Florence2ForConditionalGeneration
|
||||
|
||||
|
||||
class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def expected_components(self):
|
||||
return [
|
||||
ComponentSpec(
|
||||
name="image_annotator",
|
||||
type_hint=Florence2ForConditionalGeneration,
|
||||
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
|
||||
),
|
||||
ComponentSpec(
|
||||
name="image_annotator_processor",
|
||||
type_hint=AutoProcessor,
|
||||
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"image",
|
||||
type_hint=Union[Image.Image, List[Image.Image]],
|
||||
required=True,
|
||||
description="Image(s) to annotate",
|
||||
),
|
||||
InputParam(
|
||||
"annotation_task",
|
||||
type_hint=Union[str, List[str]],
|
||||
required=True,
|
||||
default="<REFERRING_EXPRESSION_SEGMENTATION>",
|
||||
description="""Annotation Task to perform on the image.
|
||||
Supported Tasks:
|
||||
|
||||
<OD>
|
||||
<REFERRING_EXPRESSION_SEGMENTATION>
|
||||
<CAPTION>
|
||||
<DETAILED_CAPTION>
|
||||
<MORE_DETAILED_CAPTION>
|
||||
<DENSE_REGION_CAPTION>
|
||||
<CAPTION_TO_PHRASE_GROUNDING>
|
||||
<OPEN_VOCABULARY_DETECTION>
|
||||
|
||||
""",
|
||||
),
|
||||
InputParam(
|
||||
"annotation_prompt",
|
||||
type_hint=Union[str, List[str]],
|
||||
required=True,
|
||||
description="""Annotation Prompt to provide more context to the task.
|
||||
Can be used to detect or segment out specific elements in the image
|
||||
""",
|
||||
),
|
||||
InputParam(
|
||||
"annotation_output_type",
|
||||
type_hint=str,
|
||||
required=True,
|
||||
default="mask_image",
|
||||
description="""Output type from annotation predictions. Availabe options are
|
||||
mask_image:
|
||||
-black and white mask image for the given image based on the task type
|
||||
mask_overlay:
|
||||
- mask overlayed on the original image
|
||||
bounding_box:
|
||||
- bounding boxes drawn on the original image
|
||||
""",
|
||||
),
|
||||
InputParam(
|
||||
"annotation_overlay",
|
||||
type_hint=bool,
|
||||
required=True,
|
||||
default=False,
|
||||
description="",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"mask_image",
|
||||
type_hint=Image,
|
||||
description="Inpainting Mask for input Image(s)",
|
||||
),
|
||||
OutputParam(
|
||||
"annotations",
|
||||
type_hint=dict,
|
||||
description="Annotations Predictions for input Image(s)",
|
||||
),
|
||||
OutputParam(
|
||||
"image",
|
||||
type_hint=Image,
|
||||
description="Annotated input Image(s)",
|
||||
),
|
||||
]
|
||||
|
||||
```
|
||||
|
||||
Now we implement the `__call__` method, which contains the logic for processing the input image and generating the mask.
|
||||
|
||||
```py
|
||||
from typing import List, Union
|
||||
from PIL import Image, ImageDraw
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from diffusers.modular_pipelines import (
|
||||
PipelineState,
|
||||
ModularPipelineBlocks,
|
||||
InputParam,
|
||||
ComponentSpec,
|
||||
OutputParam,
|
||||
)
|
||||
from transformers import AutoProcessor, Florence2ForConditionalGeneration
|
||||
|
||||
|
||||
class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def expected_components(self):
|
||||
return [
|
||||
ComponentSpec(
|
||||
name="image_annotator",
|
||||
type_hint=Florence2ForConditionalGeneration,
|
||||
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
|
||||
),
|
||||
ComponentSpec(
|
||||
name="image_annotator_processor",
|
||||
type_hint=AutoProcessor,
|
||||
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"image",
|
||||
type_hint=Union[Image.Image, List[Image.Image]],
|
||||
required=True,
|
||||
description="Image(s) to annotate",
|
||||
),
|
||||
InputParam(
|
||||
"annotation_task",
|
||||
type_hint=Union[str, List[str]],
|
||||
required=True,
|
||||
default="<REFERRING_EXPRESSION_SEGMENTATION>",
|
||||
description="""Annotation Task to perform on the image.
|
||||
Supported Tasks:
|
||||
|
||||
<OD>
|
||||
<REFERRING_EXPRESSION_SEGMENTATION>
|
||||
<CAPTION>
|
||||
<DETAILED_CAPTION>
|
||||
<MORE_DETAILED_CAPTION>
|
||||
<DENSE_REGION_CAPTION>
|
||||
<CAPTION_TO_PHRASE_GROUNDING>
|
||||
<OPEN_VOCABULARY_DETECTION>
|
||||
|
||||
""",
|
||||
),
|
||||
InputParam(
|
||||
"annotation_prompt",
|
||||
type_hint=Union[str, List[str]],
|
||||
required=True,
|
||||
description="""Annotation Prompt to provide more context to the task.
|
||||
Can be used to detect or segment out specific elements in the image
|
||||
""",
|
||||
),
|
||||
InputParam(
|
||||
"annotation_output_type",
|
||||
type_hint=str,
|
||||
required=True,
|
||||
default="mask_image",
|
||||
description="""Output type from annotation predictions. Availabe options are
|
||||
mask_image:
|
||||
-black and white mask image for the given image based on the task type
|
||||
mask_overlay:
|
||||
- mask overlayed on the original image
|
||||
bounding_box:
|
||||
- bounding boxes drawn on the original image
|
||||
""",
|
||||
),
|
||||
InputParam(
|
||||
"annotation_overlay",
|
||||
type_hint=bool,
|
||||
required=True,
|
||||
default=False,
|
||||
description="",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"mask_image",
|
||||
type_hint=Image,
|
||||
description="Inpainting Mask for input Image(s)",
|
||||
),
|
||||
OutputParam(
|
||||
"annotations",
|
||||
type_hint=dict,
|
||||
description="Annotations Predictions for input Image(s)",
|
||||
),
|
||||
OutputParam(
|
||||
"image",
|
||||
type_hint=Image,
|
||||
description="Annotated input Image(s)",
|
||||
),
|
||||
]
|
||||
|
||||
def get_annotations(self, components, images, prompts, task):
|
||||
task_prompts = [task + prompt for prompt in prompts]
|
||||
|
||||
inputs = components.image_annotator_processor(
|
||||
text=task_prompts, images=images, return_tensors="pt"
|
||||
).to(components.image_annotator.device, components.image_annotator.dtype)
|
||||
|
||||
generated_ids = components.image_annotator.generate(
|
||||
input_ids=inputs["input_ids"],
|
||||
pixel_values=inputs["pixel_values"],
|
||||
max_new_tokens=1024,
|
||||
early_stopping=False,
|
||||
do_sample=False,
|
||||
num_beams=3,
|
||||
)
|
||||
annotations = components.image_annotator_processor.batch_decode(
|
||||
generated_ids, skip_special_tokens=False
|
||||
)
|
||||
outputs = []
|
||||
for image, annotation in zip(images, annotations):
|
||||
outputs.append(
|
||||
components.image_annotator_processor.post_process_generation(
|
||||
annotation, task=task, image_size=(image.width, image.height)
|
||||
)
|
||||
)
|
||||
return outputs
|
||||
|
||||
def prepare_mask(self, images, annotations, overlay=False, fill="white"):
|
||||
masks = []
|
||||
for image, annotation in zip(images, annotations):
|
||||
mask_image = image.copy() if overlay else Image.new("L", image.size, 0)
|
||||
draw = ImageDraw.Draw(mask_image)
|
||||
|
||||
for _, _annotation in annotation.items():
|
||||
if "polygons" in _annotation:
|
||||
for polygon in _annotation["polygons"]:
|
||||
polygon = np.array(polygon).reshape(-1, 2)
|
||||
if len(polygon) < 3:
|
||||
continue
|
||||
polygon = polygon.reshape(-1).tolist()
|
||||
draw.polygon(polygon, fill=fill)
|
||||
|
||||
elif "bbox" in _annotation:
|
||||
bbox = _annotation["bbox"]
|
||||
draw.rectangle(bbox, fill="white")
|
||||
|
||||
masks.append(mask_image)
|
||||
|
||||
return masks
|
||||
|
||||
def prepare_bounding_boxes(self, images, annotations):
|
||||
outputs = []
|
||||
for image, annotation in zip(images, annotations):
|
||||
image_copy = image.copy()
|
||||
draw = ImageDraw.Draw(image_copy)
|
||||
for _, _annotation in annotation.items():
|
||||
bbox = _annotation["bbox"]
|
||||
label = _annotation["label"]
|
||||
|
||||
draw.rectangle(bbox, outline="red", width=3)
|
||||
draw.text((bbox[0], bbox[1] - 20), label, fill="red")
|
||||
|
||||
outputs.append(image_copy)
|
||||
|
||||
return outputs
|
||||
|
||||
def prepare_inputs(self, images, prompts):
|
||||
prompts = prompts or ""
|
||||
|
||||
if isinstance(images, Image.Image):
|
||||
images = [images]
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
if len(images) != len(prompts):
|
||||
raise ValueError("Number of images and annotation prompts must match.")
|
||||
|
||||
return images, prompts
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
images, annotation_task_prompt = self.prepare_inputs(
|
||||
block_state.image, block_state.annotation_prompt
|
||||
)
|
||||
task = block_state.annotation_task
|
||||
fill = block_state.fill
|
||||
|
||||
annotations = self.get_annotations(
|
||||
components, images, annotation_task_prompt, task
|
||||
)
|
||||
block_state.annotations = annotations
|
||||
if block_state.annotation_output_type == "mask_image":
|
||||
block_state.mask_image = self.prepare_mask(images, annotations)
|
||||
else:
|
||||
block_state.mask_image = None
|
||||
|
||||
if block_state.annotation_output_type == "mask_overlay":
|
||||
block_state.image = self.prepare_mask(images, annotations, overlay=True, fill=fill)
|
||||
|
||||
elif block_state.annotation_output_type == "bounding_box":
|
||||
block_state.image = self.prepare_bounding_boxes(images, annotations)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
```
|
||||
|
||||
Once we have defined our custom block, we can save it to the Hub, using either the CLI or the [`push_to_hub`] method. This will make it easy to share and reuse our custom block with other pipelines.
|
||||
|
||||
<hfoptions id="share">
|
||||
<hfoption id="hf CLI">
|
||||
|
||||
```shell
|
||||
# In the folder with the `block.py` file, run:
|
||||
diffusers-cli custom_block
|
||||
```
|
||||
|
||||
Then upload the block to the Hub:
|
||||
|
||||
```shell
|
||||
hf upload <your repo id> . .
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="push_to_hub">
|
||||
|
||||
```py
|
||||
from block import Florence2ImageAnnotatorBlock
|
||||
block = Florence2ImageAnnotatorBlock()
|
||||
block.push_to_hub("<your repo id>")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Using Custom Blocks
|
||||
|
||||
Load the custom block with [`~ModularPipelineBlocks.from_pretrained`] and set `trust_remote_code=True`.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.modular_pipelines import ModularPipelineBlocks, SequentialPipelineBlocks
|
||||
from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS
|
||||
from diffusers.utils import load_image
|
||||
|
||||
# Fetch the Florence2 image annotator block that will create our mask
|
||||
image_annotator_block = ModularPipelineBlocks.from_pretrained("diffusers/florence-2-custom-block", trust_remote_code=True)
|
||||
|
||||
my_blocks = INPAINT_BLOCKS.copy()
|
||||
# insert the annotation block before the image encoding step
|
||||
my_blocks.insert("image_annotator", image_annotator_block, 1)
|
||||
|
||||
# Create our initial set of inpainting blocks
|
||||
blocks = SequentialPipelineBlocks.from_blocks_dict(my_blocks)
|
||||
|
||||
repo_id = "diffusers/modular-stable-diffusion-xl-base-1.0"
|
||||
pipe = blocks.init_pipeline(repo_id)
|
||||
pipe.load_components(torch_dtype=torch.float16, device_map="cuda", trust_remote_code=True)
|
||||
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true")
|
||||
image = image.resize((1024, 1024))
|
||||
|
||||
prompt = ["A red car"]
|
||||
annotation_task = "<REFERRING_EXPRESSION_SEGMENTATION>"
|
||||
annotation_prompt = ["the car"]
|
||||
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
annotation_task=annotation_task,
|
||||
annotation_prompt=annotation_prompt,
|
||||
annotation_output_type="mask_image",
|
||||
num_inference_steps=35,
|
||||
guidance_scale=7.5,
|
||||
strength=0.95,
|
||||
output="images"
|
||||
)
|
||||
output[0].save("florence-inpainting.png")
|
||||
```
|
||||
|
||||
## Editing Custom Blocks
|
||||
|
||||
By default, custom blocks are saved in your cache directory. Use the `local_dir` argument to download and edit a custom block in a specific folder.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.modular_pipelines import ModularPipelineBlocks, SequentialPipelineBlocks
|
||||
from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS
|
||||
from diffusers.utils import load_image
|
||||
|
||||
# Fetch the Florence2 image annotator block that will create our mask
|
||||
image_annotator_block = ModularPipelineBlocks.from_pretrained("diffusers/florence-2-custom-block", trust_remote_code=True, local_dir="/my-local-folder")
|
||||
```
|
||||
|
||||
Any changes made to the block files in this folder will be reflected when you load the block again.
|
||||
@@ -0,0 +1,46 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# AutoModel
|
||||
|
||||
The [`AutoModel`] class automatically detects and loads the correct model class (UNet, transformer, VAE) from a `config.json` file. You don't need to know the specific model class name ahead of time. It supports data types and device placement, and works across model types and libraries.
|
||||
|
||||
The example below loads a transformer from Diffusers and a text encoder from Transformers. Use the `subfolder` parameter to specify where to load the `config.json` file from.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import AutoModel, DiffusionPipeline
|
||||
|
||||
transformer = AutoModel.from_pretrained(
|
||||
"Qwen/Qwen-Image", subfolder="transformer", torch_dtype=torch.bfloat16, device_map="cuda"
|
||||
)
|
||||
|
||||
text_encoder = AutoModel.from_pretrained(
|
||||
"Qwen/Qwen-Image", subfolder="text_encoder", torch_dtype=torch.bfloat16, device_map="cuda"
|
||||
)
|
||||
```
|
||||
|
||||
[`AutoModel`] also loads models from the [Hub](https://huggingface.co/models) that aren't included in Diffusers. Set `trust_remote_code=True` in [`AutoModel.from_pretrained`] to load custom models.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import AutoModel
|
||||
|
||||
transformer = AutoModel.from_pretrained(
|
||||
"custom/custom-transformer-model", trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="cuda"
|
||||
)
|
||||
```
|
||||
|
||||
If the custom model inherits from the [`ModelMixin`] class, it gets access to the same features as Diffusers model classes, like [regional compilation](../optimization/fp16#regional-compilation) and [group offloading](../optimization/memory#group-offloading).
|
||||
|
||||
> [!NOTE]
|
||||
> Learn more about implementing custom models in the [Community components](../using-diffusers/custom_pipeline_overview#community-components) guide.
|
||||
@@ -88,7 +88,7 @@ PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixar
|
||||
| FaithDiff Stable Diffusion XL Pipeline | Implementation of [(CVPR 2025) FaithDiff: Unleashing Diffusion Priors for Faithful Image Super-resolutionUnleashing Diffusion Priors for Faithful Image Super-resolution](https://huggingface.co/papers/2411.18824) - FaithDiff is a faithful image super-resolution method that leverages latent diffusion models by actively adapting the diffusion prior and jointly fine-tuning its components (encoder and diffusion model) with an alignment module to ensure high fidelity and structural consistency. | [FaithDiff Stable Diffusion XL Pipeline](#faithdiff-stable-diffusion-xl-pipeline) | [](https://huggingface.co/jychen9811/FaithDiff) | [Junyang Chen, Jinshan Pan, Jiangxin Dong, IMAG Lab, (Adapted by Eliseu Silva)](https://github.com/JyChen9811/FaithDiff) |
|
||||
| Stable Diffusion 3 InstructPix2Pix Pipeline | Implementation of Stable Diffusion 3 InstructPix2Pix Pipeline | [Stable Diffusion 3 InstructPix2Pix Pipeline](#stable-diffusion-3-instructpix2pix-pipeline) | [](https://huggingface.co/BleachNick/SD3_UltraEdit_freeform) [](https://huggingface.co/CaptainZZZ/sd3-instructpix2pix) | [Jiayu Zhang](https://github.com/xduzhangjiayu) and [Haozhe Zhao](https://github.com/HaozheZhao)|
|
||||
| Flux Kontext multiple images | A modified version of the `FluxKontextPipeline` that supports calling Flux Kontext with multiple reference images.| [Flux Kontext multiple input Pipeline](#flux-kontext-multiple-images) | - | [Net-Mist](https://github.com/Net-Mist) |
|
||||
|
||||
| Flux Fill ControlNet Pipeline | A modified version of the `FluxFillPipeline` and `FluxControlNetInpaintPipeline` that supports Controlnet with Flux Fill model.| [Flux Fill ControlNet Pipeline](#Flux-Fill-ControlNet-Pipeline) | - | [pratim4dasude](https://github.com/pratim4dasude) |
|
||||
|
||||
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
|
||||
|
||||
@@ -5488,7 +5488,7 @@ Editing at Scale", many thanks to their contribution!
|
||||
|
||||
This implementation of Flux Kontext allows users to pass multiple reference images. Each image is encoded separately, and the resulting latent vectors are concatenated.
|
||||
|
||||
As explained in Section 3 of [the paper](https://arxiv.org/pdf/2506.15742), the model's sequence concatenation mechanism can extend its capabilities to handle multiple reference images. However, note that the current version of Flux Kontext was not trained for this use case. In practice, stacking along the first axis does not yield correct results, while stacking along the other two axes appears to work.
|
||||
As explained in Section 3 of [the paper](https://huggingface.co/papers/2506.15742), the model's sequence concatenation mechanism can extend its capabilities to handle multiple reference images. However, note that the current version of Flux Kontext was not trained for this use case. In practice, stacking along the first axis does not yield correct results, while stacking along the other two axes appears to work.
|
||||
|
||||
## Example Usage
|
||||
|
||||
@@ -5527,3 +5527,106 @@ images = pipe(
|
||||
).images
|
||||
images[0].save("pizzeria.png")
|
||||
```
|
||||
|
||||
# Flux Fill ControlNet Pipeline
|
||||
|
||||
This implementation of Flux Fill + ControlNet Inpaint combines the fill-style masked editing of FLUX.1-Fill-dev with full ControlNet conditioning. The base image is processed through the Fill model while the ControlNet receives the corresponding conditioning input (depth, canny, pose, etc.), and both outputs are fused during denoising to guide structure and composition.
|
||||
|
||||
While FLUX.1-Fill-dev is designed for mask-based edits, it was not originally trained to operate jointly with ControlNet. In practice, this combined setup works well for structured inpainting tasks, though results may vary depending on the conditioning strength and the alignment between the mask and the control input.
|
||||
|
||||
## Example Usage
|
||||
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import (
|
||||
FluxControlNetModel,
|
||||
FluxPriorReduxPipeline,
|
||||
)
|
||||
from diffusers.utils import load_image
|
||||
|
||||
# NEW PIPELINE (updated name)
|
||||
from pipline_flux_fill_controlnet_Inpaint import FluxControlNetFillInpaintPipeline
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
# Models
|
||||
base_model = "black-forest-labs/FLUX.1-Fill-dev"
|
||||
controlnet_model = "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro-2.0"
|
||||
prior_model = "black-forest-labs/FLUX.1-Redux-dev"
|
||||
|
||||
# Load ControlNet
|
||||
controlnet = FluxControlNetModel.from_pretrained(
|
||||
controlnet_model,
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
|
||||
# Load Fill + ControlNet Pipeline
|
||||
fill_pipe = FluxControlNetFillInpaintPipeline.from_pretrained(
|
||||
base_model,
|
||||
controlnet=controlnet,
|
||||
torch_dtype=dtype,
|
||||
).to(device)
|
||||
|
||||
# OPTIONAL FP8
|
||||
# fill_pipe.transformer.enable_layerwise_casting(
|
||||
# storage_dtype=torch.float8_e4m3fn,
|
||||
# compute_dtype=torch.bfloat16
|
||||
# )
|
||||
|
||||
# OPTIONAL Prior Redux
|
||||
#pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(
|
||||
# prior_model,
|
||||
# torch_dtype=dtype,
|
||||
#).to(device)
|
||||
|
||||
# Inputs
|
||||
|
||||
# combined_image = load_image("person_input.png")
|
||||
|
||||
|
||||
# 1. Prior conditioning
|
||||
#prior_out = pipe_prior_redux(
|
||||
# image=cloth_image,
|
||||
# prompt=cloth_prompt,
|
||||
#)
|
||||
|
||||
# 2. Fill Inpaint with ControlNet
|
||||
|
||||
# canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6).
|
||||
|
||||
img = load_image(r"imgs/background.jpg")
|
||||
mask = load_image(r"imgs/mask.png")
|
||||
|
||||
control_image_depth = load_image(r"imgs/dog_depth _2.png")
|
||||
|
||||
result = fill_pipe(
|
||||
prompt="a dog on a bench",
|
||||
image=img,
|
||||
mask_image=mask,
|
||||
|
||||
control_image=control_image_depth,
|
||||
control_mode=[2], # union mode
|
||||
control_guidance_start=0.0,
|
||||
control_guidance_end=0.8,
|
||||
controlnet_conditioning_scale=0.9,
|
||||
|
||||
height=1024,
|
||||
width=1024,
|
||||
|
||||
strength=1.0,
|
||||
guidance_scale=50.0,
|
||||
num_inference_steps=60,
|
||||
max_sequence_length=512,
|
||||
|
||||
# **prior_out,
|
||||
)
|
||||
|
||||
# result.images[0].save("flux_fill_controlnet_inpaint.png")
|
||||
|
||||
from datetime import datetime
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
result.images[0].save(f"flux_fill_controlnet_inpaint_depth{timestamp}.jpg")
|
||||
```
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ def check_size(image, height, width):
|
||||
raise ValueError(f"Image size should be {height}x{width}, but got {h}x{w}")
|
||||
|
||||
|
||||
def overlay_inner_image(image, inner_image, paste_offset: Tuple[int] = (0, 0)):
|
||||
def overlay_inner_image(image, inner_image, paste_offset: Tuple[int, ...] = (0, 0)):
|
||||
inner_image = inner_image.convert("RGBA")
|
||||
image = image.convert("RGB")
|
||||
|
||||
|
||||
@@ -29,7 +29,6 @@ from diffusers.loaders import (
|
||||
TextualInversionLoaderMixin,
|
||||
)
|
||||
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||
@@ -1328,18 +1327,8 @@ class SDXLLongPromptWeightingPipeline(
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
|
||||
def upcast_vae(self):
|
||||
dtype = self.vae.dtype
|
||||
deprecate("upcast_vae", "1.0.0", "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`")
|
||||
self.vae.to(dtype=torch.float32)
|
||||
use_torch_2_0_or_xformers = isinstance(
|
||||
self.vae.decoder.mid_block.attentions[0].processor,
|
||||
(AttnProcessor2_0, XFormersAttnProcessor),
|
||||
)
|
||||
# if xformers or torch_2_0 is used attention block does not need
|
||||
# to be in float32 which can save lots of memory
|
||||
if use_torch_2_0_or_xformers:
|
||||
self.vae.post_quant_conv.to(dtype)
|
||||
self.vae.decoder.conv_in.to(dtype)
|
||||
self.vae.decoder.mid_block.to(dtype)
|
||||
|
||||
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
||||
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
|
||||
|
||||
@@ -1966,16 +1966,21 @@ class MatryoshkaUNet2DConditionModel(
|
||||
center_input_sample: bool = False,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str] = (
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
"UpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: Union[int, Tuple[int]] = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
@@ -2294,10 +2299,10 @@ class MatryoshkaUNet2DConditionModel(
|
||||
|
||||
def _check_config(
|
||||
self,
|
||||
down_block_types: Tuple[str],
|
||||
up_block_types: Tuple[str],
|
||||
down_block_types: Tuple[str, ...],
|
||||
up_block_types: Tuple[str, ...],
|
||||
only_cross_attention: Union[bool, Tuple[bool]],
|
||||
block_out_channels: Tuple[int],
|
||||
block_out_channels: Tuple[int, ...],
|
||||
layers_per_block: Union[int, Tuple[int]],
|
||||
cross_attention_dim: Union[int, Tuple[int]],
|
||||
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
|
||||
|
||||
@@ -30,17 +30,13 @@ from diffusers.loaders import (
|
||||
TextualInversionLoaderMixin,
|
||||
)
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
FusedAttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers, LMSDiscreteScheduler
|
||||
from diffusers.utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
is_invisible_watermark_available,
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
@@ -710,22 +706,8 @@ class StableDiffusionXLTilingPipeline(
|
||||
return torch.tile(weights_torch, (nbatches, self.unet.config.in_channels, 1, 1))
|
||||
|
||||
def upcast_vae(self):
|
||||
dtype = self.vae.dtype
|
||||
deprecate("upcast_vae", "1.0.0", "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`")
|
||||
self.vae.to(dtype=torch.float32)
|
||||
use_torch_2_0_or_xformers = isinstance(
|
||||
self.vae.decoder.mid_block.attentions[0].processor,
|
||||
(
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
FusedAttnProcessor2_0,
|
||||
),
|
||||
)
|
||||
# if xformers or torch_2_0 is used attention block does not need
|
||||
# to be in float32 which can save lots of memory
|
||||
if use_torch_2_0_or_xformers:
|
||||
self.vae.post_quant_conv.to(dtype)
|
||||
self.vae.decoder.conv_in.to(dtype)
|
||||
self.vae.decoder.mid_block.to(dtype)
|
||||
|
||||
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
||||
def get_guidance_scale_embedding(
|
||||
|
||||
@@ -39,16 +39,13 @@ from diffusers.models import (
|
||||
MultiControlNetModel,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers, LMSDiscreteScheduler
|
||||
from diffusers.utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
@@ -1220,23 +1217,9 @@ class StableDiffusionXLControlNetTileSRPipeline(
|
||||
|
||||
return tile_weights, tile_row_overlaps, tile_col_overlaps
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
|
||||
def upcast_vae(self):
|
||||
dtype = self.vae.dtype
|
||||
deprecate("upcast_vae", "1.0.0", "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`")
|
||||
self.vae.to(dtype=torch.float32)
|
||||
use_torch_2_0_or_xformers = isinstance(
|
||||
self.vae.decoder.mid_block.attentions[0].processor,
|
||||
(
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
),
|
||||
)
|
||||
# if xformers or torch_2_0 is used attention block does not need
|
||||
# to be in float32 which can save lots of memory
|
||||
if use_torch_2_0_or_xformers:
|
||||
self.vae.post_quant_conv.to(dtype)
|
||||
self.vae.decoder.conv_in.to(dtype)
|
||||
self.vae.decoder.mid_block.to(dtype)
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
|
||||
@@ -40,10 +40,6 @@ from diffusers.models import (
|
||||
MultiControlNetModel,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.pipelines.kolors import ChatGLMModel, ChatGLMTokenizer
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||
@@ -760,21 +756,8 @@ class KolorsControlNetPipeline(
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
|
||||
def upcast_vae(self):
|
||||
dtype = self.vae.dtype
|
||||
deprecate("upcast_vae", "1.0.0", "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`")
|
||||
self.vae.to(dtype=torch.float32)
|
||||
use_torch_2_0_or_xformers = isinstance(
|
||||
self.vae.decoder.mid_block.attentions[0].processor,
|
||||
(
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
),
|
||||
)
|
||||
# if xformers or torch_2_0 is used attention block does not need
|
||||
# to be in float32 which can save lots of memory
|
||||
if use_torch_2_0_or_xformers:
|
||||
self.vae.post_quant_conv.to(dtype)
|
||||
self.vae.decoder.conv_in.to(dtype)
|
||||
self.vae.decoder.mid_block.to(dtype)
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
|
||||
@@ -40,10 +40,6 @@ from diffusers.models import (
|
||||
MultiControlNetModel,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.pipelines.kolors import ChatGLMModel, ChatGLMTokenizer
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||
@@ -930,21 +926,8 @@ class KolorsControlNetImg2ImgPipeline(
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
|
||||
def upcast_vae(self):
|
||||
dtype = self.vae.dtype
|
||||
deprecate("upcast_vae", "1.0.0", "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`")
|
||||
self.vae.to(dtype=torch.float32)
|
||||
use_torch_2_0_or_xformers = isinstance(
|
||||
self.vae.decoder.mid_block.attentions[0].processor,
|
||||
(
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
),
|
||||
)
|
||||
# if xformers or torch_2_0 is used attention block does not need
|
||||
# to be in float32 which can save lots of memory
|
||||
if use_torch_2_0_or_xformers:
|
||||
self.vae.post_quant_conv.to(dtype)
|
||||
self.vae.decoder.conv_in.to(dtype)
|
||||
self.vae.decoder.mid_block.to(dtype)
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
|
||||
@@ -39,10 +39,6 @@ from diffusers.models import (
|
||||
MultiControlNetModel,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.pipelines.kolors import ChatGLMModel, ChatGLMTokenizer
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||
@@ -1006,21 +1002,8 @@ class KolorsControlNetInpaintPipeline(
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
|
||||
def upcast_vae(self):
|
||||
dtype = self.vae.dtype
|
||||
deprecate("upcast_vae", "1.0.0", "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`")
|
||||
self.vae.to(dtype=torch.float32)
|
||||
use_torch_2_0_or_xformers = isinstance(
|
||||
self.vae.decoder.mid_block.attentions[0].processor,
|
||||
(
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
),
|
||||
)
|
||||
# if xformers or torch_2_0 is used attention block does not need
|
||||
# to be in float32 which can save lots of memory
|
||||
if use_torch_2_0_or_xformers:
|
||||
self.vae.post_quant_conv.to(dtype)
|
||||
self.vae.decoder.conv_in.to(dtype)
|
||||
self.vae.decoder.mid_block.to(dtype)
|
||||
|
||||
@property
|
||||
def denoising_end(self):
|
||||
|
||||
@@ -16,11 +16,11 @@ from diffusers.loaders import (
|
||||
TextualInversionLoaderMixin,
|
||||
)
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils import (
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_invisible_watermark_available,
|
||||
@@ -612,20 +612,9 @@ class DemoFusionSDXLPipeline(
|
||||
|
||||
return image
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
|
||||
def upcast_vae(self):
|
||||
dtype = self.vae.dtype
|
||||
deprecate("upcast_vae", "1.0.0", "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`")
|
||||
self.vae.to(dtype=torch.float32)
|
||||
use_torch_2_0_or_xformers = isinstance(
|
||||
self.vae.decoder.mid_block.attentions[0].processor,
|
||||
(AttnProcessor2_0, XFormersAttnProcessor),
|
||||
)
|
||||
# if xformers or torch_2_0 is used attention block does not need
|
||||
# to be in float32 which can save lots of memory
|
||||
if use_torch_2_0_or_xformers:
|
||||
self.vae.post_quant_conv.to(dtype)
|
||||
self.vae.decoder.conv_in.to(dtype)
|
||||
self.vae.decoder.mid_block.to(dtype)
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
|
||||
@@ -40,13 +40,6 @@ from diffusers.loaders import (
|
||||
UNet2DConditionLoadersMixin,
|
||||
)
|
||||
from diffusers.models import AutoencoderKL
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
FusedAttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D, get_down_block
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
@@ -438,16 +431,21 @@ class UNet2DConditionModel(OriginalUNet2DConditionModel, ConfigMixin, UNet2DCond
|
||||
center_input_sample: bool = False,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str] = (
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
"UpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: Union[int, Tuple[int]] = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
@@ -1637,24 +1635,8 @@ class FaithDiffStableDiffusionXLPipeline(
|
||||
return latents
|
||||
|
||||
def upcast_vae(self):
|
||||
dtype = self.vae.dtype
|
||||
deprecate("upcast_vae", "1.0.0", "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`")
|
||||
self.vae.to(dtype=torch.float32)
|
||||
use_torch_2_0_or_xformers = isinstance(
|
||||
self.vae.decoder.mid_block.attentions[0].processor,
|
||||
(
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
LoRAXFormersAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
FusedAttnProcessor2_0,
|
||||
),
|
||||
)
|
||||
# if xformers or torch_2_0 is used attention block does not need
|
||||
# to be in float32 which can save lots of memory
|
||||
if use_torch_2_0_or_xformers:
|
||||
self.vae.post_quant_conv.to(dtype)
|
||||
self.vae.decoder.conv_in.to(dtype)
|
||||
self.vae.decoder.mid_block.to(dtype)
|
||||
|
||||
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
||||
def get_guidance_scale_embedding(
|
||||
|
||||
@@ -22,13 +22,12 @@ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from diffusers.loaders import IPAdapterMixin, StableDiffusionXLLoraLoaderMixin
|
||||
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
|
||||
from diffusers.pipelines.kolors.pipeline_output import KolorsPipelineOutput
|
||||
from diffusers.pipelines.kolors.text_encoder import ChatGLMModel
|
||||
from diffusers.pipelines.kolors.tokenizer import ChatGLMTokenizer
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from diffusers.utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
@@ -709,24 +708,9 @@ class KolorsDifferentialImg2ImgPipeline(
|
||||
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
||||
return add_time_ids
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae
|
||||
def upcast_vae(self):
|
||||
dtype = self.vae.dtype
|
||||
deprecate("upcast_vae", "1.0.0", "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`")
|
||||
self.vae.to(dtype=torch.float32)
|
||||
use_torch_2_0_or_xformers = isinstance(
|
||||
self.vae.decoder.mid_block.attentions[0].processor,
|
||||
(
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
FusedAttnProcessor2_0,
|
||||
),
|
||||
)
|
||||
# if xformers or torch_2_0 is used attention block does not need
|
||||
# to be in float32 which can save lots of memory
|
||||
if use_torch_2_0_or_xformers:
|
||||
self.vae.post_quant_conv.to(dtype)
|
||||
self.vae.decoder.conv_in.to(dtype)
|
||||
self.vae.decoder.mid_block.to(dtype)
|
||||
|
||||
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
||||
def get_guidance_scale_embedding(
|
||||
|
||||
@@ -32,12 +32,6 @@ from diffusers.loaders import (
|
||||
TextualInversionLoaderMixin,
|
||||
)
|
||||
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.pipelines.kolors import ChatGLMModel, ChatGLMTokenizer
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||
@@ -1008,23 +1002,8 @@ class KolorsInpaintPipeline(
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
|
||||
def upcast_vae(self):
|
||||
dtype = self.vae.dtype
|
||||
deprecate("upcast_vae", "1.0.0", "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`")
|
||||
self.vae.to(dtype=torch.float32)
|
||||
use_torch_2_0_or_xformers = isinstance(
|
||||
self.vae.decoder.mid_block.attentions[0].processor,
|
||||
(
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
LoRAXFormersAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
),
|
||||
)
|
||||
# if xformers or torch_2_0 is used attention block does not need
|
||||
# to be in float32 which can save lots of memory
|
||||
if use_torch_2_0_or_xformers:
|
||||
self.vae.post_quant_conv.to(dtype)
|
||||
self.vae.decoder.conv_in.to(dtype)
|
||||
self.vae.decoder.mid_block.to(dtype)
|
||||
|
||||
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
||||
def get_guidance_scale_embedding(
|
||||
|
||||
@@ -45,8 +45,6 @@ from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionMode
|
||||
from diffusers.models.attention_processor import (
|
||||
Attention,
|
||||
AttnProcessor2_0,
|
||||
FusedAttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
@@ -1151,22 +1149,8 @@ class StyleAlignedSDXLPipeline(
|
||||
return add_time_ids
|
||||
|
||||
def upcast_vae(self):
|
||||
dtype = self.vae.dtype
|
||||
deprecate("upcast_vae", "1.0.0", "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`")
|
||||
self.vae.to(dtype=torch.float32)
|
||||
use_torch_2_0_or_xformers = isinstance(
|
||||
self.vae.decoder.mid_block.attentions[0].processor,
|
||||
(
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
FusedAttnProcessor2_0,
|
||||
),
|
||||
)
|
||||
# if xformers or torch_2_0 is used attention block does not need
|
||||
# to be in float32 which can save lots of memory
|
||||
if use_torch_2_0_or_xformers:
|
||||
self.vae.post_quant_conv.to(dtype)
|
||||
self.vae.decoder.conv_in.to(dtype)
|
||||
self.vae.decoder.mid_block.to(dtype)
|
||||
|
||||
def _enable_shared_attention_processors(
|
||||
self,
|
||||
|
||||
@@ -503,24 +503,9 @@ class StableDiffusionUpscaleLDM3DPipeline(
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
# def upcast_vae(self):
|
||||
# dtype = self.vae.dtype
|
||||
# self.vae.to(dtype=torch.float32)
|
||||
# use_torch_2_0_or_xformers = isinstance(
|
||||
# self.vae.decoder.mid_block.attentions[0].processor,
|
||||
# (
|
||||
# AttnProcessor2_0,
|
||||
# XFormersAttnProcessor,
|
||||
# LoRAXFormersAttnProcessor,
|
||||
# LoRAAttnProcessor2_0,
|
||||
# ),
|
||||
# )
|
||||
# # if xformers or torch_2_0 is used attention block does not need
|
||||
# # to be in float32 which can save lots of memory
|
||||
# if use_torch_2_0_or_xformers:
|
||||
# self.vae.post_quant_conv.to(dtype)
|
||||
# self.vae.decoder.conv_in.to(dtype)
|
||||
# self.vae.decoder.mid_block.to(dtype)
|
||||
def upcast_vae(self):
|
||||
deprecate("upcast_vae", "1.0.0", "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`")
|
||||
self.vae.to(dtype=torch.float32)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
|
||||
@@ -35,12 +35,6 @@ from diffusers.loaders import (
|
||||
TextualInversionLoaderMixin,
|
||||
)
|
||||
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||
@@ -1282,23 +1276,8 @@ class StableDiffusionXL_AE_Pipeline(
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
|
||||
def upcast_vae(self):
|
||||
dtype = self.vae.dtype
|
||||
deprecate("upcast_vae", "1.0.0", "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`")
|
||||
self.vae.to(dtype=torch.float32)
|
||||
use_torch_2_0_or_xformers = isinstance(
|
||||
self.vae.decoder.mid_block.attentions[0].processor,
|
||||
(
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
LoRAXFormersAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
),
|
||||
)
|
||||
# if xformers or torch_2_0 is used attention block does not need
|
||||
# to be in float32 which can save lots of memory
|
||||
if use_torch_2_0_or_xformers:
|
||||
self.vae.post_quant_conv.to(dtype)
|
||||
self.vae.decoder.conv_in.to(dtype)
|
||||
self.vae.decoder.mid_block.to(dtype)
|
||||
|
||||
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
||||
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
|
||||
|
||||
@@ -25,7 +25,6 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz
|
||||
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from diffusers.loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from diffusers.models import AutoencoderKL, ControlNetModel, MultiAdapter, T2IAdapter, UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
@@ -34,6 +33,7 @@ from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils import (
|
||||
PIL_INTERPOLATION,
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
@@ -793,20 +793,9 @@ class StableDiffusionXLControlNetAdapterPipeline(
|
||||
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
||||
return add_time_ids
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
|
||||
def upcast_vae(self):
|
||||
dtype = self.vae.dtype
|
||||
deprecate("upcast_vae", "1.0.0", "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`")
|
||||
self.vae.to(dtype=torch.float32)
|
||||
use_torch_2_0_or_xformers = isinstance(
|
||||
self.vae.decoder.mid_block.attentions[0].processor,
|
||||
(AttnProcessor2_0, XFormersAttnProcessor),
|
||||
)
|
||||
# if xformers or torch_2_0 is used attention block does not need
|
||||
# to be in float32 which can save lots of memory
|
||||
if use_torch_2_0_or_xformers:
|
||||
self.vae.post_quant_conv.to(dtype)
|
||||
self.vae.decoder.conv_in.to(dtype)
|
||||
self.vae.decoder.mid_block.to(dtype)
|
||||
|
||||
# Copied from diffusers.pipelines.t2i_adapter.pipeline_stable_diffusion_adapter.StableDiffusionAdapterPipeline._default_height_width
|
||||
def _default_height_width(self, height, width, image):
|
||||
|
||||
@@ -43,7 +43,6 @@ from diffusers.models import (
|
||||
T2IAdapter,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
||||
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
|
||||
@@ -52,6 +51,7 @@ from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils import (
|
||||
PIL_INTERPOLATION,
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
@@ -1130,20 +1130,9 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(
|
||||
|
||||
return add_time_ids, add_neg_time_ids
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
|
||||
def upcast_vae(self):
|
||||
dtype = self.vae.dtype
|
||||
deprecate("upcast_vae", "1.0.0", "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`")
|
||||
self.vae.to(dtype=torch.float32)
|
||||
use_torch_2_0_or_xformers = isinstance(
|
||||
self.vae.decoder.mid_block.attentions[0].processor,
|
||||
(AttnProcessor2_0, XFormersAttnProcessor),
|
||||
)
|
||||
# if xformers or torch_2_0 is used attention block does not need
|
||||
# to be in float32 which can save lots of memory
|
||||
if use_torch_2_0_or_xformers:
|
||||
self.vae.post_quant_conv.to(dtype)
|
||||
self.vae.decoder.conv_in.to(dtype)
|
||||
self.vae.decoder.mid_block.to(dtype)
|
||||
|
||||
# Copied from diffusers.pipelines.t2i_adapter.pipeline_stable_diffusion_adapter.StableDiffusionAdapterPipeline._default_height_width
|
||||
def _default_height_width(self, height, width, image):
|
||||
|
||||
@@ -35,10 +35,6 @@ from diffusers.loaders import (
|
||||
TextualInversionLoaderMixin,
|
||||
)
|
||||
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||
@@ -848,21 +844,8 @@ class StableDiffusionXLDifferentialImg2ImgPipeline(
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
|
||||
def upcast_vae(self):
|
||||
dtype = self.vae.dtype
|
||||
deprecate("upcast_vae", "1.0.0", "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`")
|
||||
self.vae.to(dtype=torch.float32)
|
||||
use_torch_2_0_or_xformers = isinstance(
|
||||
self.vae.decoder.mid_block.attentions[0].processor,
|
||||
(
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
),
|
||||
)
|
||||
# if xformers or torch_2_0 is used attention block does not need
|
||||
# to be in float32 which can save lots of memory
|
||||
if use_torch_2_0_or_xformers:
|
||||
self.vae.post_quant_conv.to(dtype)
|
||||
self.vae.decoder.conv_in.to(dtype)
|
||||
self.vae.decoder.mid_block.to(dtype)
|
||||
|
||||
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
||||
def get_guidance_scale_embedding(
|
||||
|
||||
@@ -32,10 +32,6 @@ from diffusers.loaders import (
|
||||
TextualInversionLoaderMixin,
|
||||
)
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
@@ -658,23 +654,9 @@ class StableDiffusionXLPipelineIpex(
|
||||
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
||||
return add_time_ids
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
|
||||
def upcast_vae(self):
|
||||
dtype = self.vae.dtype
|
||||
deprecate("upcast_vae", "1.0.0", "`upcast_vae` is deprecated. Please use `pipe.vae.to(torch.float32)`")
|
||||
self.vae.to(dtype=torch.float32)
|
||||
use_torch_2_0_or_xformers = isinstance(
|
||||
self.vae.decoder.mid_block.attentions[0].processor,
|
||||
(
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
),
|
||||
)
|
||||
# if xformers or torch_2_0 is used attention block does not need
|
||||
# to be in float32 which can save lots of memory
|
||||
if use_torch_2_0_or_xformers:
|
||||
self.vae.post_quant_conv.to(dtype)
|
||||
self.vae.decoder.conv_in.to(dtype)
|
||||
self.vae.decoder.mid_block.to(dtype)
|
||||
|
||||
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
||||
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -490,7 +490,7 @@ class RegionalPromptingStableDiffusionPipeline(
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
@@ -841,7 +841,7 @@ class RegionalPromptingStableDiffusionPipeline(
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
||||
Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only applies
|
||||
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
@@ -872,7 +872,7 @@ class RegionalPromptingStableDiffusionPipeline(
|
||||
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
||||
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
|
||||
Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
|
||||
using zero terminal SNR.
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
@@ -1062,7 +1062,7 @@ class RegionalPromptingStableDiffusionPipeline(
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
# Based on 3.4. in https://huggingface.co/papers/2305.08891
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
@@ -1668,7 +1668,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
r"""
|
||||
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
|
||||
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
|
||||
Args:
|
||||
noise_cfg (`torch.Tensor`):
|
||||
|
||||
@@ -268,12 +268,11 @@ provide a simple script for LoRA fine-tuning Kontext in [train_dreambooth_lora_f
|
||||
**important**
|
||||
|
||||
> [!NOTE]
|
||||
> To make sure you can successfully run the latest version of the kontext example script, we highly recommend installing from source, specifically from the commit mentioned below.
|
||||
> To make sure you can successfully run the latest version of the kontext example script, we highly recommend installing from source.
|
||||
> To do this, execute the following steps in a new virtual environment:
|
||||
> ```
|
||||
> git clone https://github.com/huggingface/diffusers
|
||||
> cd diffusers
|
||||
> git checkout 05e7a854d0a5661f5b433f6dd5954c224b104f0b
|
||||
> pip install -e .
|
||||
> ```
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from accelerate import init_empty_weights
|
||||
from diffusers import (
|
||||
SanaControlNetModel,
|
||||
)
|
||||
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.models.model_loading_utils import load_model_dict_into_meta
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from diffusers import (
|
||||
SanaTransformer2DModel,
|
||||
SCMScheduler,
|
||||
)
|
||||
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.models.model_loading_utils import load_model_dict_into_meta
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
|
||||
|
||||
@@ -80,6 +80,8 @@ def main(args):
|
||||
|
||||
# scheduler
|
||||
flow_shift = 8.0
|
||||
if args.task == "i2v":
|
||||
assert args.scheduler_type == "flow-euler", "Scheduler type must be flow-euler for i2v task."
|
||||
|
||||
# model config
|
||||
layer_num = 20
|
||||
@@ -312,6 +314,7 @@ if __name__ == "__main__":
|
||||
choices=["flow-dpm_solver", "flow-euler", "uni-pc"],
|
||||
help="Scheduler type to use.",
|
||||
)
|
||||
parser.add_argument("--task", default="t2v", type=str, required=True, help="Task to convert, t2v or i2v.")
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
|
||||
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.")
|
||||
parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")
|
||||
|
||||
@@ -7,7 +7,7 @@ from accelerate import init_empty_weights
|
||||
|
||||
from diffusers import AutoencoderKL, SD3Transformer2DModel
|
||||
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
|
||||
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.models.model_loading_utils import load_model_dict_into_meta
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from diffusers import (
|
||||
StableAudioPipeline,
|
||||
StableAudioProjectionModel,
|
||||
)
|
||||
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.models.model_loading_utils import load_model_dict_into_meta
|
||||
from diffusers.utils import is_accelerate_available
|
||||
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
|
||||
from diffusers.models import StableCascadeUNet
|
||||
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.models.model_loading_utils import load_model_dict_into_meta
|
||||
from diffusers.pipelines.wuerstchen import PaellaVQModel
|
||||
from diffusers.utils import is_accelerate_available
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
|
||||
from diffusers.models import StableCascadeUNet
|
||||
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.models.model_loading_utils import load_model_dict_into_meta
|
||||
from diffusers.pipelines.wuerstchen import PaellaVQModel
|
||||
from diffusers.utils import is_accelerate_available
|
||||
|
||||
|
||||
@@ -6,11 +6,20 @@ import torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoProcessor, AutoTokenizer, CLIPVisionModelWithProjection, UMT5EncoderModel
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
CLIPImageProcessor,
|
||||
CLIPVisionModel,
|
||||
CLIPVisionModelWithProjection,
|
||||
UMT5EncoderModel,
|
||||
)
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLWan,
|
||||
UniPCMultistepScheduler,
|
||||
WanAnimatePipeline,
|
||||
WanAnimateTransformer3DModel,
|
||||
WanImageToVideoPipeline,
|
||||
WanPipeline,
|
||||
WanTransformer3DModel,
|
||||
@@ -105,8 +114,203 @@ VACE_TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"after_proj": "proj_out",
|
||||
}
|
||||
|
||||
ANIMATE_TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"time_embedding.0": "condition_embedder.time_embedder.linear_1",
|
||||
"time_embedding.2": "condition_embedder.time_embedder.linear_2",
|
||||
"text_embedding.0": "condition_embedder.text_embedder.linear_1",
|
||||
"text_embedding.2": "condition_embedder.text_embedder.linear_2",
|
||||
"time_projection.1": "condition_embedder.time_proj",
|
||||
"head.modulation": "scale_shift_table",
|
||||
"head.head": "proj_out",
|
||||
"modulation": "scale_shift_table",
|
||||
"ffn.0": "ffn.net.0.proj",
|
||||
"ffn.2": "ffn.net.2",
|
||||
# Hack to swap the layer names
|
||||
# The original model calls the norms in following order: norm1, norm3, norm2
|
||||
# We convert it to: norm1, norm2, norm3
|
||||
"norm2": "norm__placeholder",
|
||||
"norm3": "norm2",
|
||||
"norm__placeholder": "norm3",
|
||||
"img_emb.proj.0": "condition_embedder.image_embedder.norm1",
|
||||
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
|
||||
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
|
||||
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
|
||||
# Add attention component mappings
|
||||
"self_attn.q": "attn1.to_q",
|
||||
"self_attn.k": "attn1.to_k",
|
||||
"self_attn.v": "attn1.to_v",
|
||||
"self_attn.o": "attn1.to_out.0",
|
||||
"self_attn.norm_q": "attn1.norm_q",
|
||||
"self_attn.norm_k": "attn1.norm_k",
|
||||
"cross_attn.q": "attn2.to_q",
|
||||
"cross_attn.k": "attn2.to_k",
|
||||
"cross_attn.v": "attn2.to_v",
|
||||
"cross_attn.o": "attn2.to_out.0",
|
||||
"cross_attn.norm_q": "attn2.norm_q",
|
||||
"cross_attn.norm_k": "attn2.norm_k",
|
||||
"cross_attn.k_img": "attn2.to_k_img",
|
||||
"cross_attn.v_img": "attn2.to_v_img",
|
||||
"cross_attn.norm_k_img": "attn2.norm_k_img",
|
||||
# After cross_attn -> attn2 rename, we need to rename the img keys
|
||||
"attn2.to_k_img": "attn2.add_k_proj",
|
||||
"attn2.to_v_img": "attn2.add_v_proj",
|
||||
"attn2.norm_k_img": "attn2.norm_added_k",
|
||||
# Wan Animate-specific mappings (motion encoder, face encoder, face adapter)
|
||||
# Motion encoder mappings
|
||||
# The name mapping is complicated for the convolutional part so we handle that in its own function
|
||||
"motion_encoder.enc.fc": "motion_encoder.motion_network",
|
||||
"motion_encoder.dec.direction.weight": "motion_encoder.motion_synthesis_weight",
|
||||
# Face encoder mappings - CausalConv1d has a .conv submodule that we need to flatten
|
||||
"face_encoder.conv1_local.conv": "face_encoder.conv1_local",
|
||||
"face_encoder.conv2.conv": "face_encoder.conv2",
|
||||
"face_encoder.conv3.conv": "face_encoder.conv3",
|
||||
# Face adapter mappings are handled in a separate function
|
||||
}
|
||||
|
||||
|
||||
# TODO: Verify this and simplify if possible.
|
||||
def convert_animate_motion_encoder_weights(key: str, state_dict: Dict[str, Any], final_conv_idx: int = 8) -> None:
|
||||
"""
|
||||
Convert all motion encoder weights for Animate model.
|
||||
|
||||
In the original model:
|
||||
- All Linear layers in fc use EqualLinear
|
||||
- All Conv2d layers in convs use EqualConv2d (except blur_conv which is initialized separately)
|
||||
- Blur kernels are stored as buffers in Sequential modules
|
||||
- ConvLayer is nn.Sequential with indices: [Blur (optional), EqualConv2d, FusedLeakyReLU (optional)]
|
||||
|
||||
Conversion strategy:
|
||||
1. Drop .kernel buffers (blur kernels)
|
||||
2. Rename sequential indices to named components (e.g., 0 -> conv2d, 1 -> bias_leaky_relu)
|
||||
"""
|
||||
# Skip if not a weight, bias, or kernel
|
||||
if ".weight" not in key and ".bias" not in key and ".kernel" not in key:
|
||||
return
|
||||
|
||||
# Handle Blur kernel buffers from original implementation.
|
||||
# After renaming, these appear under: motion_encoder.res_blocks.*.conv{2,skip}.blur_kernel
|
||||
# Diffusers constructs blur kernels as a non-persistent buffer so we must drop these keys
|
||||
if ".kernel" in key and "motion_encoder" in key:
|
||||
# Remove unexpected blur kernel buffers to avoid strict load errors
|
||||
state_dict.pop(key, None)
|
||||
return
|
||||
|
||||
# Rename Sequential indices to named components in ConvLayer and ResBlock
|
||||
if ".enc.net_app.convs." in key and (".weight" in key or ".bias" in key):
|
||||
parts = key.split(".")
|
||||
|
||||
# Find the sequential index (digit) after convs or after conv1/conv2/skip
|
||||
# Examples:
|
||||
# - enc.net_app.convs.0.0.weight -> conv_in.weight (initial conv layer weight)
|
||||
# - enc.net_app.convs.0.1.bias -> conv_in.act_fn.bias (initial conv layer bias)
|
||||
# - enc.net_app.convs.{n:1-7}.conv1.0.weight -> res_blocks.{(n-1):0-6}.conv1.weight (conv1 weight)
|
||||
# - e.g. enc.net_app.convs.1.conv1.0.weight -> res_blocks.0.conv1.weight
|
||||
# - enc.net_app.convs.{n:1-7}.conv1.1.bias -> res_blocks.{(n-1):0-6}.conv1.act_fn.bias (conv1 bias)
|
||||
# - e.g. enc.net_app.convs.1.conv1.1.bias -> res_blocks.0.conv1.act_fn.bias
|
||||
# - enc.net_app.convs.{n:1-7}.conv2.1.weight -> res_blocks.{(n-1):0-6}.conv2.weight (conv2 weight)
|
||||
# - enc.net_app.convs.1.conv2.2.bias -> res_blocks.0.conv2.act_fn.bias (conv2 bias)
|
||||
# - enc.net_app.convs.{n:1-7}.skip.1.weight -> res_blocks.{(n-1):0-6}.conv_skip.weight (skip conv weight)
|
||||
# - enc.net_app.convs.8 -> conv_out (final conv layer)
|
||||
|
||||
convs_idx = parts.index("convs") if "convs" in parts else -1
|
||||
if convs_idx >= 0 and len(parts) - convs_idx >= 2:
|
||||
bias = False
|
||||
# The nn.Sequential index will always follow convs
|
||||
sequential_idx = int(parts[convs_idx + 1])
|
||||
if sequential_idx == 0:
|
||||
if key.endswith(".weight"):
|
||||
new_key = "motion_encoder.conv_in.weight"
|
||||
elif key.endswith(".bias"):
|
||||
new_key = "motion_encoder.conv_in.act_fn.bias"
|
||||
bias = True
|
||||
elif sequential_idx == final_conv_idx:
|
||||
if key.endswith(".weight"):
|
||||
new_key = "motion_encoder.conv_out.weight"
|
||||
else:
|
||||
# Intermediate .convs. layers, which get mapped to .res_blocks.
|
||||
prefix = "motion_encoder.res_blocks."
|
||||
|
||||
layer_name = parts[convs_idx + 2]
|
||||
if layer_name == "skip":
|
||||
layer_name = "conv_skip"
|
||||
|
||||
if key.endswith(".weight"):
|
||||
param_name = "weight"
|
||||
elif key.endswith(".bias"):
|
||||
param_name = "act_fn.bias"
|
||||
bias = True
|
||||
|
||||
suffix_parts = [str(sequential_idx - 1), layer_name, param_name]
|
||||
suffix = ".".join(suffix_parts)
|
||||
new_key = prefix + suffix
|
||||
|
||||
param = state_dict.pop(key)
|
||||
if bias:
|
||||
param = param.squeeze()
|
||||
state_dict[new_key] = param
|
||||
return
|
||||
return
|
||||
return
|
||||
|
||||
|
||||
def convert_animate_face_adapter_weights(key: str, state_dict: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Convert face adapter weights for the Animate model.
|
||||
|
||||
The original model uses a fused KV projection but the diffusers models uses separate K and V projections.
|
||||
"""
|
||||
# Skip if not a weight or bias
|
||||
if ".weight" not in key and ".bias" not in key:
|
||||
return
|
||||
|
||||
prefix = "face_adapter."
|
||||
if ".fuser_blocks." in key:
|
||||
parts = key.split(".")
|
||||
|
||||
module_list_idx = parts.index("fuser_blocks") if "fuser_blocks" in parts else -1
|
||||
if module_list_idx >= 0 and (len(parts) - 1) - module_list_idx == 3:
|
||||
block_idx = parts[module_list_idx + 1]
|
||||
layer_name = parts[module_list_idx + 2]
|
||||
param_name = parts[module_list_idx + 3]
|
||||
|
||||
if layer_name == "linear1_kv":
|
||||
layer_name_k = "to_k"
|
||||
layer_name_v = "to_v"
|
||||
|
||||
suffix_k = ".".join([block_idx, layer_name_k, param_name])
|
||||
suffix_v = ".".join([block_idx, layer_name_v, param_name])
|
||||
new_key_k = prefix + suffix_k
|
||||
new_key_v = prefix + suffix_v
|
||||
|
||||
kv_proj = state_dict.pop(key)
|
||||
k_proj, v_proj = torch.chunk(kv_proj, 2, dim=0)
|
||||
state_dict[new_key_k] = k_proj
|
||||
state_dict[new_key_v] = v_proj
|
||||
return
|
||||
else:
|
||||
if layer_name == "q_norm":
|
||||
new_layer_name = "norm_q"
|
||||
elif layer_name == "k_norm":
|
||||
new_layer_name = "norm_k"
|
||||
elif layer_name == "linear1_q":
|
||||
new_layer_name = "to_q"
|
||||
elif layer_name == "linear2":
|
||||
new_layer_name = "to_out"
|
||||
|
||||
suffix_parts = [block_idx, new_layer_name, param_name]
|
||||
suffix = ".".join(suffix_parts)
|
||||
new_key = prefix + suffix
|
||||
state_dict[new_key] = state_dict.pop(key)
|
||||
return
|
||||
return
|
||||
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
|
||||
VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {}
|
||||
ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"motion_encoder": convert_animate_motion_encoder_weights,
|
||||
"face_adapter": convert_animate_face_adapter_weights,
|
||||
}
|
||||
|
||||
|
||||
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
|
||||
@@ -364,6 +568,37 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
|
||||
}
|
||||
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif model_type == "Wan2.2-Animate-14B":
|
||||
config = {
|
||||
"model_id": "Wan-AI/Wan2.2-Animate-14B",
|
||||
"diffusers_config": {
|
||||
"image_dim": 1280,
|
||||
"added_kv_proj_dim": 5120,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attn_norm": True,
|
||||
"eps": 1e-06,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"in_channels": 36,
|
||||
"num_attention_heads": 40,
|
||||
"num_layers": 40,
|
||||
"out_channels": 16,
|
||||
"patch_size": (1, 2, 2),
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"text_dim": 4096,
|
||||
"rope_max_seq_len": 1024,
|
||||
"pos_embed_seq_len": None,
|
||||
"motion_encoder_size": 512, # Start of Wan Animate-specific configs
|
||||
"motion_style_dim": 512,
|
||||
"motion_dim": 20,
|
||||
"motion_encoder_dim": 512,
|
||||
"face_encoder_hidden_dim": 1024,
|
||||
"face_encoder_num_heads": 4,
|
||||
"inject_face_latents_blocks": 5,
|
||||
},
|
||||
}
|
||||
RENAME_DICT = ANIMATE_TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
return config, RENAME_DICT, SPECIAL_KEYS_REMAP
|
||||
|
||||
|
||||
@@ -380,10 +615,12 @@ def convert_transformer(model_type: str, stage: str = None):
|
||||
original_state_dict = load_sharded_safetensors(model_dir)
|
||||
|
||||
with init_empty_weights():
|
||||
if "VACE" not in model_type:
|
||||
transformer = WanTransformer3DModel.from_config(diffusers_config)
|
||||
else:
|
||||
if "Animate" in model_type:
|
||||
transformer = WanAnimateTransformer3DModel.from_config(diffusers_config)
|
||||
elif "VACE" in model_type:
|
||||
transformer = WanVACETransformer3DModel.from_config(diffusers_config)
|
||||
else:
|
||||
transformer = WanTransformer3DModel.from_config(diffusers_config)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
@@ -397,7 +634,12 @@ def convert_transformer(model_type: str, stage: str = None):
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
# Load state dict into the meta model, which will materialize the tensors
|
||||
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
|
||||
# Move to CPU to ensure all tensors are materialized
|
||||
transformer = transformer.to("cpu")
|
||||
|
||||
return transformer
|
||||
|
||||
|
||||
@@ -926,7 +1168,7 @@ DTYPE_MAPPING = {
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
if "Wan2.2" in args.model_type and "TI2V" not in args.model_type:
|
||||
if "Wan2.2" in args.model_type and "TI2V" not in args.model_type and "Animate" not in args.model_type:
|
||||
transformer = convert_transformer(args.model_type, stage="high_noise_model")
|
||||
transformer_2 = convert_transformer(args.model_type, stage="low_noise_model")
|
||||
else:
|
||||
@@ -942,7 +1184,7 @@ if __name__ == "__main__":
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
|
||||
if "FLF2V" in args.model_type:
|
||||
flow_shift = 16.0
|
||||
elif "TI2V" in args.model_type:
|
||||
elif "TI2V" in args.model_type or "Animate" in args.model_type:
|
||||
flow_shift = 5.0
|
||||
else:
|
||||
flow_shift = 3.0
|
||||
@@ -954,6 +1196,8 @@ if __name__ == "__main__":
|
||||
if args.dtype != "none":
|
||||
dtype = DTYPE_MAPPING[args.dtype]
|
||||
transformer.to(dtype)
|
||||
if transformer_2 is not None:
|
||||
transformer_2.to(dtype)
|
||||
|
||||
if "Wan2.2" and "I2V" in args.model_type and "TI2V" not in args.model_type:
|
||||
pipe = WanImageToVideoPipeline(
|
||||
@@ -1016,6 +1260,21 @@ if __name__ == "__main__":
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
elif "Animate" in args.model_type:
|
||||
image_encoder = CLIPVisionModel.from_pretrained(
|
||||
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
|
||||
)
|
||||
image_processor = CLIPImageProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
|
||||
|
||||
pipe = WanAnimatePipeline(
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
image_encoder=image_encoder,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
else:
|
||||
pipe = WanPipeline(
|
||||
transformer=transformer,
|
||||
|
||||
@@ -202,6 +202,7 @@ else:
|
||||
"BriaTransformer2DModel",
|
||||
"CacheMixin",
|
||||
"ChromaTransformer2DModel",
|
||||
"ChronoEditTransformer3DModel",
|
||||
"CogVideoXTransformer3DModel",
|
||||
"CogView3PlusTransformer2DModel",
|
||||
"CogView4Transformer2DModel",
|
||||
@@ -267,6 +268,7 @@ else:
|
||||
"UNetSpatioTemporalConditionModel",
|
||||
"UVit2DModel",
|
||||
"VQModel",
|
||||
"WanAnimateTransformer3DModel",
|
||||
"WanTransformer3DModel",
|
||||
"WanVACETransformer3DModel",
|
||||
"attention_backend",
|
||||
@@ -406,6 +408,7 @@ else:
|
||||
"QwenImageModularPipeline",
|
||||
"StableDiffusionXLAutoBlocks",
|
||||
"StableDiffusionXLModularPipeline",
|
||||
"Wan22AutoBlocks",
|
||||
"WanAutoBlocks",
|
||||
"WanModularPipeline",
|
||||
]
|
||||
@@ -436,6 +439,7 @@ else:
|
||||
"BriaPipeline",
|
||||
"ChromaImg2ImgPipeline",
|
||||
"ChromaPipeline",
|
||||
"ChronoEditPipeline",
|
||||
"CLIPImageProjection",
|
||||
"CogVideoXFunControlPipeline",
|
||||
"CogVideoXImageToVideoPipeline",
|
||||
@@ -541,11 +545,13 @@ else:
|
||||
"QwenImagePipeline",
|
||||
"ReduxImageEncoder",
|
||||
"SanaControlNetPipeline",
|
||||
"SanaImageToVideoPipeline",
|
||||
"SanaPAGPipeline",
|
||||
"SanaPipeline",
|
||||
"SanaSprintImg2ImgPipeline",
|
||||
"SanaSprintPipeline",
|
||||
"SanaVideoPipeline",
|
||||
"SanaVideoPipeline",
|
||||
"SemanticStableDiffusionPipeline",
|
||||
"ShapEImg2ImgPipeline",
|
||||
"ShapEPipeline",
|
||||
@@ -633,6 +639,7 @@ else:
|
||||
"VisualClozeGenerationPipeline",
|
||||
"VisualClozePipeline",
|
||||
"VQDiffusionPipeline",
|
||||
"WanAnimatePipeline",
|
||||
"WanImageToVideoPipeline",
|
||||
"WanPipeline",
|
||||
"WanVACEPipeline",
|
||||
@@ -909,6 +916,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
BriaTransformer2DModel,
|
||||
CacheMixin,
|
||||
ChromaTransformer2DModel,
|
||||
ChronoEditTransformer3DModel,
|
||||
CogVideoXTransformer3DModel,
|
||||
CogView3PlusTransformer2DModel,
|
||||
CogView4Transformer2DModel,
|
||||
@@ -973,6 +981,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
UNetSpatioTemporalConditionModel,
|
||||
UVit2DModel,
|
||||
VQModel,
|
||||
WanAnimateTransformer3DModel,
|
||||
WanTransformer3DModel,
|
||||
WanVACETransformer3DModel,
|
||||
attention_backend,
|
||||
@@ -1087,6 +1096,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
QwenImageModularPipeline,
|
||||
StableDiffusionXLAutoBlocks,
|
||||
StableDiffusionXLModularPipeline,
|
||||
Wan22AutoBlocks,
|
||||
WanAutoBlocks,
|
||||
WanModularPipeline,
|
||||
)
|
||||
@@ -1113,6 +1123,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
BriaPipeline,
|
||||
ChromaImg2ImgPipeline,
|
||||
ChromaPipeline,
|
||||
ChronoEditPipeline,
|
||||
CLIPImageProjection,
|
||||
CogVideoXFunControlPipeline,
|
||||
CogVideoXImageToVideoPipeline,
|
||||
@@ -1218,6 +1229,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
QwenImagePipeline,
|
||||
ReduxImageEncoder,
|
||||
SanaControlNetPipeline,
|
||||
SanaImageToVideoPipeline,
|
||||
SanaPAGPipeline,
|
||||
SanaPipeline,
|
||||
SanaSprintImg2ImgPipeline,
|
||||
@@ -1309,6 +1321,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
VisualClozeGenerationPipeline,
|
||||
VisualClozePipeline,
|
||||
VQDiffusionPipeline,
|
||||
WanAnimatePipeline,
|
||||
WanImageToVideoPipeline,
|
||||
WanPipeline,
|
||||
WanVACEPipeline,
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -88,6 +88,19 @@ class AdaptiveProjectedGuidance(BaseGuidance):
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
if self._step == 0:
|
||||
if self.adaptive_projected_guidance_momentum is not None:
|
||||
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
|
||||
pred = None
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -99,6 +99,19 @@ class AdaptiveProjectedMixGuidance(BaseGuidance):
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
if self._step == 0:
|
||||
if self.adaptive_projected_guidance_momentum is not None:
|
||||
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
|
||||
pred = None
|
||||
|
||||
|
||||
@@ -141,6 +141,16 @@ class AutoGuidance(BaseGuidance):
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
|
||||
pred = None
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -99,6 +99,16 @@ class ClassifierFreeGuidance(BaseGuidance):
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
|
||||
pred = None
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -85,6 +85,16 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
|
||||
pred = None
|
||||
|
||||
|
||||
@@ -226,6 +226,16 @@ class FrequencyDecoupledGuidance(BaseGuidance):
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
|
||||
pred = None
|
||||
|
||||
|
||||
@@ -166,6 +166,11 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
||||
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
raise NotImplementedError("BaseGuidance::prepare_inputs_from_block_state must be implemented in subclasses.")
|
||||
|
||||
def __call__(self, data: List["BlockState"]) -> Any:
|
||||
if not all(hasattr(d, "noise_pred") for d in data):
|
||||
raise ValueError("Expected all data to have `noise_pred` attribute.")
|
||||
@@ -234,6 +239,51 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
data_batch[cls._identifier_key] = identifier
|
||||
return BlockState(**data_batch)
|
||||
|
||||
@classmethod
|
||||
def _prepare_batch_from_block_state(
|
||||
cls,
|
||||
input_fields: Dict[str, Union[str, Tuple[str, str]]],
|
||||
data: "BlockState",
|
||||
tuple_index: int,
|
||||
identifier: str,
|
||||
) -> "BlockState":
|
||||
"""
|
||||
Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of the
|
||||
`BaseGuidance` class. It prepares the batch based on the provided tuple index.
|
||||
|
||||
Args:
|
||||
input_fields (`Dict[str, Union[str, Tuple[str, str]]]`):
|
||||
A dictionary where the keys are the names of the fields that will be used to store the data once it is
|
||||
prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
|
||||
to look up the required data provided for preparation. If a string is provided, it will be used as the
|
||||
conditional data (or unconditional if used with a guidance method that requires it). If a tuple of
|
||||
length 2 is provided, the first element must be the conditional data identifier and the second element
|
||||
must be the unconditional data identifier or None.
|
||||
data (`BlockState`):
|
||||
The input data to be prepared.
|
||||
tuple_index (`int`):
|
||||
The index to use when accessing input fields that are tuples.
|
||||
|
||||
Returns:
|
||||
`BlockState`: The prepared batch of data.
|
||||
"""
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
data_batch = {}
|
||||
for key, value in input_fields.items():
|
||||
try:
|
||||
if isinstance(value, str):
|
||||
data_batch[key] = getattr(data, value)
|
||||
elif isinstance(value, tuple):
|
||||
data_batch[key] = getattr(data, value[tuple_index])
|
||||
else:
|
||||
# We've already checked that value is a string or a tuple of strings with length 2
|
||||
pass
|
||||
except AttributeError:
|
||||
logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
|
||||
data_batch[cls._identifier_key] = identifier
|
||||
return BlockState(**data_batch)
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_pretrained(
|
||||
@@ -323,7 +373,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
r"""
|
||||
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
|
||||
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
|
||||
Args:
|
||||
noise_cfg (`torch.Tensor`):
|
||||
|
||||
@@ -187,6 +187,26 @@ class PerturbedAttentionGuidance(BaseGuidance):
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
elif self.num_conditions == 2:
|
||||
tuple_indices = [0, 1]
|
||||
input_predictions = (
|
||||
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
|
||||
)
|
||||
else:
|
||||
tuple_indices = [0, 1, 0]
|
||||
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.forward
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -183,6 +183,26 @@ class SkipLayerGuidance(BaseGuidance):
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
elif self.num_conditions == 2:
|
||||
tuple_indices = [0, 1]
|
||||
input_predictions = (
|
||||
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
|
||||
)
|
||||
else:
|
||||
tuple_indices = [0, 1, 0]
|
||||
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred_cond: torch.Tensor,
|
||||
|
||||
@@ -172,6 +172,26 @@ class SmoothedEnergyGuidance(BaseGuidance):
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
elif self.num_conditions == 2:
|
||||
tuple_indices = [0, 1]
|
||||
input_predictions = (
|
||||
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"]
|
||||
)
|
||||
else:
|
||||
tuple_indices = [0, 1, 0]
|
||||
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred_cond: torch.Tensor,
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -74,6 +74,16 @@ class TangentialClassifierFreeGuidance(BaseGuidance):
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
|
||||
pred = None
|
||||
|
||||
|
||||
@@ -409,7 +409,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
src_w = width if ratio < src_ratio else image.width * height // image.height
|
||||
src_h = height if ratio >= src_ratio else image.height * width // image.width
|
||||
|
||||
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
|
||||
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION[self.config.resample])
|
||||
res = Image.new("RGB", (width, height))
|
||||
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
||||
|
||||
@@ -460,7 +460,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
src_w = width if ratio > src_ratio else image.width * height // image.height
|
||||
src_h = height if ratio <= src_ratio else image.height * width // image.width
|
||||
|
||||
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
|
||||
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION[self.config.resample])
|
||||
res = Image.new("RGB", (width, height))
|
||||
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
||||
return res
|
||||
|
||||
@@ -86,6 +86,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_bria"] = ["BriaTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_bria_fibo"] = ["BriaFiboTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_chronoedit"] = ["ChronoEditTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_cosmos"] = ["CosmosTransformer3DModel"]
|
||||
@@ -107,6 +108,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
|
||||
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
|
||||
_import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_wan_animate"] = ["WanAnimateTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"]
|
||||
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
|
||||
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
|
||||
@@ -179,6 +181,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
BriaFiboTransformer2DModel,
|
||||
BriaTransformer2DModel,
|
||||
ChromaTransformer2DModel,
|
||||
ChronoEditTransformer3DModel,
|
||||
CogVideoXTransformer3DModel,
|
||||
CogView3PlusTransformer2DModel,
|
||||
CogView4Transformer2DModel,
|
||||
@@ -212,6 +215,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
T5FilmDecoder,
|
||||
Transformer2DModel,
|
||||
TransformerTemporalModel,
|
||||
WanAnimateTransformer3DModel,
|
||||
WanTransformer3DModel,
|
||||
WanVACETransformer3DModel,
|
||||
)
|
||||
|
||||
@@ -44,11 +44,16 @@ class ContextParallelConfig:
|
||||
|
||||
Args:
|
||||
ring_degree (`int`, *optional*, defaults to `1`):
|
||||
Number of devices to use for ring attention within a context parallel region. Must be a divisor of the
|
||||
total number of devices in the context parallel mesh.
|
||||
Number of devices to use for Ring Attention. Sequence is split across devices. Each device computes
|
||||
attention between its local Q and KV chunks passed sequentially around ring. Lower memory (only holds 1/N
|
||||
of KV at a time), overlaps compute with communication, but requires N iterations to see all tokens. Best
|
||||
for long sequences with limited memory/bandwidth. Number of devices to use for ring attention within a
|
||||
context parallel region. Must be a divisor of the total number of devices in the context parallel mesh.
|
||||
ulysses_degree (`int`, *optional*, defaults to `1`):
|
||||
Number of devices to use for ulysses attention within a context parallel region. Must be a divisor of the
|
||||
total number of devices in the context parallel mesh.
|
||||
Number of devices to use for Ulysses Attention. Sequence split is across devices. Each device computes
|
||||
local QKV, then all-gathers all KV chunks to compute full attention in one pass. Higher memory (stores all
|
||||
KV), requires high-bandwidth all-to-all communication, but lower latency. Best for moderate sequences with
|
||||
good interconnect bandwidth.
|
||||
convert_to_fp32 (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert output and LSE to float32 for ring attention numerical stability.
|
||||
rotate_method (`str`, *optional*, defaults to `"allgather"`):
|
||||
@@ -79,29 +84,46 @@ class ContextParallelConfig:
|
||||
if self.ulysses_degree is None:
|
||||
self.ulysses_degree = 1
|
||||
|
||||
if self.ring_degree == 1 and self.ulysses_degree == 1:
|
||||
raise ValueError(
|
||||
"Either ring_degree or ulysses_degree must be greater than 1 in order to use context parallel inference"
|
||||
)
|
||||
if self.ring_degree < 1 or self.ulysses_degree < 1:
|
||||
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
|
||||
if self.ring_degree > 1 and self.ulysses_degree > 1:
|
||||
raise ValueError(
|
||||
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
|
||||
)
|
||||
if self.rotate_method != "allgather":
|
||||
raise NotImplementedError(
|
||||
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
|
||||
)
|
||||
|
||||
@property
|
||||
def mesh_shape(self) -> Tuple[int, int]:
|
||||
return (self.ring_degree, self.ulysses_degree)
|
||||
|
||||
@property
|
||||
def mesh_dim_names(self) -> Tuple[str, str]:
|
||||
"""Dimension names for the device mesh."""
|
||||
return ("ring", "ulysses")
|
||||
|
||||
def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.distributed.device_mesh.DeviceMesh):
|
||||
self._rank = rank
|
||||
self._world_size = world_size
|
||||
self._device = device
|
||||
self._mesh = mesh
|
||||
if self.ring_degree is None:
|
||||
self.ring_degree = 1
|
||||
if self.ulysses_degree is None:
|
||||
self.ulysses_degree = 1
|
||||
if self.rotate_method != "allgather":
|
||||
raise NotImplementedError(
|
||||
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
|
||||
|
||||
if self.ulysses_degree * self.ring_degree > world_size:
|
||||
raise ValueError(
|
||||
f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
|
||||
)
|
||||
if self._flattened_mesh is None:
|
||||
self._flattened_mesh = self._mesh._flatten()
|
||||
if self._ring_mesh is None:
|
||||
self._ring_mesh = self._mesh["ring"]
|
||||
if self._ulysses_mesh is None:
|
||||
self._ulysses_mesh = self._mesh["ulysses"]
|
||||
if self._ring_local_rank is None:
|
||||
self._ring_local_rank = self._ring_mesh.get_local_rank()
|
||||
if self._ulysses_local_rank is None:
|
||||
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
|
||||
|
||||
self._flattened_mesh = self._mesh._flatten()
|
||||
self._ring_mesh = self._mesh["ring"]
|
||||
self._ulysses_mesh = self._mesh["ulysses"]
|
||||
self._ring_local_rank = self._ring_mesh.get_local_rank()
|
||||
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -119,7 +141,7 @@ class ParallelConfig:
|
||||
_rank: int = None
|
||||
_world_size: int = None
|
||||
_device: torch.device = None
|
||||
_cp_mesh: torch.distributed.device_mesh.DeviceMesh = None
|
||||
_mesh: torch.distributed.device_mesh.DeviceMesh = None
|
||||
|
||||
def setup(
|
||||
self,
|
||||
@@ -127,14 +149,14 @@ class ParallelConfig:
|
||||
world_size: int,
|
||||
device: torch.device,
|
||||
*,
|
||||
cp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
|
||||
mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
|
||||
):
|
||||
self._rank = rank
|
||||
self._world_size = world_size
|
||||
self._device = device
|
||||
self._cp_mesh = cp_mesh
|
||||
self._mesh = mesh
|
||||
if self.context_parallel_config is not None:
|
||||
self.context_parallel_config.setup(rank, world_size, device, cp_mesh)
|
||||
self.context_parallel_config.setup(rank, world_size, device, mesh)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
||||
@@ -16,6 +16,7 @@ import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
@@ -42,7 +43,7 @@ from ..utils import (
|
||||
is_xformers_available,
|
||||
is_xformers_version,
|
||||
)
|
||||
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ENABLE_HUB_KERNELS
|
||||
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -82,24 +83,11 @@ else:
|
||||
flash_attn_3_func = None
|
||||
flash_attn_3_varlen_func = None
|
||||
|
||||
|
||||
if _CAN_USE_AITER_ATTN:
|
||||
from aiter import flash_attn_func as aiter_flash_attn_func
|
||||
else:
|
||||
aiter_flash_attn_func = None
|
||||
|
||||
if DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||
if not is_kernels_available():
|
||||
raise ImportError(
|
||||
"To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
|
||||
)
|
||||
from ..utils.kernels_utils import _get_fa3_from_hub
|
||||
|
||||
flash_attn_interface_hub = _get_fa3_from_hub()
|
||||
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func
|
||||
else:
|
||||
flash_attn_3_func_hub = None
|
||||
|
||||
if _CAN_USE_SAGE_ATTN:
|
||||
from sageattention import (
|
||||
sageattn,
|
||||
@@ -220,7 +208,7 @@ class _AttentionBackendRegistry:
|
||||
_backends = {}
|
||||
_constraints = {}
|
||||
_supported_arg_names = {}
|
||||
_supports_context_parallel = {}
|
||||
_supports_context_parallel = set()
|
||||
_active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND)
|
||||
_checks_enabled = DIFFUSERS_ATTN_CHECKS
|
||||
|
||||
@@ -237,7 +225,9 @@ class _AttentionBackendRegistry:
|
||||
cls._backends[backend] = func
|
||||
cls._constraints[backend] = constraints or []
|
||||
cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys())
|
||||
cls._supports_context_parallel[backend] = supports_context_parallel
|
||||
if supports_context_parallel:
|
||||
cls._supports_context_parallel.add(backend.value)
|
||||
|
||||
return func
|
||||
|
||||
return decorator
|
||||
@@ -251,15 +241,31 @@ class _AttentionBackendRegistry:
|
||||
return list(cls._backends.keys())
|
||||
|
||||
@classmethod
|
||||
def _is_context_parallel_enabled(
|
||||
cls, backend: AttentionBackendName, parallel_config: Optional["ParallelConfig"]
|
||||
def _is_context_parallel_available(
|
||||
cls,
|
||||
backend: AttentionBackendName,
|
||||
) -> bool:
|
||||
supports_context_parallel = backend in cls._supports_context_parallel
|
||||
is_degree_greater_than_1 = parallel_config is not None and (
|
||||
parallel_config.context_parallel_config.ring_degree > 1
|
||||
or parallel_config.context_parallel_config.ulysses_degree > 1
|
||||
)
|
||||
return supports_context_parallel and is_degree_greater_than_1
|
||||
supports_context_parallel = backend.value in cls._supports_context_parallel
|
||||
return supports_context_parallel
|
||||
|
||||
|
||||
@dataclass
|
||||
class _HubKernelConfig:
|
||||
"""Configuration for downloading and using a hub-based attention kernel."""
|
||||
|
||||
repo_id: str
|
||||
function_attr: str
|
||||
revision: Optional[str] = None
|
||||
kernel_fn: Optional[Callable] = None
|
||||
|
||||
|
||||
# Registry for hub-based attention kernels
|
||||
_HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = {
|
||||
# TODO: temporary revision for now. Remove when merged upstream into `main`.
|
||||
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@@ -306,14 +312,6 @@ def dispatch_attention_fn(
|
||||
backend_name = AttentionBackendName(backend)
|
||||
backend_fn = _AttentionBackendRegistry._backends.get(backend_name)
|
||||
|
||||
if parallel_config is not None and not _AttentionBackendRegistry._is_context_parallel_enabled(
|
||||
backend_name, parallel_config
|
||||
):
|
||||
raise ValueError(
|
||||
f"Backend {backend_name} either does not support context parallelism or context parallelism "
|
||||
f"was enabled with a world size of 1."
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
"query": query,
|
||||
"key": key,
|
||||
@@ -392,12 +390,18 @@ def _check_shape(
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
# Expected shapes:
|
||||
# query: (batch_size, seq_len_q, num_heads, head_dim)
|
||||
# key: (batch_size, seq_len_kv, num_heads, head_dim)
|
||||
# value: (batch_size, seq_len_kv, num_heads, head_dim)
|
||||
# attn_mask: (seq_len_q, seq_len_kv) or (batch_size, seq_len_q, seq_len_kv)
|
||||
# or (batch_size, num_heads, seq_len_q, seq_len_kv)
|
||||
if query.shape[-1] != key.shape[-1]:
|
||||
raise ValueError("Query and key must have the same last dimension.")
|
||||
if query.shape[-2] != value.shape[-2]:
|
||||
raise ValueError("Query and value must have the same second to last dimension.")
|
||||
if attn_mask is not None and attn_mask.shape[-1] != key.shape[-2]:
|
||||
raise ValueError("Attention mask must match the key's second to last dimension.")
|
||||
raise ValueError("Query and key must have the same head dimension.")
|
||||
if key.shape[-3] != value.shape[-3]:
|
||||
raise ValueError("Key and value must have the same sequence length.")
|
||||
if attn_mask is not None and attn_mask.shape[-1] != key.shape[-3]:
|
||||
raise ValueError("Attention mask must match the key's sequence length.")
|
||||
|
||||
|
||||
# ===== Helper functions =====
|
||||
@@ -418,13 +422,9 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
||||
|
||||
# TODO: add support Hub variant of FA3 varlen later
|
||||
elif backend in [AttentionBackendName._FLASH_3_HUB]:
|
||||
if not DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||
raise RuntimeError(
|
||||
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
|
||||
)
|
||||
if not is_kernels_available():
|
||||
raise RuntimeError(
|
||||
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
|
||||
f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
|
||||
)
|
||||
|
||||
elif backend == AttentionBackendName.AITER:
|
||||
@@ -574,6 +574,29 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
||||
return q_idx >= kv_idx
|
||||
|
||||
|
||||
# ===== Helpers for downloading kernels =====
|
||||
def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
|
||||
if backend not in _HUB_KERNELS_REGISTRY:
|
||||
return
|
||||
config = _HUB_KERNELS_REGISTRY[backend]
|
||||
|
||||
if config.kernel_fn is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
from kernels import get_kernel
|
||||
|
||||
kernel_module = get_kernel(config.repo_id, revision=config.revision)
|
||||
kernel_func = getattr(kernel_module, config.function_attr)
|
||||
|
||||
# Cache the downloaded kernel function in the config object
|
||||
config.kernel_fn = kernel_func
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}")
|
||||
raise
|
||||
|
||||
|
||||
# ===== torch op registrations =====
|
||||
# Registrations are required for fullgraph tracing compatibility
|
||||
# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
|
||||
@@ -1421,7 +1444,8 @@ def _flash_attention_3_hub(
|
||||
return_attn_probs: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
out = flash_attn_3_func_hub(
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
|
||||
@@ -102,7 +102,7 @@ def get_block(
|
||||
attention_head_dim: int,
|
||||
norm_type: str,
|
||||
act_fn: str,
|
||||
qkv_mutliscales: Tuple[int] = (),
|
||||
qkv_mutliscales: Tuple[int, ...] = (),
|
||||
):
|
||||
if block_type == "ResBlock":
|
||||
block = ResBlock(in_channels, out_channels, norm_type, act_fn)
|
||||
@@ -206,8 +206,8 @@ class Encoder(nn.Module):
|
||||
latent_channels: int,
|
||||
attention_head_dim: int = 32,
|
||||
block_type: Union[str, Tuple[str]] = "ResBlock",
|
||||
block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024),
|
||||
layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2),
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
|
||||
layers_per_block: Tuple[int, ...] = (2, 2, 2, 2, 2, 2),
|
||||
qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||
downsample_block_type: str = "pixel_unshuffle",
|
||||
out_shortcut: bool = True,
|
||||
@@ -292,8 +292,8 @@ class Decoder(nn.Module):
|
||||
latent_channels: int,
|
||||
attention_head_dim: int = 32,
|
||||
block_type: Union[str, Tuple[str]] = "ResBlock",
|
||||
block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024),
|
||||
layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2),
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
|
||||
layers_per_block: Tuple[int, ...] = (2, 2, 2, 2, 2, 2),
|
||||
qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||
norm_type: Union[str, Tuple[str]] = "rms_norm",
|
||||
act_fn: Union[str, Tuple[str]] = "silu",
|
||||
@@ -440,8 +440,8 @@ class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
decoder_block_types: Union[str, Tuple[str]] = "ResBlock",
|
||||
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
|
||||
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
|
||||
encoder_layers_per_block: Tuple[int] = (2, 2, 2, 3, 3, 3),
|
||||
decoder_layers_per_block: Tuple[int] = (3, 3, 3, 3, 3, 3),
|
||||
encoder_layers_per_block: Tuple[int, ...] = (2, 2, 2, 3, 3, 3),
|
||||
decoder_layers_per_block: Tuple[int, ...] = (3, 3, 3, 3, 3, 3),
|
||||
encoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||
decoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||
upsample_block_type: str = "pixel_shuffle",
|
||||
|
||||
@@ -78,9 +78,9 @@ class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModel
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
||||
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
|
||||
block_out_channels: Tuple[int] = (64,),
|
||||
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
|
||||
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
|
||||
block_out_channels: Tuple[int, ...] = (64,),
|
||||
layers_per_block: int = 1,
|
||||
act_fn: str = "silu",
|
||||
latent_channels: int = 4,
|
||||
|
||||
@@ -995,19 +995,19 @@ class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str] = (
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CogVideoXDownBlock3D",
|
||||
"CogVideoXDownBlock3D",
|
||||
"CogVideoXDownBlock3D",
|
||||
"CogVideoXDownBlock3D",
|
||||
),
|
||||
up_block_types: Tuple[str] = (
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
"CogVideoXUpBlock3D",
|
||||
"CogVideoXUpBlock3D",
|
||||
"CogVideoXUpBlock3D",
|
||||
"CogVideoXUpBlock3D",
|
||||
),
|
||||
block_out_channels: Tuple[int] = (128, 256, 256, 512),
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
|
||||
latent_channels: int = 16,
|
||||
layers_per_block: int = 3,
|
||||
act_fn: str = "silu",
|
||||
|
||||
@@ -653,7 +653,7 @@ class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
"HunyuanVideoUpBlock3D",
|
||||
"HunyuanVideoUpBlock3D",
|
||||
),
|
||||
block_out_channels: Tuple[int] = (128, 256, 512, 512),
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
layers_per_block: int = 2,
|
||||
act_fn: str = "silu",
|
||||
norm_num_groups: int = 32,
|
||||
|
||||
@@ -601,7 +601,7 @@ class AutoencoderKLHunyuanImageRefiner(ModelMixin, ConfigMixin):
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
latent_channels: int = 32,
|
||||
block_out_channels: Tuple[int] = (128, 256, 512, 1024, 1024),
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024, 1024),
|
||||
layers_per_block: int = 2,
|
||||
spatial_compression_ratio: int = 16,
|
||||
temporal_compression_ratio: int = 4,
|
||||
|
||||
@@ -688,8 +688,8 @@ class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
self,
|
||||
in_channels: int = 15,
|
||||
out_channels: int = 3,
|
||||
encoder_block_out_channels: Tuple[int] = (64, 128, 256, 384),
|
||||
decoder_block_out_channels: Tuple[int] = (128, 256, 512, 768),
|
||||
encoder_block_out_channels: Tuple[int, ...] = (64, 128, 256, 384),
|
||||
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 768),
|
||||
latent_channels: int = 12,
|
||||
layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3),
|
||||
act_fn: str = "silu",
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
# QwenImageVAE is further fine-tuned from the Wan Video VAE to achieve improved performance.
|
||||
# For more information about the Wan VAE, please refer to:
|
||||
# - GitHub: https://github.com/Wan-Video/Wan2.1
|
||||
# - arXiv: https://arxiv.org/abs/2503.20314
|
||||
# - Paper: https://huggingface.co/papers/2503.20314
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
@@ -679,7 +679,7 @@ class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
self,
|
||||
base_dim: int = 96,
|
||||
z_dim: int = 16,
|
||||
dim_mult: Tuple[int] = [1, 2, 4, 4],
|
||||
dim_mult: Tuple[int, ...] = (1, 2, 4, 4),
|
||||
num_res_blocks: int = 2,
|
||||
attn_scales: List[float] = [],
|
||||
temperal_downsample: List[bool] = [False, True, True],
|
||||
|
||||
@@ -31,7 +31,7 @@ class TemporalDecoder(nn.Module):
|
||||
self,
|
||||
in_channels: int = 4,
|
||||
out_channels: int = 3,
|
||||
block_out_channels: Tuple[int] = (128, 256, 512, 512),
|
||||
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
layers_per_block: int = 2,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -172,8 +172,8 @@ class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
||||
block_out_channels: Tuple[int] = (64,),
|
||||
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
|
||||
block_out_channels: Tuple[int, ...] = (64,),
|
||||
layers_per_block: int = 1,
|
||||
latent_channels: int = 4,
|
||||
sample_size: int = 32,
|
||||
|
||||
@@ -971,7 +971,7 @@ class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalMo
|
||||
base_dim: int = 96,
|
||||
decoder_base_dim: Optional[int] = None,
|
||||
z_dim: int = 16,
|
||||
dim_mult: Tuple[int] = [1, 2, 4, 4],
|
||||
dim_mult: List[int] = [1, 2, 4, 4],
|
||||
num_res_blocks: int = 2,
|
||||
attn_scales: List[float] = [],
|
||||
temperal_downsample: List[bool] = [False, True, True],
|
||||
|
||||
@@ -293,14 +293,14 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
|
||||
self,
|
||||
conditioning_channels: int = 3,
|
||||
conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
|
||||
conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
|
||||
time_embedding_mix: float = 1.0,
|
||||
learn_time_embedding: bool = False,
|
||||
num_attention_heads: Union[int, Tuple[int]] = 4,
|
||||
block_out_channels: Tuple[int] = (4, 8, 16, 16),
|
||||
base_block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
block_out_channels: Tuple[int, ...] = (4, 8, 16, 16),
|
||||
base_block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
cross_attention_dim: int = 1024,
|
||||
down_block_types: Tuple[str] = (
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
@@ -436,7 +436,7 @@ class ControlNetXSAdapter(ModelMixin, ConfigMixin):
|
||||
time_embedding_mix: int = 1.0,
|
||||
conditioning_channels: int = 3,
|
||||
conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
|
||||
conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
|
||||
):
|
||||
r"""
|
||||
Instantiate a [`ControlNetXSAdapter`] from a [`UNet2DConditionModel`].
|
||||
@@ -529,14 +529,19 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
||||
self,
|
||||
# unet configs
|
||||
sample_size: Optional[int] = 96,
|
||||
down_block_types: Tuple[str] = (
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
"UpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
),
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
cross_attention_dim: Union[int, Tuple[int]] = 1024,
|
||||
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
|
||||
@@ -550,10 +555,10 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
||||
# additional controlnet configs
|
||||
time_embedding_mix: float = 1.0,
|
||||
ctrl_conditioning_channels: int = 3,
|
||||
ctrl_conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256),
|
||||
ctrl_conditioning_embedding_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
|
||||
ctrl_conditioning_channel_order: str = "rgb",
|
||||
ctrl_learn_time_embedding: bool = False,
|
||||
ctrl_block_out_channels: Tuple[int] = (4, 8, 16, 16),
|
||||
ctrl_block_out_channels: Tuple[int, ...] = (4, 8, 16, 16),
|
||||
ctrl_num_attention_heads: Union[int, Tuple[int]] = 4,
|
||||
ctrl_max_norm_num_groups: int = 32,
|
||||
):
|
||||
|
||||
@@ -595,7 +595,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
attention as backend.
|
||||
"""
|
||||
from .attention import AttentionModuleMixin
|
||||
from .attention_dispatch import AttentionBackendName, _check_attention_backend_requirements
|
||||
from .attention_dispatch import (
|
||||
AttentionBackendName,
|
||||
_check_attention_backend_requirements,
|
||||
_maybe_download_kernel_for_backend,
|
||||
)
|
||||
|
||||
# TODO: the following will not be required when everything is refactored to AttentionModuleMixin
|
||||
from .attention_processor import Attention, MochiAttention
|
||||
@@ -606,8 +610,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
|
||||
if backend not in available_backends:
|
||||
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
|
||||
|
||||
backend = AttentionBackendName(backend)
|
||||
_check_attention_backend_requirements(backend)
|
||||
_maybe_download_kernel_for_backend(backend)
|
||||
|
||||
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
|
||||
for module in self.modules():
|
||||
@@ -1484,59 +1490,71 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
config: Union[ParallelConfig, ContextParallelConfig],
|
||||
cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None,
|
||||
):
|
||||
from ..hooks.context_parallel import apply_context_parallel
|
||||
from .attention import AttentionModuleMixin
|
||||
from .attention_processor import Attention, MochiAttention
|
||||
|
||||
logger.warning(
|
||||
"`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning."
|
||||
)
|
||||
|
||||
if not torch.distributed.is_available() and not torch.distributed.is_initialized():
|
||||
raise RuntimeError(
|
||||
"torch.distributed must be available and initialized before calling `enable_parallelism`."
|
||||
)
|
||||
|
||||
from ..hooks.context_parallel import apply_context_parallel
|
||||
from .attention import AttentionModuleMixin
|
||||
from .attention_dispatch import AttentionBackendName, _AttentionBackendRegistry
|
||||
from .attention_processor import Attention, MochiAttention
|
||||
|
||||
if isinstance(config, ContextParallelConfig):
|
||||
config = ParallelConfig(context_parallel_config=config)
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
raise RuntimeError("torch.distributed must be initialized before calling `enable_parallelism`.")
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
device_type = torch._C._get_accelerator().type
|
||||
device_module = torch.get_device_module(device_type)
|
||||
device = torch.device(device_type, rank % device_module.device_count())
|
||||
|
||||
cp_mesh = None
|
||||
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
|
||||
|
||||
if config.context_parallel_config is not None:
|
||||
for module in self.modules():
|
||||
if not isinstance(module, attention_classes):
|
||||
continue
|
||||
|
||||
processor = module.processor
|
||||
if processor is None or not hasattr(processor, "_attention_backend"):
|
||||
continue
|
||||
|
||||
attention_backend = processor._attention_backend
|
||||
if attention_backend is None:
|
||||
attention_backend, _ = _AttentionBackendRegistry.get_active_backend()
|
||||
else:
|
||||
attention_backend = AttentionBackendName(attention_backend)
|
||||
|
||||
if not _AttentionBackendRegistry._is_context_parallel_available(attention_backend):
|
||||
compatible_backends = sorted(_AttentionBackendRegistry._supports_context_parallel)
|
||||
raise ValueError(
|
||||
f"Context parallelism is enabled but the attention processor '{processor.__class__.__name__}' "
|
||||
f"is using backend '{attention_backend.value}' which does not support context parallelism. "
|
||||
f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()` before "
|
||||
f"calling `enable_parallelism()`."
|
||||
)
|
||||
|
||||
# All modules use the same attention processor and backend. We don't need to
|
||||
# iterate over all modules after checking the first processor
|
||||
break
|
||||
|
||||
mesh = None
|
||||
if config.context_parallel_config is not None:
|
||||
cp_config = config.context_parallel_config
|
||||
if cp_config.ring_degree < 1 or cp_config.ulysses_degree < 1:
|
||||
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
|
||||
if cp_config.ring_degree > 1 and cp_config.ulysses_degree > 1:
|
||||
raise ValueError(
|
||||
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
|
||||
)
|
||||
if cp_config.ring_degree * cp_config.ulysses_degree > world_size:
|
||||
raise ValueError(
|
||||
f"The product of `ring_degree` ({cp_config.ring_degree}) and `ulysses_degree` ({cp_config.ulysses_degree}) must not exceed the world size ({world_size})."
|
||||
)
|
||||
cp_mesh = torch.distributed.device_mesh.init_device_mesh(
|
||||
mesh = torch.distributed.device_mesh.init_device_mesh(
|
||||
device_type=device_type,
|
||||
mesh_shape=(cp_config.ring_degree, cp_config.ulysses_degree),
|
||||
mesh_dim_names=("ring", "ulysses"),
|
||||
mesh_shape=cp_config.mesh_shape,
|
||||
mesh_dim_names=cp_config.mesh_dim_names,
|
||||
)
|
||||
|
||||
config.setup(rank, world_size, device, cp_mesh=cp_mesh)
|
||||
|
||||
if cp_plan is None and self._cp_plan is None:
|
||||
raise ValueError(
|
||||
"`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute."
|
||||
)
|
||||
cp_plan = cp_plan if cp_plan is not None else self._cp_plan
|
||||
|
||||
if config.context_parallel_config is not None:
|
||||
apply_context_parallel(self, config.context_parallel_config, cp_plan)
|
||||
|
||||
config.setup(rank, world_size, device, mesh=mesh)
|
||||
self._parallel_config = config
|
||||
|
||||
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
|
||||
for module in self.modules():
|
||||
if not isinstance(module, attention_classes):
|
||||
continue
|
||||
@@ -1545,6 +1563,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
continue
|
||||
processor._parallel_config = config
|
||||
|
||||
if config.context_parallel_config is not None:
|
||||
if cp_plan is None and self._cp_plan is None:
|
||||
raise ValueError(
|
||||
"`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute."
|
||||
)
|
||||
cp_plan = cp_plan if cp_plan is not None else self._cp_plan
|
||||
apply_context_parallel(self, config.context_parallel_config, cp_plan)
|
||||
|
||||
@classmethod
|
||||
def _load_pretrained_model(
|
||||
cls,
|
||||
|
||||
@@ -20,6 +20,7 @@ if is_torch_available():
|
||||
from .transformer_bria import BriaTransformer2DModel
|
||||
from .transformer_bria_fibo import BriaFiboTransformer2DModel
|
||||
from .transformer_chroma import ChromaTransformer2DModel
|
||||
from .transformer_chronoedit import ChronoEditTransformer3DModel
|
||||
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
|
||||
from .transformer_cogview4 import CogView4Transformer2DModel
|
||||
from .transformer_cosmos import CosmosTransformer3DModel
|
||||
@@ -41,4 +42,5 @@ if is_torch_available():
|
||||
from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel
|
||||
from .transformer_temporal import TransformerTemporalModel
|
||||
from .transformer_wan import WanTransformer3DModel
|
||||
from .transformer_wan_animate import WanAnimateTransformer3DModel
|
||||
from .transformer_wan_vace import WanVACETransformer3DModel
|
||||
|
||||
@@ -0,0 +1,735 @@
|
||||
# Copyright 2025 The ChronoEdit Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import FP32LayerNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_wan._get_qkv_projections
|
||||
def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor):
|
||||
# encoder_hidden_states is only passed for cross-attention
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
if attn.fused_projections:
|
||||
if attn.cross_attention_dim_head is None:
|
||||
# In self-attention layers, we can fuse the entire QKV projection into a single linear
|
||||
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
|
||||
else:
|
||||
# In cross-attention layers, we can only fuse the KV projections into a single linear
|
||||
query = attn.to_q(hidden_states)
|
||||
key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1)
|
||||
else:
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
return query, key, value
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_wan._get_added_kv_projections
|
||||
def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: torch.Tensor):
|
||||
if attn.fused_projections:
|
||||
key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1)
|
||||
else:
|
||||
key_img = attn.add_k_proj(encoder_hidden_states_img)
|
||||
value_img = attn.add_v_proj(encoder_hidden_states_img)
|
||||
return key_img, value_img
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_wan.WanAttnProcessor
|
||||
class WanAttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "WanAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
encoder_hidden_states_img = None
|
||||
if attn.add_k_proj is not None:
|
||||
# 512 is the context length of the text encoder, hardcoded for now
|
||||
image_context_length = encoder_hidden_states.shape[1] - 512
|
||||
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
|
||||
encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
|
||||
|
||||
query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
|
||||
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1))
|
||||
key = key.unflatten(2, (attn.heads, -1))
|
||||
value = value.unflatten(2, (attn.heads, -1))
|
||||
|
||||
if rotary_emb is not None:
|
||||
|
||||
def apply_rotary_emb(
|
||||
hidden_states: torch.Tensor,
|
||||
freqs_cos: torch.Tensor,
|
||||
freqs_sin: torch.Tensor,
|
||||
):
|
||||
x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
|
||||
cos = freqs_cos[..., 0::2]
|
||||
sin = freqs_sin[..., 1::2]
|
||||
out = torch.empty_like(hidden_states)
|
||||
out[..., 0::2] = x1 * cos - x2 * sin
|
||||
out[..., 1::2] = x1 * sin + x2 * cos
|
||||
return out.type_as(hidden_states)
|
||||
|
||||
query = apply_rotary_emb(query, *rotary_emb)
|
||||
key = apply_rotary_emb(key, *rotary_emb)
|
||||
|
||||
# I2V task
|
||||
hidden_states_img = None
|
||||
if encoder_hidden_states_img is not None:
|
||||
key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img)
|
||||
key_img = attn.norm_added_k(key_img)
|
||||
|
||||
key_img = key_img.unflatten(2, (attn.heads, -1))
|
||||
value_img = value_img.unflatten(2, (attn.heads, -1))
|
||||
|
||||
hidden_states_img = dispatch_attention_fn(
|
||||
query,
|
||||
key_img,
|
||||
value_img,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states_img = hidden_states_img.flatten(2, 3)
|
||||
hidden_states_img = hidden_states_img.type_as(query)
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.type_as(query)
|
||||
|
||||
if hidden_states_img is not None:
|
||||
hidden_states = hidden_states + hidden_states_img
|
||||
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_wan.WanAttnProcessor2_0
|
||||
class WanAttnProcessor2_0:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = (
|
||||
"The WanAttnProcessor2_0 class is deprecated and will be removed in a future version. "
|
||||
"Please use WanAttnProcessor instead. "
|
||||
)
|
||||
deprecate("WanAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False)
|
||||
return WanAttnProcessor(*args, **kwargs)
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_wan.WanAttention
|
||||
class WanAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
_default_processor_cls = WanAttnProcessor
|
||||
_available_processors = [WanAttnProcessor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
eps: float = 1e-5,
|
||||
dropout: float = 0.0,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
cross_attention_dim_head: Optional[int] = None,
|
||||
processor=None,
|
||||
is_cross_attention=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.added_kv_proj_dim = added_kv_proj_dim
|
||||
self.cross_attention_dim_head = cross_attention_dim_head
|
||||
self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
|
||||
|
||||
self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
|
||||
self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
|
||||
self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
|
||||
self.to_out = torch.nn.ModuleList(
|
||||
[
|
||||
torch.nn.Linear(self.inner_dim, dim, bias=True),
|
||||
torch.nn.Dropout(dropout),
|
||||
]
|
||||
)
|
||||
self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
|
||||
self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
|
||||
|
||||
self.add_k_proj = self.add_v_proj = None
|
||||
if added_kv_proj_dim is not None:
|
||||
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
|
||||
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
|
||||
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
|
||||
|
||||
self.is_cross_attention = cross_attention_dim_head is not None
|
||||
|
||||
self.set_processor(processor)
|
||||
|
||||
def fuse_projections(self):
|
||||
if getattr(self, "fused_projections", False):
|
||||
return
|
||||
|
||||
if self.cross_attention_dim_head is None:
|
||||
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
|
||||
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
|
||||
out_features, in_features = concatenated_weights.shape
|
||||
with torch.device("meta"):
|
||||
self.to_qkv = nn.Linear(in_features, out_features, bias=True)
|
||||
self.to_qkv.load_state_dict(
|
||||
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
|
||||
)
|
||||
else:
|
||||
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
|
||||
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
|
||||
out_features, in_features = concatenated_weights.shape
|
||||
with torch.device("meta"):
|
||||
self.to_kv = nn.Linear(in_features, out_features, bias=True)
|
||||
self.to_kv.load_state_dict(
|
||||
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
|
||||
)
|
||||
|
||||
if self.added_kv_proj_dim is not None:
|
||||
concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data])
|
||||
concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data])
|
||||
out_features, in_features = concatenated_weights.shape
|
||||
with torch.device("meta"):
|
||||
self.to_added_kv = nn.Linear(in_features, out_features, bias=True)
|
||||
self.to_added_kv.load_state_dict(
|
||||
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
|
||||
)
|
||||
|
||||
self.fused_projections = True
|
||||
|
||||
@torch.no_grad()
|
||||
def unfuse_projections(self):
|
||||
if not getattr(self, "fused_projections", False):
|
||||
return
|
||||
|
||||
if hasattr(self, "to_qkv"):
|
||||
delattr(self, "to_qkv")
|
||||
if hasattr(self, "to_kv"):
|
||||
delattr(self, "to_kv")
|
||||
if hasattr(self, "to_added_kv"):
|
||||
delattr(self, "to_added_kv")
|
||||
|
||||
self.fused_projections = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs)
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_wan.WanImageEmbedding
|
||||
class WanImageEmbedding(torch.nn.Module):
|
||||
def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = FP32LayerNorm(in_features)
|
||||
self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
|
||||
self.norm2 = FP32LayerNorm(out_features)
|
||||
if pos_embed_seq_len is not None:
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
|
||||
else:
|
||||
self.pos_embed = None
|
||||
|
||||
def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
|
||||
if self.pos_embed is not None:
|
||||
batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
|
||||
encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
|
||||
encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
|
||||
|
||||
hidden_states = self.norm1(encoder_hidden_states_image)
|
||||
hidden_states = self.ff(hidden_states)
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_wan.WanTimeTextImageEmbedding
|
||||
class WanTimeTextImageEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
time_freq_dim: int,
|
||||
time_proj_dim: int,
|
||||
text_embed_dim: int,
|
||||
image_embed_dim: Optional[int] = None,
|
||||
pos_embed_seq_len: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
|
||||
self.act_fn = nn.SiLU()
|
||||
self.time_proj = nn.Linear(dim, time_proj_dim)
|
||||
self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
|
||||
|
||||
self.image_embedder = None
|
||||
if image_embed_dim is not None:
|
||||
self.image_embedder = WanImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
timestep: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_hidden_states_image: Optional[torch.Tensor] = None,
|
||||
timestep_seq_len: Optional[int] = None,
|
||||
):
|
||||
timestep = self.timesteps_proj(timestep)
|
||||
if timestep_seq_len is not None:
|
||||
timestep = timestep.unflatten(0, (-1, timestep_seq_len))
|
||||
|
||||
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
|
||||
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
|
||||
timestep = timestep.to(time_embedder_dtype)
|
||||
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
|
||||
timestep_proj = self.time_proj(self.act_fn(temb))
|
||||
|
||||
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
|
||||
if encoder_hidden_states_image is not None:
|
||||
encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
|
||||
|
||||
return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
|
||||
|
||||
|
||||
class ChronoEditRotaryPosEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
attention_head_dim: int,
|
||||
patch_size: Tuple[int, int, int],
|
||||
max_seq_len: int,
|
||||
theta: float = 10000.0,
|
||||
temporal_skip_len: int = 8,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.patch_size = patch_size
|
||||
self.max_seq_len = max_seq_len
|
||||
self.temporal_skip_len = temporal_skip_len
|
||||
|
||||
h_dim = w_dim = 2 * (attention_head_dim // 6)
|
||||
t_dim = attention_head_dim - h_dim - w_dim
|
||||
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
||||
|
||||
freqs_cos = []
|
||||
freqs_sin = []
|
||||
|
||||
for dim in [t_dim, h_dim, w_dim]:
|
||||
freq_cos, freq_sin = get_1d_rotary_pos_embed(
|
||||
dim,
|
||||
max_seq_len,
|
||||
theta,
|
||||
use_real=True,
|
||||
repeat_interleave_real=True,
|
||||
freqs_dtype=freqs_dtype,
|
||||
)
|
||||
freqs_cos.append(freq_cos)
|
||||
freqs_sin.append(freq_sin)
|
||||
|
||||
self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
|
||||
self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p_t, p_h, p_w = self.patch_size
|
||||
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
|
||||
|
||||
split_sizes = [
|
||||
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
|
||||
self.attention_head_dim // 3,
|
||||
self.attention_head_dim // 3,
|
||||
]
|
||||
|
||||
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
|
||||
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
|
||||
|
||||
if num_frames == 2:
|
||||
freqs_cos_f = freqs_cos[0][: self.temporal_skip_len][[0, -1]].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
else:
|
||||
freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
|
||||
|
||||
if num_frames == 2:
|
||||
freqs_sin_f = freqs_sin[0][: self.temporal_skip_len][[0, -1]].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
else:
|
||||
freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
|
||||
freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
|
||||
|
||||
freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
|
||||
freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
|
||||
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
# Copied from diffusers.models.transformers.transformer_wan.WanTransformerBlock
|
||||
class WanTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
ffn_dim: int,
|
||||
num_heads: int,
|
||||
qk_norm: str = "rms_norm_across_heads",
|
||||
cross_attn_norm: bool = False,
|
||||
eps: float = 1e-6,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# 1. Self-attention
|
||||
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
||||
self.attn1 = WanAttention(
|
||||
dim=dim,
|
||||
heads=num_heads,
|
||||
dim_head=dim // num_heads,
|
||||
eps=eps,
|
||||
cross_attention_dim_head=None,
|
||||
processor=WanAttnProcessor(),
|
||||
)
|
||||
|
||||
# 2. Cross-attention
|
||||
self.attn2 = WanAttention(
|
||||
dim=dim,
|
||||
heads=num_heads,
|
||||
dim_head=dim // num_heads,
|
||||
eps=eps,
|
||||
added_kv_proj_dim=added_kv_proj_dim,
|
||||
cross_attention_dim_head=dim // num_heads,
|
||||
processor=WanAttnProcessor(),
|
||||
)
|
||||
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
||||
|
||||
# 3. Feed-forward
|
||||
self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
|
||||
self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
rotary_emb: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if temb.ndim == 4:
|
||||
# temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)
|
||||
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
||||
self.scale_shift_table.unsqueeze(0) + temb.float()
|
||||
).chunk(6, dim=2)
|
||||
# batch_size, seq_len, 1, inner_dim
|
||||
shift_msa = shift_msa.squeeze(2)
|
||||
scale_msa = scale_msa.squeeze(2)
|
||||
gate_msa = gate_msa.squeeze(2)
|
||||
c_shift_msa = c_shift_msa.squeeze(2)
|
||||
c_scale_msa = c_scale_msa.squeeze(2)
|
||||
c_gate_msa = c_gate_msa.squeeze(2)
|
||||
else:
|
||||
# temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B)
|
||||
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
||||
self.scale_shift_table + temb.float()
|
||||
).chunk(6, dim=1)
|
||||
|
||||
# 1. Self-attention
|
||||
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
|
||||
attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb)
|
||||
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
|
||||
|
||||
# 2. Cross-attention
|
||||
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
|
||||
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None)
|
||||
hidden_states = hidden_states + attn_output
|
||||
|
||||
# 3. Feed-forward
|
||||
norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
|
||||
hidden_states
|
||||
)
|
||||
ff_output = self.ffn(norm_hidden_states)
|
||||
hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# modified from diffusers.models.transformers.transformer_wan.WanTransformer3DModel
|
||||
class ChronoEditTransformer3DModel(
|
||||
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
|
||||
):
|
||||
r"""
|
||||
A Transformer model for video-like data used in the ChronoEdit model.
|
||||
|
||||
Args:
|
||||
patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
|
||||
3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
|
||||
num_attention_heads (`int`, defaults to `40`):
|
||||
Fixed length for text embeddings.
|
||||
attention_head_dim (`int`, defaults to `128`):
|
||||
The number of channels in each head.
|
||||
in_channels (`int`, defaults to `16`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, defaults to `16`):
|
||||
The number of channels in the output.
|
||||
text_dim (`int`, defaults to `512`):
|
||||
Input dimension for text embeddings.
|
||||
freq_dim (`int`, defaults to `256`):
|
||||
Dimension for sinusoidal time embeddings.
|
||||
ffn_dim (`int`, defaults to `13824`):
|
||||
Intermediate dimension in feed-forward network.
|
||||
num_layers (`int`, defaults to `40`):
|
||||
The number of layers of transformer blocks to use.
|
||||
window_size (`Tuple[int]`, defaults to `(-1, -1)`):
|
||||
Window size for local attention (-1 indicates global attention).
|
||||
cross_attn_norm (`bool`, defaults to `True`):
|
||||
Enable cross-attention normalization.
|
||||
qk_norm (`bool`, defaults to `True`):
|
||||
Enable query/key normalization.
|
||||
eps (`float`, defaults to `1e-6`):
|
||||
Epsilon value for normalization layers.
|
||||
add_img_emb (`bool`, defaults to `False`):
|
||||
Whether to use img_emb.
|
||||
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
|
||||
The number of channels to use for the added key and value projections. If `None`, no projection is used.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
|
||||
_no_split_modules = ["WanTransformerBlock"]
|
||||
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
|
||||
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
|
||||
_repeated_blocks = ["WanTransformerBlock"]
|
||||
_cp_plan = {
|
||||
"rope": {
|
||||
0: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
|
||||
1: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
|
||||
},
|
||||
"blocks.0": {
|
||||
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
},
|
||||
"blocks.*": {
|
||||
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
},
|
||||
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
||||
}
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: Tuple[int] = (1, 2, 2),
|
||||
num_attention_heads: int = 40,
|
||||
attention_head_dim: int = 128,
|
||||
in_channels: int = 16,
|
||||
out_channels: int = 16,
|
||||
text_dim: int = 4096,
|
||||
freq_dim: int = 256,
|
||||
ffn_dim: int = 13824,
|
||||
num_layers: int = 40,
|
||||
cross_attn_norm: bool = True,
|
||||
qk_norm: Optional[str] = "rms_norm_across_heads",
|
||||
eps: float = 1e-6,
|
||||
image_dim: Optional[int] = None,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
rope_max_seq_len: int = 1024,
|
||||
pos_embed_seq_len: Optional[int] = None,
|
||||
rope_temporal_skip_len: int = 8,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
out_channels = out_channels or in_channels
|
||||
|
||||
# 1. Patch & position embedding
|
||||
self.rope = ChronoEditRotaryPosEmbed(
|
||||
attention_head_dim, patch_size, rope_max_seq_len, temporal_skip_len=rope_temporal_skip_len
|
||||
)
|
||||
self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
# 2. Condition embeddings
|
||||
# image_embedding_dim=1280 for I2V model
|
||||
self.condition_embedder = WanTimeTextImageEmbedding(
|
||||
dim=inner_dim,
|
||||
time_freq_dim=freq_dim,
|
||||
time_proj_dim=inner_dim * 6,
|
||||
text_embed_dim=text_dim,
|
||||
image_embed_dim=image_dim,
|
||||
pos_embed_seq_len=pos_embed_seq_len,
|
||||
)
|
||||
|
||||
# 3. Transformer blocks
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
WanTransformerBlock(
|
||||
inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Output norm & projection
|
||||
self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
|
||||
self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_hidden_states_image: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p_t, p_h, p_w = self.config.patch_size
|
||||
post_patch_num_frames = num_frames // p_t
|
||||
post_patch_height = height // p_h
|
||||
post_patch_width = width // p_w
|
||||
|
||||
rotary_emb = self.rope(hidden_states)
|
||||
|
||||
hidden_states = self.patch_embedding(hidden_states)
|
||||
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
||||
|
||||
# timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v)
|
||||
if timestep.ndim == 2:
|
||||
ts_seq_len = timestep.shape[1]
|
||||
timestep = timestep.flatten() # batch_size * seq_len
|
||||
else:
|
||||
ts_seq_len = None
|
||||
|
||||
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
|
||||
timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len
|
||||
)
|
||||
if ts_seq_len is not None:
|
||||
# batch_size, seq_len, 6, inner_dim
|
||||
timestep_proj = timestep_proj.unflatten(2, (6, -1))
|
||||
else:
|
||||
# batch_size, 6, inner_dim
|
||||
timestep_proj = timestep_proj.unflatten(1, (6, -1))
|
||||
|
||||
if encoder_hidden_states_image is not None:
|
||||
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
|
||||
|
||||
# 4. Transformer blocks
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for block in self.blocks:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
|
||||
)
|
||||
else:
|
||||
for block in self.blocks:
|
||||
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
|
||||
|
||||
# 5. Output norm, projection & unpatchify
|
||||
if temb.ndim == 3:
|
||||
# batch_size, seq_len, inner_dim (wan 2.2 ti2v)
|
||||
shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2)
|
||||
shift = shift.squeeze(2)
|
||||
scale = scale.squeeze(2)
|
||||
else:
|
||||
# batch_size, inner_dim
|
||||
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
|
||||
|
||||
# Move the shift and scale tensors to the same device as hidden_states.
|
||||
# When using multi-GPU inference via accelerate these will be on the
|
||||
# first device rather than the last device, which hidden_states ends up
|
||||
# on.
|
||||
shift = shift.to(hidden_states.device)
|
||||
scale = scale.to(hidden_states.device)
|
||||
|
||||
hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
||||
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -914,7 +914,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
text_embed_dim: int = 4096,
|
||||
pooled_projection_dim: int = 768,
|
||||
rope_theta: float = 256.0,
|
||||
rope_axes_dim: Tuple[int] = (16, 56, 56),
|
||||
rope_axes_dim: Tuple[int, ...] = (16, 56, 56),
|
||||
image_condition_type: Optional[str] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -139,7 +139,7 @@ class HunyuanVideoFramepackTransformer3DModel(
|
||||
text_embed_dim: int = 4096,
|
||||
pooled_projection_dim: int = 768,
|
||||
rope_theta: float = 256.0,
|
||||
rope_axes_dim: Tuple[int] = (16, 56, 56),
|
||||
rope_axes_dim: Tuple[int, ...] = (16, 56, 56),
|
||||
image_condition_type: Optional[str] = None,
|
||||
has_image_proj: int = False,
|
||||
image_proj_dim: int = 1152,
|
||||
|
||||
@@ -689,7 +689,7 @@ class HunyuanImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
text_embed_dim: int = 3584,
|
||||
text_embed_2_dim: Optional[int] = None,
|
||||
rope_theta: float = 256.0,
|
||||
rope_axes_dim: Tuple[int] = (64, 64),
|
||||
rope_axes_dim: Tuple[int, ...] = (64, 64),
|
||||
use_meanflow: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -275,7 +275,12 @@ class PRXEmbedND(nn.Module):
|
||||
|
||||
def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
||||
assert dim % 2 == 0
|
||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||
|
||||
is_mps = pos.device.type == "mps"
|
||||
is_npu = pos.device.type == "npu"
|
||||
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
||||
|
||||
scale = torch.arange(0, dim, 2, dtype=dtype, device=pos.device) / dim
|
||||
omega = 1.0 / (theta**scale)
|
||||
out = pos.unsqueeze(-1) * omega.unsqueeze(0)
|
||||
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
||||
|
||||
@@ -172,7 +172,6 @@ class SanaLinearAttnProcessor3_0:
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_wan.WanRotaryPosEmbed
|
||||
class WanRotaryPosEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -189,6 +188,11 @@ class WanRotaryPosEmbed(nn.Module):
|
||||
|
||||
h_dim = w_dim = 2 * (attention_head_dim // 6)
|
||||
t_dim = attention_head_dim - h_dim - w_dim
|
||||
|
||||
self.t_dim = t_dim
|
||||
self.h_dim = h_dim
|
||||
self.w_dim = w_dim
|
||||
|
||||
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
||||
|
||||
freqs_cos = []
|
||||
@@ -214,11 +218,7 @@ class WanRotaryPosEmbed(nn.Module):
|
||||
p_t, p_h, p_w = self.patch_size
|
||||
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
|
||||
|
||||
split_sizes = [
|
||||
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
|
||||
self.attention_head_dim // 3,
|
||||
self.attention_head_dim // 3,
|
||||
]
|
||||
split_sizes = [self.t_dim, self.h_dim, self.w_dim]
|
||||
|
||||
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
|
||||
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
|
||||
@@ -237,7 +237,6 @@ class WanRotaryPosEmbed(nn.Module):
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.sana_transformer.SanaModulatedNorm
|
||||
class SanaModulatedNorm(nn.Module):
|
||||
def __init__(self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
@@ -247,7 +246,7 @@ class SanaModulatedNorm(nn.Module):
|
||||
self, hidden_states: torch.Tensor, temb: torch.Tensor, scale_shift_table: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.norm(hidden_states)
|
||||
shift, scale = (scale_shift_table[None] + temb[:, None].to(scale_shift_table.device)).chunk(2, dim=1)
|
||||
shift, scale = (scale_shift_table[None, None] + temb[:, :, None].to(scale_shift_table.device)).unbind(dim=2)
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
return hidden_states
|
||||
|
||||
@@ -423,8 +422,8 @@ class SanaVideoTransformerBlock(nn.Module):
|
||||
|
||||
# 1. Modulation
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
|
||||
).chunk(6, dim=1)
|
||||
self.scale_shift_table[None, None] + timestep.reshape(batch_size, timestep.shape[1], 6, -1)
|
||||
).unbind(dim=2)
|
||||
|
||||
# 2. Self Attention
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
@@ -635,13 +634,16 @@ class SanaVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
|
||||
|
||||
if guidance is not None:
|
||||
timestep, embedded_timestep = self.time_embed(
|
||||
timestep, guidance=guidance, hidden_dtype=hidden_states.dtype
|
||||
timestep.flatten(), guidance=guidance, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
else:
|
||||
timestep, embedded_timestep = self.time_embed(
|
||||
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
||||
timestep.flatten(), batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
timestep = timestep.view(batch_size, -1, timestep.size(-1))
|
||||
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
|
||||
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
||||
|
||||
|
||||
@@ -389,6 +389,10 @@ class SkyReelsV2RotaryPosEmbed(nn.Module):
|
||||
t_dim = attention_head_dim - h_dim - w_dim
|
||||
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
||||
|
||||
self.t_dim = t_dim
|
||||
self.h_dim = h_dim
|
||||
self.w_dim = w_dim
|
||||
|
||||
freqs_cos = []
|
||||
freqs_sin = []
|
||||
|
||||
@@ -412,11 +416,7 @@ class SkyReelsV2RotaryPosEmbed(nn.Module):
|
||||
p_t, p_h, p_w = self.patch_size
|
||||
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
|
||||
|
||||
split_sizes = [
|
||||
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
|
||||
self.attention_head_dim // 3,
|
||||
self.attention_head_dim // 3,
|
||||
]
|
||||
split_sizes = [self.t_dim, self.h_dim, self.w_dim]
|
||||
|
||||
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
|
||||
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
|
||||
@@ -570,7 +570,7 @@ class SkyReelsV2Transformer3DModel(
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: Tuple[int] = (1, 2, 2),
|
||||
patch_size: Tuple[int, ...] = (1, 2, 2),
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 128,
|
||||
in_channels: int = 16,
|
||||
|
||||
@@ -362,6 +362,11 @@ class WanRotaryPosEmbed(nn.Module):
|
||||
|
||||
h_dim = w_dim = 2 * (attention_head_dim // 6)
|
||||
t_dim = attention_head_dim - h_dim - w_dim
|
||||
|
||||
self.t_dim = t_dim
|
||||
self.h_dim = h_dim
|
||||
self.w_dim = w_dim
|
||||
|
||||
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
||||
|
||||
freqs_cos = []
|
||||
@@ -387,11 +392,7 @@ class WanRotaryPosEmbed(nn.Module):
|
||||
p_t, p_h, p_w = self.patch_size
|
||||
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
|
||||
|
||||
split_sizes = [
|
||||
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
|
||||
self.attention_head_dim // 3,
|
||||
self.attention_head_dim // 3,
|
||||
]
|
||||
split_sizes = [self.t_dim, self.h_dim, self.w_dim]
|
||||
|
||||
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
|
||||
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
|
||||
@@ -563,7 +564,7 @@ class WanTransformer3DModel(
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: Tuple[int] = (1, 2, 2),
|
||||
patch_size: Tuple[int, ...] = (1, 2, 2),
|
||||
num_attention_heads: int = 40,
|
||||
attention_head_dim: int = 128,
|
||||
in_channels: int = 16,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -182,7 +182,7 @@ class WanVACETransformer3DModel(
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: Tuple[int] = (1, 2, 2),
|
||||
patch_size: Tuple[int, ...] = (1, 2, 2),
|
||||
num_attention_heads: int = 40,
|
||||
attention_head_dim: int = 128,
|
||||
in_channels: int = 16,
|
||||
|
||||
@@ -86,11 +86,11 @@ class UNet1DModel(ModelMixin, ConfigMixin):
|
||||
flip_sin_to_cos: bool = True,
|
||||
use_timestep_embedding: bool = False,
|
||||
freq_shift: float = 0.0,
|
||||
down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
|
||||
up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
|
||||
mid_block_type: Tuple[str] = "UNetMidBlock1D",
|
||||
down_block_types: Tuple[str, ...] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
|
||||
up_block_types: Tuple[str, ...] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
|
||||
mid_block_type: str = "UNetMidBlock1D",
|
||||
out_block_type: str = None,
|
||||
block_out_channels: Tuple[int] = (32, 32, 64),
|
||||
block_out_channels: Tuple[int, ...] = (32, 32, 64),
|
||||
act_fn: str = None,
|
||||
norm_num_groups: int = 8,
|
||||
layers_per_block: int = 1,
|
||||
|
||||
@@ -177,16 +177,21 @@ class UNet2DConditionModel(
|
||||
center_input_sample: bool = False,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str] = (
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
"UpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: Union[int, Tuple[int]] = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
@@ -486,10 +491,10 @@ class UNet2DConditionModel(
|
||||
|
||||
def _check_config(
|
||||
self,
|
||||
down_block_types: Tuple[str],
|
||||
up_block_types: Tuple[str],
|
||||
down_block_types: Tuple[str, ...],
|
||||
up_block_types: Tuple[str, ...],
|
||||
only_cross_attention: Union[bool, Tuple[bool]],
|
||||
block_out_channels: Tuple[int],
|
||||
block_out_channels: Tuple[int, ...],
|
||||
layers_per_block: Union[int, Tuple[int]],
|
||||
cross_attention_dim: Union[int, Tuple[int]],
|
||||
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
|
||||
|
||||
@@ -54,7 +54,7 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin):
|
||||
groups: int = 32,
|
||||
attention_head_dim: int = 64,
|
||||
layers_per_block: Union[int, Tuple[int]] = 3,
|
||||
block_out_channels: Tuple[int] = (384, 768, 1536, 3072),
|
||||
block_out_channels: Tuple[int, ...] = (384, 768, 1536, 3072),
|
||||
cross_attention_dim: Union[int, Tuple[int]] = 4096,
|
||||
encoder_hid_dim: int = 4096,
|
||||
):
|
||||
|
||||
@@ -73,25 +73,25 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
|
||||
sample_size: Optional[int] = None,
|
||||
in_channels: int = 8,
|
||||
out_channels: int = 4,
|
||||
down_block_types: Tuple[str] = (
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlockSpatioTemporal",
|
||||
"CrossAttnDownBlockSpatioTemporal",
|
||||
"CrossAttnDownBlockSpatioTemporal",
|
||||
"DownBlockSpatioTemporal",
|
||||
),
|
||||
up_block_types: Tuple[str] = (
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
"UpBlockSpatioTemporal",
|
||||
"CrossAttnUpBlockSpatioTemporal",
|
||||
"CrossAttnUpBlockSpatioTemporal",
|
||||
"CrossAttnUpBlockSpatioTemporal",
|
||||
),
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
addition_time_embed_dim: int = 256,
|
||||
projection_class_embeddings_input_dim: int = 768,
|
||||
layers_per_block: Union[int, Tuple[int]] = 2,
|
||||
cross_attention_dim: Union[int, Tuple[int]] = 1024,
|
||||
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
|
||||
num_attention_heads: Union[int, Tuple[int]] = (5, 10, 20, 20),
|
||||
num_attention_heads: Union[int, Tuple[int, ...]] = (5, 10, 20, 20),
|
||||
num_frames: int = 25,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -145,10 +145,10 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
timestep_ratio_embedding_dim: int = 64,
|
||||
patch_size: int = 1,
|
||||
conditioning_dim: int = 2048,
|
||||
block_out_channels: Tuple[int] = (2048, 2048),
|
||||
num_attention_heads: Tuple[int] = (32, 32),
|
||||
down_num_layers_per_block: Tuple[int] = (8, 24),
|
||||
up_num_layers_per_block: Tuple[int] = (24, 8),
|
||||
block_out_channels: Tuple[int, ...] = (2048, 2048),
|
||||
num_attention_heads: Tuple[int, ...] = (32, 32),
|
||||
down_num_layers_per_block: Tuple[int, ...] = (8, 24),
|
||||
up_num_layers_per_block: Tuple[int, ...] = (24, 8),
|
||||
down_blocks_repeat_mappers: Optional[Tuple[int]] = (
|
||||
1,
|
||||
1,
|
||||
@@ -167,7 +167,7 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
kernel_size=3,
|
||||
dropout: Union[float, Tuple[float]] = (0.1, 0.1),
|
||||
self_attn: Union[bool, Tuple[bool]] = True,
|
||||
timestep_conditioning_type: Tuple[str] = ("sca", "crp"),
|
||||
timestep_conditioning_type: Tuple[str, ...] = ("sca", "crp"),
|
||||
switch_level: Optional[Tuple[bool]] = None,
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -532,8 +532,8 @@ class FlaxEncoder(nn.Module):
|
||||
|
||||
in_channels: int = 3
|
||||
out_channels: int = 3
|
||||
down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
|
||||
block_out_channels: Tuple[int] = (64,)
|
||||
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",)
|
||||
block_out_channels: Tuple[int, ...] = (64,)
|
||||
layers_per_block: int = 2
|
||||
norm_num_groups: int = 32
|
||||
act_fn: str = "silu"
|
||||
@@ -650,8 +650,8 @@ class FlaxDecoder(nn.Module):
|
||||
|
||||
in_channels: int = 3
|
||||
out_channels: int = 3
|
||||
up_block_types: Tuple[str] = ("UpDecoderBlock2D",)
|
||||
block_out_channels: int = (64,)
|
||||
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",)
|
||||
block_out_channels: Tuple[int, ...] = (64,)
|
||||
layers_per_block: int = 2
|
||||
norm_num_groups: int = 32
|
||||
act_fn: str = "silu"
|
||||
@@ -823,9 +823,9 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
|
||||
in_channels: int = 3
|
||||
out_channels: int = 3
|
||||
down_block_types: Tuple[str] = ("DownEncoderBlock2D",)
|
||||
up_block_types: Tuple[str] = ("UpDecoderBlock2D",)
|
||||
block_out_channels: Tuple[int] = (64,)
|
||||
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",)
|
||||
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",)
|
||||
block_out_channels: Tuple[int, ...] = (64,)
|
||||
layers_per_block: int = 1
|
||||
act_fn: str = "silu"
|
||||
latent_channels: int = 4
|
||||
|
||||
@@ -45,7 +45,7 @@ else:
|
||||
"InsertableDict",
|
||||
]
|
||||
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
|
||||
_import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"]
|
||||
_import_structure["wan"] = ["WanAutoBlocks", "Wan22AutoBlocks", "WanModularPipeline"]
|
||||
_import_structure["flux"] = [
|
||||
"FluxAutoBlocks",
|
||||
"FluxModularPipeline",
|
||||
@@ -90,7 +90,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
QwenImageModularPipeline,
|
||||
)
|
||||
from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
|
||||
from .wan import WanAutoBlocks, WanModularPipeline
|
||||
from .wan import Wan22AutoBlocks, WanAutoBlocks, WanModularPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
@@ -861,6 +861,10 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
else:
|
||||
sub_blocks[block_name] = block
|
||||
self.sub_blocks = sub_blocks
|
||||
if not len(self.block_names) == len(self.block_classes):
|
||||
raise ValueError(
|
||||
f"In {self.__class__.__name__}, the number of block_names and block_classes must be the same."
|
||||
)
|
||||
|
||||
def _get_inputs(self):
|
||||
inputs = []
|
||||
@@ -1441,6 +1445,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
|
||||
components_manager: Optional[ComponentsManager] = None,
|
||||
collection: Optional[str] = None,
|
||||
modular_config_dict: Optional[Dict[str, Any]] = None,
|
||||
config_dict: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -1492,23 +1498,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
- The pipeline's config dict is also used to store the pipeline blocks's class name, which will be saved as
|
||||
`_blocks_class_name` in the config dict
|
||||
"""
|
||||
if blocks is None:
|
||||
blocks_class_name = self.default_blocks_name
|
||||
if blocks_class_name is not None:
|
||||
diffusers_module = importlib.import_module("diffusers")
|
||||
blocks_class = getattr(diffusers_module, blocks_class_name)
|
||||
blocks = blocks_class()
|
||||
else:
|
||||
logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}")
|
||||
|
||||
self.blocks = blocks
|
||||
self._components_manager = components_manager
|
||||
self._collection = collection
|
||||
self._component_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_components}
|
||||
self._config_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_configs}
|
||||
|
||||
# update component_specs and config_specs from modular_repo
|
||||
if pretrained_model_name_or_path is not None:
|
||||
if modular_config_dict is None and config_dict is None and pretrained_model_name_or_path is not None:
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
@@ -1524,52 +1515,59 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
"local_files_only": local_files_only,
|
||||
"revision": revision,
|
||||
}
|
||||
# try to load modular_model_index.json
|
||||
try:
|
||||
config_dict = self.load_config(pretrained_model_name_or_path, **load_config_kwargs)
|
||||
except EnvironmentError as e:
|
||||
logger.debug(f"modular_model_index.json not found: {e}")
|
||||
config_dict = None
|
||||
|
||||
# update component_specs and config_specs based on modular_model_index.json
|
||||
if config_dict is not None:
|
||||
for name, value in config_dict.items():
|
||||
# all the components in modular_model_index.json are from_pretrained components
|
||||
if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 3:
|
||||
library, class_name, component_spec_dict = value
|
||||
component_spec = self._dict_to_component_spec(name, component_spec_dict)
|
||||
component_spec.default_creation_method = "from_pretrained"
|
||||
self._component_specs[name] = component_spec
|
||||
modular_config_dict, config_dict = self._load_pipeline_config(
|
||||
pretrained_model_name_or_path, **load_config_kwargs
|
||||
)
|
||||
|
||||
elif name in self._config_specs:
|
||||
self._config_specs[name].default = value
|
||||
|
||||
# if modular_model_index.json is not found, try to load model_index.json
|
||||
if blocks is None:
|
||||
if modular_config_dict is not None:
|
||||
blocks_class_name = modular_config_dict.get("_blocks_class_name")
|
||||
elif config_dict is not None:
|
||||
blocks_class_name = self.get_default_blocks_name(config_dict)
|
||||
else:
|
||||
logger.debug(" loading config from model_index.json")
|
||||
try:
|
||||
from diffusers import DiffusionPipeline
|
||||
blocks_class_name = None
|
||||
if blocks_class_name is not None:
|
||||
diffusers_module = importlib.import_module("diffusers")
|
||||
blocks_class = getattr(diffusers_module, blocks_class_name)
|
||||
blocks = blocks_class()
|
||||
else:
|
||||
logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}")
|
||||
|
||||
config_dict = DiffusionPipeline.load_config(pretrained_model_name_or_path, **load_config_kwargs)
|
||||
except EnvironmentError as e:
|
||||
logger.debug(f" model_index.json not found in the repo: {e}")
|
||||
config_dict = None
|
||||
self.blocks = blocks
|
||||
self._components_manager = components_manager
|
||||
self._collection = collection
|
||||
self._component_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_components}
|
||||
self._config_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_configs}
|
||||
|
||||
# update component_specs and config_specs based on model_index.json
|
||||
if config_dict is not None:
|
||||
for name, value in config_dict.items():
|
||||
if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 2:
|
||||
library, class_name = value
|
||||
component_spec_dict = {
|
||||
"repo": pretrained_model_name_or_path,
|
||||
"subfolder": name,
|
||||
"type_hint": (library, class_name),
|
||||
}
|
||||
component_spec = self._dict_to_component_spec(name, component_spec_dict)
|
||||
component_spec.default_creation_method = "from_pretrained"
|
||||
self._component_specs[name] = component_spec
|
||||
elif name in self._config_specs:
|
||||
self._config_specs[name].default = value
|
||||
# update component_specs and config_specs based on modular_model_index.json
|
||||
if modular_config_dict is not None:
|
||||
for name, value in modular_config_dict.items():
|
||||
# all the components in modular_model_index.json are from_pretrained components
|
||||
if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 3:
|
||||
library, class_name, component_spec_dict = value
|
||||
component_spec = self._dict_to_component_spec(name, component_spec_dict)
|
||||
component_spec.default_creation_method = "from_pretrained"
|
||||
self._component_specs[name] = component_spec
|
||||
|
||||
elif name in self._config_specs:
|
||||
self._config_specs[name].default = value
|
||||
|
||||
# if `modular_config_dict` is None (i.e. `modular_model_index.json` is not found), update based on `config_dict` (i.e. `model_index.json`)
|
||||
elif config_dict is not None:
|
||||
for name, value in config_dict.items():
|
||||
if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 2:
|
||||
library, class_name = value
|
||||
component_spec_dict = {
|
||||
"repo": pretrained_model_name_or_path,
|
||||
"subfolder": name,
|
||||
"type_hint": (library, class_name),
|
||||
}
|
||||
component_spec = self._dict_to_component_spec(name, component_spec_dict)
|
||||
component_spec.default_creation_method = "from_pretrained"
|
||||
self._component_specs[name] = component_spec
|
||||
elif name in self._config_specs:
|
||||
self._config_specs[name].default = value
|
||||
|
||||
if len(kwargs) > 0:
|
||||
logger.warning(f"Unexpected input '{kwargs.keys()}' provided. This input will be ignored.")
|
||||
@@ -1601,6 +1599,35 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
params[input_param.name] = input_param.default
|
||||
return params
|
||||
|
||||
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
|
||||
return self.default_blocks_name
|
||||
|
||||
@classmethod
|
||||
def _load_pipeline_config(
|
||||
cls,
|
||||
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
||||
**load_config_kwargs,
|
||||
):
|
||||
try:
|
||||
# try to load modular_model_index.json
|
||||
modular_config_dict = cls.load_config(pretrained_model_name_or_path, **load_config_kwargs)
|
||||
return modular_config_dict, None
|
||||
|
||||
except EnvironmentError as e:
|
||||
logger.debug(f" modular_model_index.json not found in the repo: {e}")
|
||||
|
||||
try:
|
||||
logger.debug(" try to load model_index.json")
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
config_dict = DiffusionPipeline.load_config(pretrained_model_name_or_path, **load_config_kwargs)
|
||||
return None, config_dict
|
||||
|
||||
except EnvironmentError as e:
|
||||
logger.debug(f" model_index.json not found in the repo: {e}")
|
||||
|
||||
return None, None
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_pretrained(
|
||||
@@ -1655,42 +1682,33 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
"revision": revision,
|
||||
}
|
||||
|
||||
try:
|
||||
# try to load modular_model_index.json
|
||||
config_dict = cls.load_config(pretrained_model_name_or_path, **load_config_kwargs)
|
||||
except EnvironmentError as e:
|
||||
logger.debug(f" modular_model_index.json not found in the repo: {e}")
|
||||
config_dict = None
|
||||
modular_config_dict, config_dict = cls._load_pipeline_config(
|
||||
pretrained_model_name_or_path, **load_config_kwargs
|
||||
)
|
||||
|
||||
if config_dict is not None:
|
||||
pipeline_class = _get_pipeline_class(cls, config=config_dict)
|
||||
if modular_config_dict is not None:
|
||||
pipeline_class = _get_pipeline_class(cls, config=modular_config_dict)
|
||||
elif config_dict is not None:
|
||||
from diffusers.pipelines.auto_pipeline import _get_model
|
||||
|
||||
logger.debug(" try to determine the modular pipeline class from model_index.json")
|
||||
standard_pipeline_class = _get_pipeline_class(cls, config=config_dict)
|
||||
model_name = _get_model(standard_pipeline_class.__name__)
|
||||
pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(model_name, ModularPipeline.__name__)
|
||||
diffusers_module = importlib.import_module("diffusers")
|
||||
pipeline_class = getattr(diffusers_module, pipeline_class_name)
|
||||
else:
|
||||
try:
|
||||
logger.debug(" try to load model_index.json")
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.pipelines.auto_pipeline import _get_model
|
||||
|
||||
config_dict = DiffusionPipeline.load_config(pretrained_model_name_or_path, **load_config_kwargs)
|
||||
except EnvironmentError as e:
|
||||
logger.debug(f" model_index.json not found in the repo: {e}")
|
||||
|
||||
if config_dict is not None:
|
||||
logger.debug(" try to determine the modular pipeline class from model_index.json")
|
||||
standard_pipeline_class = _get_pipeline_class(cls, config=config_dict)
|
||||
model_name = _get_model(standard_pipeline_class.__name__)
|
||||
pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(model_name, ModularPipeline.__name__)
|
||||
diffusers_module = importlib.import_module("diffusers")
|
||||
pipeline_class = getattr(diffusers_module, pipeline_class_name)
|
||||
else:
|
||||
# there is no config for modular pipeline, assuming that the pipeline block does not need any from_pretrained components
|
||||
pipeline_class = cls
|
||||
pretrained_model_name_or_path = None
|
||||
# there is no config for modular pipeline, assuming that the pipeline block does not need any from_pretrained components
|
||||
pipeline_class = cls
|
||||
pretrained_model_name_or_path = None
|
||||
|
||||
pipeline = pipeline_class(
|
||||
blocks=blocks,
|
||||
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
components_manager=components_manager,
|
||||
collection=collection,
|
||||
modular_config_dict=modular_config_dict,
|
||||
config_dict=config_dict,
|
||||
**kwargs,
|
||||
)
|
||||
return pipeline
|
||||
@@ -2134,7 +2152,9 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
logger.warning(
|
||||
f"\nFailed to create component {name}:\n"
|
||||
f"- Component spec: {spec}\n"
|
||||
f"- load() called with kwargs: {component_load_kwargs}\n\n"
|
||||
f"- load() called with kwargs: {component_load_kwargs}\n"
|
||||
"If this component is not required for your workflow you can safely ignore this message.\n\n"
|
||||
"Traceback:\n"
|
||||
f"{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
|
||||
@@ -132,6 +132,7 @@ class QwenImagePrepareLatentsStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("latents"),
|
||||
InputParam(name="height"),
|
||||
InputParam(name="width"),
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
@@ -196,11 +197,11 @@ class QwenImagePrepareLatentsStep(ModularPipelineBlocks):
|
||||
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
block_state.latents = randn_tensor(
|
||||
shape, generator=block_state.generator, device=device, dtype=block_state.dtype
|
||||
)
|
||||
block_state.latents = components.pachifier.pack_latents(block_state.latents)
|
||||
if block_state.latents is None:
|
||||
block_state.latents = randn_tensor(
|
||||
shape, generator=block_state.generator, device=device, dtype=block_state.dtype
|
||||
)
|
||||
block_state.latents = components.pachifier.pack_latents(block_state.latents)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
@@ -549,8 +550,7 @@ class QwenImageRoPEInputsStep(ModularPipelineBlocks):
|
||||
block_state.width // components.vae_scale_factor // 2,
|
||||
)
|
||||
]
|
||||
* block_state.batch_size
|
||||
]
|
||||
] * block_state.batch_size
|
||||
block_state.txt_seq_lens = (
|
||||
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
|
||||
)
|
||||
|
||||
@@ -74,8 +74,9 @@ class QwenImageDecoderStep(ModularPipelineBlocks):
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# YiYi Notes: remove support for output_type = "latents', we can just skip decode/encode step in modular
|
||||
vae_scale_factor = components.vae_scale_factor
|
||||
block_state.latents = components.pachifier.unpack_latents(
|
||||
block_state.latents, block_state.height, block_state.width
|
||||
block_state.latents, block_state.height, block_state.width, vae_scale_factor=vae_scale_factor
|
||||
)
|
||||
block_state.latents = block_state.latents.to(components.vae.dtype)
|
||||
|
||||
|
||||
@@ -503,6 +503,8 @@ class QwenImageTextEncoderStep(ModularPipelineBlocks):
|
||||
block_state.prompt_embeds = block_state.prompt_embeds[:, : block_state.max_sequence_length]
|
||||
block_state.prompt_embeds_mask = block_state.prompt_embeds_mask[:, : block_state.max_sequence_length]
|
||||
|
||||
block_state.negative_prompt_embeds = None
|
||||
block_state.negative_prompt_embeds_mask = None
|
||||
if components.requires_unconditional_embeds:
|
||||
negative_prompt = block_state.negative_prompt or ""
|
||||
block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds(
|
||||
@@ -627,6 +629,8 @@ class QwenImageEditTextEncoderStep(ModularPipelineBlocks):
|
||||
device=device,
|
||||
)
|
||||
|
||||
block_state.negative_prompt_embeds = None
|
||||
block_state.negative_prompt_embeds_mask = None
|
||||
if components.requires_unconditional_embeds:
|
||||
negative_prompt = block_state.negative_prompt or " "
|
||||
block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds_edit(
|
||||
@@ -679,6 +683,8 @@ class QwenImageEditPlusTextEncoderStep(QwenImageEditTextEncoderStep):
|
||||
device=device,
|
||||
)
|
||||
|
||||
block_state.negative_prompt_embeds = None
|
||||
block_state.negative_prompt_embeds_mask = None
|
||||
if components.requires_unconditional_embeds:
|
||||
negative_prompt = block_state.negative_prompt or " "
|
||||
block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = (
|
||||
|
||||
@@ -523,7 +523,7 @@ class QwenImageCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
QwenImageOptionalControlNetBeforeDenoiseStep,
|
||||
QwenImageAutoDenoiseStep,
|
||||
]
|
||||
block_names = ["input", "controlnet_input", "before_denoise", "controlnet_before_denoise", "denoise", "decode"]
|
||||
block_names = ["input", "controlnet_input", "before_denoise", "controlnet_before_denoise", "denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
@@ -534,7 +534,6 @@ class QwenImageCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
+ " - `QwenImageAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
|
||||
+ " - `QwenImageOptionalControlNetBeforeDenoiseStep` (controlnet_before_denoise) prepares the controlnet input for the denoising step.\n"
|
||||
+ " - `QwenImageAutoDenoiseStep` (denoise) iteratively denoises the latents.\n"
|
||||
+ " - `QwenImageAutoDecodeStep` (decode) decodes the latents into images.\n\n"
|
||||
+ "This step support text-to-image, image-to-image, inpainting, and controlnet tasks for QwenImage:\n"
|
||||
+ " - for image-to-image generation, you need to provide `image_latents`\n"
|
||||
+ " - for inpainting, you need to provide `processed_mask_image` and `image_latents`\n"
|
||||
|
||||
@@ -26,10 +26,7 @@ class QwenImagePachifier(ConfigMixin):
|
||||
config_name = "config.json"
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
):
|
||||
def __init__(self, patch_size: int = 2):
|
||||
super().__init__()
|
||||
|
||||
def pack_latents(self, latents):
|
||||
|
||||
@@ -21,7 +21,6 @@ import torch
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...models import AutoencoderKL
|
||||
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import (
|
||||
ModularPipelineBlocks,
|
||||
@@ -77,21 +76,7 @@ class StableDiffusionXLDecodeStep(ModularPipelineBlocks):
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae with self->components
|
||||
def upcast_vae(components):
|
||||
dtype = components.vae.dtype
|
||||
components.vae.to(dtype=torch.float32)
|
||||
use_torch_2_0_or_xformers = isinstance(
|
||||
components.vae.decoder.mid_block.attentions[0].processor,
|
||||
(
|
||||
AttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
),
|
||||
)
|
||||
# if xformers or torch_2_0 is used attention block does not need
|
||||
# to be in float32 which can save lots of memory
|
||||
if use_torch_2_0_or_xformers:
|
||||
components.vae.post_quant_conv.to(dtype)
|
||||
components.vae.decoder.conv_in.to(dtype)
|
||||
components.vae.decoder.mid_block.to(dtype)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user