Compare commits
101 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 7e65e59ff0 | |||
| e6d4612309 | |||
| a88a7b4f03 | |||
| c8656ed73c | |||
| 94c9613f99 | |||
| b91e8c0d0b | |||
| ac7864624b | |||
| 5ffb73d4ae | |||
| 4088e8a851 | |||
| d33d9f6715 | |||
| dde8754ba2 | |||
| fbcd3ba6b2 | |||
| d176f61fcf | |||
| 354d35adb0 | |||
| 544ba677dd | |||
| 6f1042e36c | |||
| d5da453de5 | |||
| 15370f8412 | |||
| a96b145304 | |||
| 6d8973ffe2 | |||
| ab71f3c864 | |||
| b7df4a5387 | |||
| 67dc65e2e3 | |||
| 3579fdabf9 | |||
| 1afc21855e | |||
| 0c35b580fe | |||
| 01a56927f1 | |||
| a9e4883b6a | |||
| 63dd601758 | |||
| eeae0338e7 | |||
| 3c1ca869d7 | |||
| 6fe4a6ff8e | |||
| 40de88af8c | |||
| 6a2309b98d | |||
| cd3bbe2910 | |||
| 7a001c3ee2 | |||
| d8e4805816 | |||
| 44c3101685 | |||
| d6c63bb956 | |||
| 2f44d63046 | |||
| f3db38c1e7 | |||
| f5e5f34823 | |||
| 093cd3f040 | |||
| aecf0c53bf | |||
| 0c7589293b | |||
| ff263947ad | |||
| 66e6a0215f | |||
| 5a47442f92 | |||
| 8f6328c4a4 | |||
| 8d45f219d0 | |||
| 0fd58c7706 | |||
| 35d703310c | |||
| b455dc94a2 | |||
| 04f9d2bf3d | |||
| bc8fd864eb | |||
| a9cb08af39 | |||
| 9f669e7b5d | |||
| 8ac17cd2cb | |||
| e4393fa613 | |||
| b3e9dfced7 | |||
| 58f3771545 | |||
| 6198f8a12b | |||
| dcfb18a2d3 | |||
| ac5a1e28fc | |||
| 325a95051b | |||
| 1ec28a2c77 | |||
| de6173c683 | |||
| 8f80dda193 | |||
| cdbf0ad883 | |||
| 5e8415a311 | |||
| 051c8a1c0f | |||
| d54622c267 | |||
| df8dd77817 | |||
| 9f3c0fdcd8 | |||
| 84e16575e4 | |||
| 55d49d4379 | |||
| 40528e9ae7 | |||
| dc622a95d0 | |||
| ecfbc8f952 | |||
| df0e2a4f2c | |||
| 303efd2b8d | |||
| 5afbcce176 | |||
| 6d1a648602 | |||
| 250f5cb53d | |||
| dc6bd1511a | |||
| 500b9cf184 | |||
| d34b18c783 | |||
| 7536f647e4 | |||
| a138d71ec1 | |||
| bc4039886d | |||
| 9c3b58dcf1 | |||
| 74b5fed434 | |||
| 85eb505672 | |||
| ccdd96ca52 | |||
| 4c723d8ec3 | |||
| bec2d8eaea | |||
| a0a51eb098 | |||
| a5a0ccf86a | |||
| dd07b19e27 | |||
| 57636ad4f4 | |||
| cefc2cf82d |
@@ -7,7 +7,7 @@ on:
|
||||
|
||||
env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
HF_HUB_ENABLE_HF_TRANSFER: 1
|
||||
HF_XET_HIGH_PERFORMANCE: 1
|
||||
HF_HOME: /mnt/cache
|
||||
OMP_NUM_THREADS: 8
|
||||
MKL_NUM_THREADS: 8
|
||||
|
||||
@@ -42,18 +42,39 @@ jobs:
|
||||
CHANGED_FILES: ${{ steps.file_changes.outputs.all }}
|
||||
run: |
|
||||
echo "$CHANGED_FILES"
|
||||
for FILE in $CHANGED_FILES; do
|
||||
ALLOWED_IMAGES=(
|
||||
diffusers-pytorch-cpu
|
||||
diffusers-pytorch-cuda
|
||||
diffusers-pytorch-xformers-cuda
|
||||
diffusers-pytorch-minimum-cuda
|
||||
diffusers-doc-builder
|
||||
)
|
||||
|
||||
declare -A IMAGES_TO_BUILD=()
|
||||
|
||||
for FILE in $CHANGED_FILES; do
|
||||
# skip anything that isn't still on disk
|
||||
if [[ ! -f "$FILE" ]]; then
|
||||
if [[ ! -e "$FILE" ]]; then
|
||||
echo "Skipping removed file $FILE"
|
||||
continue
|
||||
fi
|
||||
if [[ "$FILE" == docker/*Dockerfile ]]; then
|
||||
DOCKER_PATH="${FILE%/Dockerfile}"
|
||||
DOCKER_TAG=$(basename "$DOCKER_PATH")
|
||||
echo "Building Docker image for $DOCKER_TAG"
|
||||
docker build -t "$DOCKER_TAG" "$DOCKER_PATH"
|
||||
fi
|
||||
|
||||
for IMAGE in "${ALLOWED_IMAGES[@]}"; do
|
||||
if [[ "$FILE" == docker/${IMAGE}/* ]]; then
|
||||
IMAGES_TO_BUILD["$IMAGE"]=1
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
if [[ ${#IMAGES_TO_BUILD[@]} -eq 0 ]]; then
|
||||
echo "No relevant Docker changes detected."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
for IMAGE in "${!IMAGES_TO_BUILD[@]}"; do
|
||||
DOCKER_PATH="docker/${IMAGE}"
|
||||
echo "Building Docker image for $IMAGE"
|
||||
docker build -t "$IMAGE" "$DOCKER_PATH"
|
||||
done
|
||||
if: steps.file_changes.outputs.all != ''
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ on:
|
||||
|
||||
env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
HF_HUB_ENABLE_HF_TRANSFER: 1
|
||||
HF_XET_HIGH_PERFORMANCE: 1
|
||||
OMP_NUM_THREADS: 8
|
||||
MKL_NUM_THREADS: 8
|
||||
PYTEST_TIMEOUT: 600
|
||||
@@ -73,6 +73,8 @@ jobs:
|
||||
run: |
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
uv pip install pytest-reportlog
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -84,7 +86,7 @@ jobs:
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
-k "not Flax and not Onnx" \
|
||||
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
|
||||
--report-log=tests_pipeline_${{ matrix.module }}_cuda.log \
|
||||
tests/pipelines/${{ matrix.module }}
|
||||
@@ -126,6 +128,8 @@ jobs:
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
uv pip install pytest-reportlog
|
||||
- name: Environment
|
||||
run: python utils/print_env.py
|
||||
@@ -138,7 +142,7 @@ jobs:
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
-k "not Flax and not Onnx" \
|
||||
--make-reports=tests_torch_${{ matrix.module }}_cuda \
|
||||
--report-log=tests_torch_${{ matrix.module }}_cuda.log \
|
||||
tests/${{ matrix.module }}
|
||||
@@ -151,7 +155,7 @@ jobs:
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v --make-reports=examples_torch_cuda \
|
||||
--make-reports=examples_torch_cuda \
|
||||
--report-log=examples_torch_cuda.log \
|
||||
examples/
|
||||
|
||||
@@ -190,6 +194,8 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality,training]"
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -198,7 +204,7 @@ jobs:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
RUN_COMPILE: yes
|
||||
run: |
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "compile" --make-reports=tests_torch_compile_cuda tests/
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_torch_compile_cuda_failures_short.txt
|
||||
@@ -232,6 +238,8 @@ jobs:
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
uv pip install pytest-reportlog
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -281,6 +289,8 @@ jobs:
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -293,7 +303,7 @@ jobs:
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
-k "not Flax and not Onnx" \
|
||||
--make-reports=tests_torch_minimum_version_cuda \
|
||||
tests/models/test_modeling_common.py \
|
||||
tests/pipelines/test_pipelines_common.py \
|
||||
@@ -358,6 +368,8 @@ jobs:
|
||||
uv pip install ${{ join(matrix.config.additional_deps, ' ') }}
|
||||
fi
|
||||
uv pip install pytest-reportlog
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -405,6 +417,8 @@ jobs:
|
||||
run: |
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install -U bitsandbytes optimum_quanto
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
uv pip install pytest-reportlog
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -531,7 +545,7 @@ jobs:
|
||||
# HF_HOME: /System/Volumes/Data/mnt/cache
|
||||
# HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
# run: |
|
||||
# ${CONDA_RUN} pytest -n 1 -s -v --make-reports=tests_torch_mps \
|
||||
# ${CONDA_RUN} pytest -n 1 --make-reports=tests_torch_mps \
|
||||
# --report-log=tests_torch_mps.log \
|
||||
# tests/
|
||||
# - name: Failure short reports
|
||||
@@ -587,7 +601,7 @@ jobs:
|
||||
# HF_HOME: /System/Volumes/Data/mnt/cache
|
||||
# HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
# run: |
|
||||
# ${CONDA_RUN} pytest -n 1 -s -v --make-reports=tests_torch_mps \
|
||||
# ${CONDA_RUN} pytest -n 1 --make-reports=tests_torch_mps \
|
||||
# --report-log=tests_torch_mps.log \
|
||||
# tests/
|
||||
# - name: Failure short reports
|
||||
|
||||
@@ -26,7 +26,7 @@ concurrency:
|
||||
|
||||
env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
HF_HUB_ENABLE_HF_TRANSFER: 1
|
||||
HF_XET_HIGH_PERFORMANCE: 1
|
||||
OMP_NUM_THREADS: 4
|
||||
MKL_NUM_THREADS: 4
|
||||
PYTEST_TIMEOUT: 60
|
||||
@@ -109,7 +109,8 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
|
||||
- name: Environment
|
||||
@@ -120,7 +121,7 @@ jobs:
|
||||
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
|
||||
run: |
|
||||
pytest -n 8 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
-k "not Flax and not Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/modular_pipelines
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ concurrency:
|
||||
|
||||
env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
HF_HUB_ENABLE_HF_TRANSFER: 1
|
||||
HF_XET_HIGH_PERFORMANCE: 1
|
||||
OMP_NUM_THREADS: 4
|
||||
MKL_NUM_THREADS: 4
|
||||
PYTEST_TIMEOUT: 60
|
||||
@@ -115,7 +115,8 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
|
||||
- name: Environment
|
||||
@@ -126,7 +127,7 @@ jobs:
|
||||
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
|
||||
run: |
|
||||
pytest -n 8 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
-k "not Flax and not Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/pipelines
|
||||
|
||||
@@ -134,7 +135,7 @@ jobs:
|
||||
if: ${{ matrix.config.framework == 'pytorch_models' }}
|
||||
run: |
|
||||
pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx and not Dependency" \
|
||||
-k "not Flax and not Onnx and not Dependency" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/models tests/schedulers tests/others
|
||||
|
||||
@@ -246,7 +247,8 @@ jobs:
|
||||
uv pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps
|
||||
uv pip install -U tokenizers
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -255,11 +257,11 @@ jobs:
|
||||
- name: Run fast PyTorch LoRA tests with PEFT
|
||||
run: |
|
||||
pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v \
|
||||
\
|
||||
--make-reports=tests_peft_main \
|
||||
tests/lora/
|
||||
pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v \
|
||||
\
|
||||
--make-reports=tests_models_lora_peft_main \
|
||||
tests/models/ -k "lora"
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
name: Fast GPU Tests on PR
|
||||
name: Fast GPU Tests on PR
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
@@ -24,7 +24,7 @@ env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
OMP_NUM_THREADS: 8
|
||||
MKL_NUM_THREADS: 8
|
||||
HF_HUB_ENABLE_HF_TRANSFER: 1
|
||||
HF_XET_HIGH_PERFORMANCE: 1
|
||||
PYTEST_TIMEOUT: 600
|
||||
PIPELINE_USAGE_CUTOFF: 1000000000 # set high cutoff so that only always-test pipelines run
|
||||
|
||||
@@ -71,7 +71,7 @@ jobs:
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
|
||||
setup_torch_cuda_pipeline_matrix:
|
||||
needs: [check_code_quality, check_repository_consistency]
|
||||
name: Setup Torch Pipelines CUDA Slow Tests Matrix
|
||||
@@ -131,7 +131,8 @@ jobs:
|
||||
run: |
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -149,18 +150,18 @@ jobs:
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
if [ "${{ matrix.module }}" = "ip_adapters" ]; then
|
||||
if [ "${{ matrix.module }}" = "ip_adapters" ]; then
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
-k "not Flax and not Onnx" \
|
||||
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
|
||||
tests/pipelines/${{ matrix.module }}
|
||||
else
|
||||
else
|
||||
pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx and $pattern" \
|
||||
-k "not Flax and not Onnx and $pattern" \
|
||||
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
|
||||
tests/pipelines/${{ matrix.module }}
|
||||
fi
|
||||
fi
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
@@ -201,7 +202,8 @@ jobs:
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -222,11 +224,11 @@ jobs:
|
||||
run: |
|
||||
pattern=$(cat ${{ steps.extract_tests.outputs.pattern_file }})
|
||||
if [ -z "$pattern" ]; then
|
||||
pytest -n 1 -sv --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx" tests/${{ matrix.module }} \
|
||||
--make-reports=tests_torch_cuda_${{ matrix.module }}
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx" tests/${{ matrix.module }} \
|
||||
--make-reports=tests_torch_cuda_${{ matrix.module }}
|
||||
else
|
||||
pytest -n 1 -sv --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx and $pattern" tests/${{ matrix.module }} \
|
||||
--make-reports=tests_torch_cuda_${{ matrix.module }}
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "not Flax and not Onnx and $pattern" tests/${{ matrix.module }} \
|
||||
--make-reports=tests_torch_cuda_${{ matrix.module }}
|
||||
fi
|
||||
|
||||
- name: Failure short reports
|
||||
@@ -262,7 +264,8 @@ jobs:
|
||||
nvidia-smi
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
uv pip install -e ".[quality,training]"
|
||||
|
||||
- name: Environment
|
||||
@@ -274,7 +277,7 @@ jobs:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
run: |
|
||||
uv pip install ".[training]"
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile --make-reports=examples_torch_cuda examples/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
|
||||
@@ -14,7 +14,7 @@ env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
OMP_NUM_THREADS: 8
|
||||
MKL_NUM_THREADS: 8
|
||||
HF_HUB_ENABLE_HF_TRANSFER: 1
|
||||
HF_XET_HIGH_PERFORMANCE: 1
|
||||
PYTEST_TIMEOUT: 600
|
||||
PIPELINE_USAGE_CUTOFF: 50000
|
||||
|
||||
@@ -76,6 +76,8 @@ jobs:
|
||||
run: |
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -86,7 +88,7 @@ jobs:
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
-k "not Flax and not Onnx" \
|
||||
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
|
||||
tests/pipelines/${{ matrix.module }}
|
||||
- name: Failure short reports
|
||||
@@ -127,6 +129,8 @@ jobs:
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -139,7 +143,7 @@ jobs:
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
-k "not Flax and not Onnx" \
|
||||
--make-reports=tests_torch_cuda_${{ matrix.module }} \
|
||||
tests/${{ matrix.module }}
|
||||
|
||||
@@ -178,6 +182,8 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality,training]"
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -186,7 +192,7 @@ jobs:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
RUN_COMPILE: yes
|
||||
run: |
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "compile" --make-reports=tests_torch_compile_cuda tests/
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_torch_compile_cuda_failures_short.txt
|
||||
@@ -227,7 +233,7 @@ jobs:
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
run: |
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_torch_xformers_cuda_failures_short.txt
|
||||
@@ -270,7 +276,7 @@ jobs:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
run: |
|
||||
uv pip install ".[training]"
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile --make-reports=examples_torch_cuda examples/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
|
||||
@@ -18,7 +18,7 @@ env:
|
||||
HF_HOME: /mnt/cache
|
||||
OMP_NUM_THREADS: 8
|
||||
MKL_NUM_THREADS: 8
|
||||
HF_HUB_ENABLE_HF_TRANSFER: 1
|
||||
HF_XET_HIGH_PERFORMANCE: 1
|
||||
PYTEST_TIMEOUT: 600
|
||||
RUN_SLOW: no
|
||||
|
||||
@@ -70,7 +70,7 @@ jobs:
|
||||
if: ${{ matrix.config.framework == 'pytorch' }}
|
||||
run: |
|
||||
pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
-k "not Flax and not Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ env:
|
||||
HF_HOME: /mnt/cache
|
||||
OMP_NUM_THREADS: 8
|
||||
MKL_NUM_THREADS: 8
|
||||
HF_HUB_ENABLE_HF_TRANSFER: 1
|
||||
HF_XET_HIGH_PERFORMANCE: 1
|
||||
PYTEST_TIMEOUT: 600
|
||||
RUN_SLOW: no
|
||||
|
||||
@@ -57,7 +57,7 @@ jobs:
|
||||
HF_HOME: /System/Volumes/Data/mnt/cache
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
run: |
|
||||
${CONDA_RUN} python -m pytest -n 0 -s -v --make-reports=tests_torch_mps tests/
|
||||
${CONDA_RUN} python -m pytest -n 0 --make-reports=tests_torch_mps tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
|
||||
@@ -84,7 +84,7 @@ jobs:
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
-k "not Flax and not Onnx" \
|
||||
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
|
||||
tests/pipelines/${{ matrix.module }}
|
||||
- name: Failure short reports
|
||||
@@ -137,7 +137,7 @@ jobs:
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
-k "not Flax and not Onnx" \
|
||||
--make-reports=tests_torch_${{ matrix.module }}_cuda \
|
||||
tests/${{ matrix.module }}
|
||||
|
||||
@@ -187,7 +187,7 @@ jobs:
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
-k "not Flax and not Onnx" \
|
||||
--make-reports=tests_torch_minimum_cuda \
|
||||
tests/models/test_modeling_common.py \
|
||||
tests/pipelines/test_pipelines_common.py \
|
||||
@@ -240,7 +240,7 @@ jobs:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
RUN_COMPILE: yes
|
||||
run: |
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "compile" --make-reports=tests_torch_compile_cuda tests/
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_torch_compile_cuda_failures_short.txt
|
||||
@@ -281,7 +281,7 @@ jobs:
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
run: |
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_torch_xformers_cuda_failures_short.txt
|
||||
@@ -326,7 +326,7 @@ jobs:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
run: |
|
||||
uv pip install ".[training]"
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
|
||||
pytest -n 1 --max-worker-restart=0 --dist=loadfile --make-reports=examples_torch_cuda examples/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
|
||||
@@ -33,7 +33,7 @@ RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.
|
||||
RUN uv pip install --no-cache-dir \
|
||||
accelerate \
|
||||
numpy==1.26.4 \
|
||||
hf_transfer \
|
||||
hf_xet \
|
||||
setuptools==69.5.1 \
|
||||
bitsandbytes \
|
||||
torchao \
|
||||
|
||||
@@ -44,6 +44,6 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers \
|
||||
hf_transfer
|
||||
hf_xet
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
@@ -38,13 +38,12 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
||||
datasets \
|
||||
hf-doc-builder \
|
||||
huggingface-hub \
|
||||
hf_transfer \
|
||||
hf_xet \
|
||||
Jinja2 \
|
||||
librosa \
|
||||
numpy==1.26.4 \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers \
|
||||
hf_transfer
|
||||
transformers
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
@@ -31,7 +31,7 @@ RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.
|
||||
RUN uv pip install --no-cache-dir \
|
||||
accelerate \
|
||||
numpy==1.26.4 \
|
||||
hf_transfer
|
||||
hf_xet
|
||||
|
||||
RUN apt-get clean && rm -rf /var/lib/apt/lists/* && apt-get autoremove && apt-get autoclean
|
||||
|
||||
|
||||
@@ -44,6 +44,6 @@ RUN uv pip install --no-cache-dir \
|
||||
accelerate \
|
||||
numpy==1.26.4 \
|
||||
pytorch-lightning \
|
||||
hf_transfer
|
||||
hf_xet
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
@@ -47,6 +47,6 @@ RUN uv pip install --no-cache-dir \
|
||||
accelerate \
|
||||
numpy==1.26.4 \
|
||||
pytorch-lightning \
|
||||
hf_transfer
|
||||
hf_xet
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
@@ -44,7 +44,7 @@ RUN uv pip install --no-cache-dir \
|
||||
accelerate \
|
||||
numpy==1.26.4 \
|
||||
pytorch-lightning \
|
||||
hf_transfer \
|
||||
hf_xet \
|
||||
xformers
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
@@ -22,6 +22,8 @@
|
||||
title: Reproducibility
|
||||
- local: using-diffusers/schedulers
|
||||
title: Schedulers
|
||||
- local: using-diffusers/automodel
|
||||
title: AutoModel
|
||||
- local: using-diffusers/other-formats
|
||||
title: Model formats
|
||||
- local: using-diffusers/push_to_hub
|
||||
@@ -119,6 +121,8 @@
|
||||
title: ComponentsManager
|
||||
- local: modular_diffusers/guiders
|
||||
title: Guiders
|
||||
- local: modular_diffusers/custom_blocks
|
||||
title: Building Custom Blocks
|
||||
title: Modular Diffusers
|
||||
- isExpanded: false
|
||||
sections:
|
||||
@@ -323,10 +327,14 @@
|
||||
title: AllegroTransformer3DModel
|
||||
- local: api/models/aura_flow_transformer2d
|
||||
title: AuraFlowTransformer2DModel
|
||||
- local: api/models/transformer_bria_fibo
|
||||
title: BriaFiboTransformer2DModel
|
||||
- local: api/models/bria_transformer
|
||||
title: BriaTransformer2DModel
|
||||
- local: api/models/chroma_transformer
|
||||
title: ChromaTransformer2DModel
|
||||
- local: api/models/chronoedit_transformer_3d
|
||||
title: ChronoEditTransformer3DModel
|
||||
- local: api/models/cogvideox_transformer3d
|
||||
title: CogVideoXTransformer3DModel
|
||||
- local: api/models/cogview3plus_transformer2d
|
||||
@@ -341,12 +349,16 @@
|
||||
title: DiTTransformer2DModel
|
||||
- local: api/models/easyanimate_transformer3d
|
||||
title: EasyAnimateTransformer3DModel
|
||||
- local: api/models/flux2_transformer
|
||||
title: Flux2Transformer2DModel
|
||||
- local: api/models/flux_transformer
|
||||
title: FluxTransformer2DModel
|
||||
- local: api/models/hidream_image_transformer
|
||||
title: HiDreamImageTransformer2DModel
|
||||
- local: api/models/hunyuan_transformer2d
|
||||
title: HunyuanDiT2DModel
|
||||
- local: api/models/hunyuanimage_transformer_2d
|
||||
title: HunyuanImageTransformer2DModel
|
||||
- local: api/models/hunyuan_video_transformer_3d
|
||||
title: HunyuanVideoTransformer3DModel
|
||||
- local: api/models/latte_transformer3d
|
||||
@@ -369,6 +381,8 @@
|
||||
title: QwenImageTransformer2DModel
|
||||
- local: api/models/sana_transformer2d
|
||||
title: SanaTransformer2DModel
|
||||
- local: api/models/sana_video_transformer3d
|
||||
title: SanaVideoTransformer3DModel
|
||||
- local: api/models/sd3_transformer2d
|
||||
title: SD3Transformer2DModel
|
||||
- local: api/models/skyreels_v2_transformer_3d
|
||||
@@ -379,6 +393,8 @@
|
||||
title: Transformer2DModel
|
||||
- local: api/models/transformer_temporal
|
||||
title: TransformerTemporalModel
|
||||
- local: api/models/wan_animate_transformer_3d
|
||||
title: WanAnimateTransformer3DModel
|
||||
- local: api/models/wan_transformer_3d
|
||||
title: WanTransformer3DModel
|
||||
title: Transformers
|
||||
@@ -411,6 +427,10 @@
|
||||
title: AutoencoderKLCogVideoX
|
||||
- local: api/models/autoencoderkl_cosmos
|
||||
title: AutoencoderKLCosmos
|
||||
- local: api/models/autoencoder_kl_hunyuanimage
|
||||
title: AutoencoderKLHunyuanImage
|
||||
- local: api/models/autoencoder_kl_hunyuanimage_refiner
|
||||
title: AutoencoderKLHunyuanImageRefiner
|
||||
- local: api/models/autoencoder_kl_hunyuan_video
|
||||
title: AutoencoderKLHunyuanVideo
|
||||
- local: api/models/autoencoderkl_ltx_video
|
||||
@@ -436,6 +456,8 @@
|
||||
- sections:
|
||||
- local: api/pipelines/overview
|
||||
title: Overview
|
||||
- local: api/pipelines/auto_pipeline
|
||||
title: AutoPipeline
|
||||
- sections:
|
||||
- local: api/pipelines/audioldm
|
||||
title: AudioLDM
|
||||
@@ -448,8 +470,6 @@
|
||||
- local: api/pipelines/stable_audio
|
||||
title: Stable Audio
|
||||
title: Audio
|
||||
- local: api/pipelines/auto_pipeline
|
||||
title: AutoPipeline
|
||||
- sections:
|
||||
- local: api/pipelines/amused
|
||||
title: aMUSEd
|
||||
@@ -463,6 +483,8 @@
|
||||
title: BLIP-Diffusion
|
||||
- local: api/pipelines/bria_3_2
|
||||
title: Bria 3.2
|
||||
- local: api/pipelines/bria_fibo
|
||||
title: Bria Fibo
|
||||
- local: api/pipelines/chroma
|
||||
title: Chroma
|
||||
- local: api/pipelines/cogview3
|
||||
@@ -505,12 +527,16 @@
|
||||
title: EasyAnimate
|
||||
- local: api/pipelines/flux
|
||||
title: Flux
|
||||
- local: api/pipelines/flux2
|
||||
title: Flux2
|
||||
- local: api/pipelines/control_flux_inpaint
|
||||
title: FluxControlInpaint
|
||||
- local: api/pipelines/hidream
|
||||
title: HiDream-I1
|
||||
- local: api/pipelines/hunyuandit
|
||||
title: Hunyuan-DiT
|
||||
- local: api/pipelines/hunyuanimage21
|
||||
title: HunyuanImage2.1
|
||||
- local: api/pipelines/pix2pix
|
||||
title: InstructPix2Pix
|
||||
- local: api/pipelines/kandinsky
|
||||
@@ -545,12 +571,16 @@
|
||||
title: PixArt-α
|
||||
- local: api/pipelines/pixart_sigma
|
||||
title: PixArt-Σ
|
||||
- local: api/pipelines/prx
|
||||
title: PRX
|
||||
- local: api/pipelines/qwenimage
|
||||
title: QwenImage
|
||||
- local: api/pipelines/sana
|
||||
title: Sana
|
||||
- local: api/pipelines/sana_sprint
|
||||
title: Sana Sprint
|
||||
- local: api/pipelines/sana_video
|
||||
title: Sana Video
|
||||
- local: api/pipelines/self_attention_guidance
|
||||
title: Self-Attention Guidance
|
||||
- local: api/pipelines/semantic_stable_diffusion
|
||||
@@ -612,6 +642,8 @@
|
||||
- sections:
|
||||
- local: api/pipelines/allegro
|
||||
title: Allegro
|
||||
- local: api/pipelines/chronoedit
|
||||
title: ChronoEdit
|
||||
- local: api/pipelines/cogvideox
|
||||
title: CogVideoX
|
||||
- local: api/pipelines/consisid
|
||||
@@ -622,6 +654,8 @@
|
||||
title: HunyuanVideo
|
||||
- local: api/pipelines/i2vgenxl
|
||||
title: I2VGen-XL
|
||||
- local: api/pipelines/kandinsky5_video
|
||||
title: Kandinsky 5.0 Video
|
||||
- local: api/pipelines/latte
|
||||
title: Latte
|
||||
- local: api/pipelines/ltx_video
|
||||
|
||||
@@ -30,7 +30,8 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
|
||||
- [`CogView4LoraLoaderMixin`] provides similar functions for [CogView4](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogview4).
|
||||
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
|
||||
- [`HiDreamImageLoraLoaderMixin`] provides similar functions for [HiDream Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hidream)
|
||||
- [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen)
|
||||
- [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen).
|
||||
- [`Flux2LoraLoaderMixin`] provides similar functions for [Flux2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux2).
|
||||
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.
|
||||
|
||||
> [!TIP]
|
||||
@@ -56,6 +57,10 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.FluxLoraLoaderMixin
|
||||
|
||||
## Flux2LoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.Flux2LoraLoaderMixin
|
||||
|
||||
## CogVideoXLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.CogVideoXLoraLoaderMixin
|
||||
|
||||
@@ -12,15 +12,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# AutoModel
|
||||
|
||||
The `AutoModel` is designed to make it easy to load a checkpoint without needing to know the specific model class. `AutoModel` automatically retrieves the correct model class from the checkpoint `config.json` file.
|
||||
|
||||
```python
|
||||
from diffusers import AutoModel, AutoPipelineForText2Image
|
||||
|
||||
unet = AutoModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet")
|
||||
pipe = AutoPipelineForText2Image.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", unet=unet)
|
||||
```
|
||||
|
||||
[`AutoModel`] automatically retrieves the correct model class from the checkpoint `config.json` file.
|
||||
|
||||
## AutoModel
|
||||
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# AutoencoderKLHunyuanImage
|
||||
|
||||
The 2D variational autoencoder (VAE) model with KL loss used in [HunyuanImage2.1].
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import AutoencoderKLHunyuanImage
|
||||
|
||||
vae = AutoencoderKLHunyuanImage.from_pretrained("hunyuanvideo-community/HunyuanImage-2.1-Diffusers", subfolder="vae", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## AutoencoderKLHunyuanImage
|
||||
|
||||
[[autodoc]] AutoencoderKLHunyuanImage
|
||||
- decode
|
||||
- all
|
||||
|
||||
## DecoderOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.vae.DecoderOutput
|
||||
@@ -0,0 +1,32 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# AutoencoderKLHunyuanImageRefiner
|
||||
|
||||
The 3D variational autoencoder (VAE) model with KL loss used in [HunyuanImage2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1) for its refiner pipeline.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import AutoencoderKLHunyuanImageRefiner
|
||||
|
||||
vae = AutoencoderKLHunyuanImageRefiner.from_pretrained("hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers", subfolder="vae", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## AutoencoderKLHunyuanImageRefiner
|
||||
|
||||
[[autodoc]] AutoencoderKLHunyuanImageRefiner
|
||||
- decode
|
||||
- all
|
||||
|
||||
## DecoderOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.vae.DecoderOutput
|
||||
@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# ChromaTransformer2DModel
|
||||
|
||||
A modified flux Transformer model from [Chroma](https://huggingface.co/lodestones/Chroma)
|
||||
A modified flux Transformer model from [Chroma](https://huggingface.co/lodestones/Chroma1-HD)
|
||||
|
||||
## ChromaTransformer2DModel
|
||||
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
<!-- Copyright 2025 The ChronoEdit Team and HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# ChronoEditTransformer3DModel
|
||||
|
||||
A Diffusion Transformer model for 3D video-like data from [ChronoEdit: Towards Temporal Reasoning for Image Editing and World Simulation](https://huggingface.co/papers/2510.04290) from NVIDIA and University of Toronto, by Jay Zhangjie Wu, Xuanchi Ren, Tianchang Shen, Tianshi Cao, Kai He, Yifan Lu, Ruiyuan Gao, Enze Xie, Shiyi Lan, Jose M. Alvarez, Jun Gao, Sanja Fidler, Zian Wang, Huan Ling.
|
||||
|
||||
> **TL;DR:** ChronoEdit reframes image editing as a video generation task, using input and edited images as start/end frames to leverage pretrained video models with temporal consistency. A temporal reasoning stage introduces reasoning tokens to ensure physically plausible edits and visualize the editing trajectory.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import ChronoEditTransformer3DModel
|
||||
|
||||
transformer = ChronoEditTransformer3DModel.from_pretrained("nvidia/ChronoEdit-14B-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## ChronoEditTransformer3DModel
|
||||
|
||||
[[autodoc]] ChronoEditTransformer3DModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
|
||||
@@ -0,0 +1,19 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Flux2Transformer2DModel
|
||||
|
||||
A Transformer model for image-like data from [Flux2](https://hf.co/black-forest-labs/FLUX.2-dev).
|
||||
|
||||
## Flux2Transformer2DModel
|
||||
|
||||
[[autodoc]] Flux2Transformer2DModel
|
||||
@@ -0,0 +1,30 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# HunyuanImageTransformer2DModel
|
||||
|
||||
A Diffusion Transformer model for [HunyuanImage2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1).
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import HunyuanImageTransformer2DModel
|
||||
|
||||
transformer = HunyuanImageTransformer2DModel.from_pretrained("hunyuanvideo-community/HunyuanImage-2.1-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## HunyuanImageTransformer2DModel
|
||||
|
||||
[[autodoc]] HunyuanImageTransformer2DModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
|
||||
@@ -0,0 +1,36 @@
|
||||
<!-- Copyright 2025 The SANA-Video Authors and HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# SanaVideoTransformer3DModel
|
||||
|
||||
A Diffusion Transformer model for 3D data (video) from [SANA-Video: Efficient Video Generation with Block Linear Diffusion Transformer](https://huggingface.co/papers/2509.24695) from NVIDIA and MIT HAN Lab, by Junsong Chen, Yuyang Zhao, Jincheng Yu, Ruihang Chu, Junyu Chen, Shuai Yang, Xianbang Wang, Yicheng Pan, Daquan Zhou, Huan Ling, Haozhe Liu, Hongwei Yi, Hao Zhang, Muyang Li, Yukang Chen, Han Cai, Sanja Fidler, Ping Luo, Song Han, Enze Xie.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*We introduce SANA-Video, a small diffusion model that can efficiently generate videos up to 720x1280 resolution and minute-length duration. SANA-Video synthesizes high-resolution, high-quality and long videos with strong text-video alignment at a remarkably fast speed, deployable on RTX 5090 GPU. Two core designs ensure our efficient, effective and long video generation: (1) Linear DiT: We leverage linear attention as the core operation, which is more efficient than vanilla attention given the large number of tokens processed in video generation. (2) Constant-Memory KV cache for Block Linear Attention: we design block-wise autoregressive approach for long video generation by employing a constant-memory state, derived from the cumulative properties of linear attention. This KV cache provides the Linear DiT with global context at a fixed memory cost, eliminating the need for a traditional KV cache and enabling efficient, minute-long video generation. In addition, we explore effective data filters and model training strategies, narrowing the training cost to 12 days on 64 H100 GPUs, which is only 1% of the cost of MovieGen. Given its low cost, SANA-Video achieves competitive performance compared to modern state-of-the-art small diffusion models (e.g., Wan 2.1-1.3B and SkyReel-V2-1.3B) while being 16x faster in measured latency. Moreover, SANA-Video can be deployed on RTX 5090 GPUs with NVFP4 precision, accelerating the inference speed of generating a 5-second 720p video from 71s to 29s (2.4x speedup). In summary, SANA-Video enables low-cost, high-quality video generation.*
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import SanaVideoTransformer3DModel
|
||||
import torch
|
||||
|
||||
transformer = SanaVideoTransformer3DModel.from_pretrained("Efficient-Large-Model/SANA-Video_2B_480p_diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## SanaVideoTransformer3DModel
|
||||
|
||||
[[autodoc]] SanaVideoTransformer3DModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# BriaFiboTransformer2DModel
|
||||
|
||||
A modified flux Transformer model from [Bria](https://huggingface.co/briaai/FIBO)
|
||||
|
||||
## BriaFiboTransformer2DModel
|
||||
|
||||
[[autodoc]] BriaFiboTransformer2DModel
|
||||
@@ -0,0 +1,30 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# WanAnimateTransformer3DModel
|
||||
|
||||
A Diffusion Transformer model for 3D video-like data was introduced in [Wan Animate](https://github.com/Wan-Video/Wan2.2) by the Alibaba Wan Team.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import WanAnimateTransformer3DModel
|
||||
|
||||
transformer = WanAnimateTransformer3DModel.from_pretrained("Wan-AI/Wan2.2-Animate-14B-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## WanAnimateTransformer3DModel
|
||||
|
||||
[[autodoc]] WanAnimateTransformer3DModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
|
||||
@@ -0,0 +1,45 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Bria Fibo
|
||||
|
||||
Text-to-image models have mastered imagination - but not control. FIBO changes that.
|
||||
|
||||
FIBO is trained on structured JSON captions up to 1,000+ words and designed to understand and control different visual parameters such as lighting, composition, color, and camera settings, enabling precise and reproducible outputs.
|
||||
|
||||
With only 8 billion parameters, FIBO provides a new level of image quality, prompt adherence and proffesional control.
|
||||
|
||||
FIBO is trained exclusively on a structured prompt and will not work with freeform text prompts.
|
||||
you can use the [FIBO-VLM-prompt-to-JSON](https://huggingface.co/briaai/FIBO-VLM-prompt-to-JSON) model or the [FIBO-gemini-prompt-to-JSON](https://huggingface.co/briaai/FIBO-gemini-prompt-to-JSON) to convert your freeform text prompt to a structured JSON prompt.
|
||||
|
||||
its not recommended to use freeform text prompts directly with FIBO, as it will not produce the best results.
|
||||
|
||||
you can learn more about FIBO in [Bria Fibo Hugging Face page](https://huggingface.co/briaai/FIBO).
|
||||
|
||||
|
||||
## Usage
|
||||
|
||||
_As the model is gated, before using it with diffusers you first need to go to the [Bria Fibo Hugging Face page](https://huggingface.co/briaai/FIBO), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._
|
||||
|
||||
Use the command below to log in:
|
||||
|
||||
```bash
|
||||
hf auth login
|
||||
```
|
||||
|
||||
|
||||
## BriaPipeline
|
||||
|
||||
[[autodoc]] BriaPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
@@ -19,20 +19,21 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
Chroma is a text to image generation model based on Flux.
|
||||
|
||||
Original model checkpoints for Chroma can be found [here](https://huggingface.co/lodestones/Chroma).
|
||||
Original model checkpoints for Chroma can be found here:
|
||||
* High-resolution finetune: [lodestones/Chroma1-HD](https://huggingface.co/lodestones/Chroma1-HD)
|
||||
* Base model: [lodestones/Chroma1-Base](https://huggingface.co/lodestones/Chroma1-Base)
|
||||
* Original repo with progress checkpoints: [lodestones/Chroma](https://huggingface.co/lodestones/Chroma) (loading this repo with `from_pretrained` will load a Diffusers-compatible version of the `unlocked-v37` checkpoint)
|
||||
|
||||
> [!TIP]
|
||||
> Chroma can use all the same optimizations as Flux.
|
||||
|
||||
## Inference
|
||||
|
||||
The Diffusers version of Chroma is based on the [`unlocked-v37`](https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors) version of the original model, which is available in the [Chroma repository](https://huggingface.co/lodestones/Chroma).
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import ChromaPipeline
|
||||
|
||||
pipe = ChromaPipeline.from_pretrained("lodestones/Chroma", torch_dtype=torch.bfloat16)
|
||||
pipe = ChromaPipeline.from_pretrained("lodestones/Chroma1-HD", torch_dtype=torch.bfloat16)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = [
|
||||
@@ -63,10 +64,10 @@ Then run the following example
|
||||
import torch
|
||||
from diffusers import ChromaTransformer2DModel, ChromaPipeline
|
||||
|
||||
model_id = "lodestones/Chroma"
|
||||
model_id = "lodestones/Chroma1-HD"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors", torch_dtype=dtype)
|
||||
transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors", torch_dtype=dtype)
|
||||
|
||||
pipe = ChromaPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=dtype)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
@@ -0,0 +1,156 @@
|
||||
<!-- Copyright 2025 The ChronoEdit Team and HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<a href="https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference" target="_blank" rel="noopener">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
# ChronoEdit
|
||||
|
||||
[ChronoEdit: Towards Temporal Reasoning for Image Editing and World Simulation](https://huggingface.co/papers/2510.04290) from NVIDIA and University of Toronto, by Jay Zhangjie Wu, Xuanchi Ren, Tianchang Shen, Tianshi Cao, Kai He, Yifan Lu, Ruiyuan Gao, Enze Xie, Shiyi Lan, Jose M. Alvarez, Jun Gao, Sanja Fidler, Zian Wang, Huan Ling.
|
||||
|
||||
> **TL;DR:** ChronoEdit reframes image editing as a video generation task, using input and edited images as start/end frames to leverage pretrained video models with temporal consistency. A temporal reasoning stage introduces reasoning tokens to ensure physically plausible edits and visualize the editing trajectory.
|
||||
|
||||
*Recent advances in large generative models have greatly enhanced both image editing and in-context image generation, yet a critical gap remains in ensuring physical consistency, where edited objects must remain coherent. This capability is especially vital for world simulation related tasks. In this paper, we present ChronoEdit, a framework that reframes image editing as a video generation problem. First, ChronoEdit treats the input and edited images as the first and last frames of a video, allowing it to leverage large pretrained video generative models that capture not only object appearance but also the implicit physics of motion and interaction through learned temporal consistency. Second, ChronoEdit introduces a temporal reasoning stage that explicitly performs editing at inference time. Under this setting, target frame is jointly denoised with reasoning tokens to imagine a plausible editing trajectory that constrains the solution space to physically viable transformations. The reasoning tokens are then dropped after a few steps to avoid the high computational cost of rendering a full video. To validate ChronoEdit, we introduce PBench-Edit, a new benchmark of image-prompt pairs for contexts that require physical consistency, and demonstrate that ChronoEdit surpasses state-of-the-art baselines in both visual fidelity and physical plausibility. Project page for code and models: [this https URL](https://research.nvidia.com/labs/toronto-ai/chronoedit).*
|
||||
|
||||
The ChronoEdit pipeline is developed by the ChronoEdit Team. The original code is available on [GitHub](https://github.com/nv-tlabs/ChronoEdit), and pretrained models can be found in the [nvidia/ChronoEdit](https://huggingface.co/collections/nvidia/chronoedit) collection on Hugging Face.
|
||||
|
||||
|
||||
### Image Editing
|
||||
|
||||
```py
|
||||
import torch
|
||||
import numpy as np
|
||||
from diffusers import AutoencoderKLWan, ChronoEditTransformer3DModel, ChronoEditPipeline
|
||||
from diffusers.utils import export_to_video, load_image
|
||||
from transformers import CLIPVisionModel
|
||||
from PIL import Image
|
||||
|
||||
model_id = "nvidia/ChronoEdit-14B-Diffusers"
|
||||
image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
|
||||
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
transformer = ChronoEditTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
pipe = ChronoEditPipeline.from_pretrained(model_id, image_encoder=image_encoder, transformer=transformer, vae=vae, torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/spaces/nvidia/ChronoEdit/resolve/main/examples/3.png"
|
||||
)
|
||||
max_area = 720 * 1280
|
||||
aspect_ratio = image.height / image.width
|
||||
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
|
||||
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
||||
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
||||
print("width", width, "height", height)
|
||||
image = image.resize((width, height))
|
||||
prompt = (
|
||||
"The user wants to transform the image by adding a small, cute mouse sitting inside the floral teacup, enjoying a spa bath. The mouse should appear relaxed and cheerful, with a tiny white bath towel draped over its head like a turban. It should be positioned comfortably in the cup’s liquid, with gentle steam rising around it to blend with the cozy atmosphere. "
|
||||
"The mouse’s pose should be natural—perhaps sitting upright with paws resting lightly on the rim or submerged in the tea. The teacup’s floral design, gold trim, and warm lighting must remain unchanged to preserve the original aesthetic. The steam should softly swirl around the mouse, enhancing the spa-like, whimsical mood."
|
||||
)
|
||||
|
||||
output = pipe(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=5,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=5.0,
|
||||
enable_temporal_reasoning=False,
|
||||
num_temporal_reasoning_steps=0,
|
||||
).frames[0]
|
||||
Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output.png")
|
||||
```
|
||||
|
||||
Optionally, enable **temporal reasoning** for improved physical consistency:
|
||||
```py
|
||||
output = pipe(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=29,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=5.0,
|
||||
enable_temporal_reasoning=True,
|
||||
num_temporal_reasoning_steps=50,
|
||||
).frames[0]
|
||||
export_to_video(output, "output.mp4", fps=16)
|
||||
Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output.png")
|
||||
```
|
||||
|
||||
### Inference with 8-Step Distillation Lora
|
||||
|
||||
```py
|
||||
import torch
|
||||
import numpy as np
|
||||
from diffusers import AutoencoderKLWan, ChronoEditTransformer3DModel, ChronoEditPipeline
|
||||
from diffusers.utils import export_to_video, load_image
|
||||
from transformers import CLIPVisionModel
|
||||
from PIL import Image
|
||||
|
||||
model_id = "nvidia/ChronoEdit-14B-Diffusers"
|
||||
image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
|
||||
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
transformer = ChronoEditTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
pipe = ChronoEditPipeline.from_pretrained(model_id, image_encoder=image_encoder, transformer=transformer, vae=vae, torch_dtype=torch.bfloat16)
|
||||
lora_path = hf_hub_download(repo_id=model_id, filename="lora/chronoedit_distill_lora.safetensors")
|
||||
pipe.load_lora_weights(lora_path)
|
||||
pipe.fuse_lora(lora_scale=1.0)
|
||||
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=2.0)
|
||||
pipe.to("cuda")
|
||||
|
||||
image = load_image(
|
||||
"https://huggingface.co/spaces/nvidia/ChronoEdit/resolve/main/examples/3.png"
|
||||
)
|
||||
max_area = 720 * 1280
|
||||
aspect_ratio = image.height / image.width
|
||||
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
|
||||
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
||||
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
||||
print("width", width, "height", height)
|
||||
image = image.resize((width, height))
|
||||
prompt = (
|
||||
"The user wants to transform the image by adding a small, cute mouse sitting inside the floral teacup, enjoying a spa bath. The mouse should appear relaxed and cheerful, with a tiny white bath towel draped over its head like a turban. It should be positioned comfortably in the cup’s liquid, with gentle steam rising around it to blend with the cozy atmosphere. "
|
||||
"The mouse’s pose should be natural—perhaps sitting upright with paws resting lightly on the rim or submerged in the tea. The teacup’s floral design, gold trim, and warm lighting must remain unchanged to preserve the original aesthetic. The steam should softly swirl around the mouse, enhancing the spa-like, whimsical mood."
|
||||
)
|
||||
|
||||
output = pipe(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=5,
|
||||
num_inference_steps=8,
|
||||
guidance_scale=1.0,
|
||||
enable_temporal_reasoning=False,
|
||||
num_temporal_reasoning_steps=0,
|
||||
).frames[0]
|
||||
export_to_video(output, "output.mp4", fps=16)
|
||||
Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output.png")
|
||||
```
|
||||
|
||||
## ChronoEditPipeline
|
||||
|
||||
[[autodoc]] ChronoEditPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## ChronoEditPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.chronoedit.pipeline_output.ChronoEditPipelineOutput
|
||||
@@ -0,0 +1,33 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Flux2
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
|
||||
</div>
|
||||
|
||||
Flux.2 is the recent series of image generation models from Black Forest Labs, preceded by the [Flux.1](./flux.md) series. It is an entirely new model with a new architecture and pre-training done from scratch!
|
||||
|
||||
Original model checkpoints for Flux can be found [here](https://huggingface.co/black-forest-labs). Original inference code can be found [here](https://github.com/black-forest-labs/flux2).
|
||||
|
||||
> [!TIP]
|
||||
> Flux2 can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more.
|
||||
>
|
||||
> [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
|
||||
|
||||
## Flux2Pipeline
|
||||
|
||||
[[autodoc]] Flux2Pipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -0,0 +1,152 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# HunyuanImage2.1
|
||||
|
||||
|
||||
HunyuanImage-2.1 is a 17B text-to-image model that is capable of generating 2K (2048 x 2048) resolution images
|
||||
|
||||
HunyuanImage-2.1 comes in the following variants:
|
||||
|
||||
| model type | model id |
|
||||
|:----------:|:--------:|
|
||||
| HunyuanImage-2.1 | [hunyuanvideo-community/HunyuanImage-2.1-Diffusers](https://huggingface.co/hunyuanvideo-community/HunyuanImage-2.1-Diffusers) |
|
||||
| HunyuanImage-2.1-Distilled | [hunyuanvideo-community/HunyuanImage-2.1-Distilled-Diffusers](https://huggingface.co/hunyuanvideo-community/HunyuanImage-2.1-Distilled-Diffusers) |
|
||||
| HunyuanImage-2.1-Refiner | [hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers](https://huggingface.co/hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers) |
|
||||
|
||||
> [!TIP]
|
||||
> [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
|
||||
|
||||
## HunyuanImage-2.1
|
||||
|
||||
HunyuanImage-2.1 applies [Adaptive Projected Guidance (APG)](https://huggingface.co/papers/2410.02416) combined with Classifier-Free Guidance (CFG) in the denoising loop. `HunyuanImagePipeline` has a `guider` component (read more about [Guider](../modular_diffusers/guiders.md)) and does not take a `guidance_scale` parameter at runtime. To change guider-related parameters, e.g., `guidance_scale`, you can update the `guider` configuration instead.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import HunyuanImagePipeline
|
||||
|
||||
pipe = HunyuanImagePipeline.from_pretrained(
|
||||
"hunyuanvideo-community/HunyuanImage-2.1-Diffusers",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe = pipe.to("cuda")
|
||||
```
|
||||
|
||||
You can inspect the `guider` object:
|
||||
|
||||
```py
|
||||
>>> pipe.guider
|
||||
AdaptiveProjectedMixGuidance {
|
||||
"_class_name": "AdaptiveProjectedMixGuidance",
|
||||
"_diffusers_version": "0.36.0.dev0",
|
||||
"adaptive_projected_guidance_momentum": -0.5,
|
||||
"adaptive_projected_guidance_rescale": 10.0,
|
||||
"adaptive_projected_guidance_scale": 10.0,
|
||||
"adaptive_projected_guidance_start_step": 5,
|
||||
"enabled": true,
|
||||
"eta": 0.0,
|
||||
"guidance_rescale": 0.0,
|
||||
"guidance_scale": 3.5,
|
||||
"start": 0.0,
|
||||
"stop": 1.0,
|
||||
"use_original_formulation": false
|
||||
}
|
||||
|
||||
State:
|
||||
step: None
|
||||
num_inference_steps: None
|
||||
timestep: None
|
||||
count_prepared: 0
|
||||
enabled: True
|
||||
num_conditions: 2
|
||||
momentum_buffer: None
|
||||
is_apg_enabled: False
|
||||
is_cfg_enabled: True
|
||||
```
|
||||
|
||||
To update the guider with a different configuration, use the `new()` method. For example, to generate an image with `guidance_scale=5.0` while keeping all other default guidance parameters:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import HunyuanImagePipeline
|
||||
|
||||
pipe = HunyuanImagePipeline.from_pretrained(
|
||||
"hunyuanvideo-community/HunyuanImage-2.1-Diffusers",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
# Update the guider configuration
|
||||
pipe.guider = pipe.guider.new(guidance_scale=5.0)
|
||||
|
||||
prompt = (
|
||||
"A cute, cartoon-style anthropomorphic penguin plush toy with fluffy fur, standing in a painting studio, "
|
||||
"wearing a red knitted scarf and a red beret with the word 'Tencent' on it, holding a paintbrush with a "
|
||||
"focused expression as it paints an oil painting of the Mona Lisa, rendered in a photorealistic photographic style."
|
||||
)
|
||||
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
num_inference_steps=50,
|
||||
height=2048,
|
||||
width=2048,
|
||||
).images[0]
|
||||
image.save("image.png")
|
||||
```
|
||||
|
||||
|
||||
## HunyuanImage-2.1-Distilled
|
||||
|
||||
use `distilled_guidance_scale` with the guidance-distilled checkpoint,
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import HunyuanImagePipeline
|
||||
pipe = HunyuanImagePipeline.from_pretrained("hunyuanvideo-community/HunyuanImage-2.1-Distilled-Diffusers", torch_dtype=torch.bfloat16)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = (
|
||||
"A cute, cartoon-style anthropomorphic penguin plush toy with fluffy fur, standing in a painting studio, "
|
||||
"wearing a red knitted scarf and a red beret with the word 'Tencent' on it, holding a paintbrush with a "
|
||||
"focused expression as it paints an oil painting of the Mona Lisa, rendered in a photorealistic photographic style."
|
||||
)
|
||||
|
||||
out = pipe(
|
||||
prompt,
|
||||
num_inference_steps=8,
|
||||
distilled_guidance_scale=3.25,
|
||||
height=2048,
|
||||
width=2048,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
|
||||
```
|
||||
|
||||
|
||||
## HunyuanImagePipeline
|
||||
|
||||
[[autodoc]] HunyuanImagePipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## HunyuanImageRefinerPipeline
|
||||
|
||||
[[autodoc]] HunyuanImageRefinerPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
## HunyuanImagePipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.hunyuan_image.pipeline_output.HunyuanImagePipelineOutput
|
||||
@@ -0,0 +1,149 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Kandinsky 5.0 Video
|
||||
|
||||
Kandinsky 5.0 Video is created by the Kandinsky team: Alexey Letunovskiy, Maria Kovaleva, Ivan Kirillov, Lev Novitskiy, Denis Koposov, Dmitrii Mikhailov, Anna Averchenkova, Andrey Shutkin, Julia Agafonova, Olga Kim, Anastasiia Kargapoltseva, Nikita Kiselev, Anna Dmitrienko, Anastasia Maltseva, Kirill Chernyshev, Ilia Vasiliev, Viacheslav Vasilev, Vladimir Polovnikov, Yury Kolabushin, Alexander Belykh, Mikhail Mamaev, Anastasia Aliaskina, Tatiana Nikulina, Polina Gavrilova, Vladimir Arkhipkin, Vladimir Korviakov, Nikolai Gerasimenko, Denis Parkhomenko, Denis Dimitrov
|
||||
|
||||
|
||||
Kandinsky 5.0 is a family of diffusion models for Video & Image generation. Kandinsky 5.0 T2V Lite is a lightweight video generation model (2B parameters) that ranks #1 among open-source models in its class. It outperforms larger models and offers the best understanding of Russian concepts in the open-source ecosystem.
|
||||
|
||||
The model introduces several key innovations:
|
||||
- **Latent diffusion pipeline** with **Flow Matching** for improved training stability
|
||||
- **Diffusion Transformer (DiT)** as the main generative backbone with cross-attention to text embeddings
|
||||
- Dual text encoding using **Qwen2.5-VL** and **CLIP** for comprehensive text understanding
|
||||
- **HunyuanVideo 3D VAE** for efficient video encoding and decoding
|
||||
- **Sparse attention mechanisms** (NABLA) for efficient long-sequence processing
|
||||
|
||||
The original codebase can be found at [ai-forever/Kandinsky-5](https://github.com/ai-forever/Kandinsky-5).
|
||||
|
||||
> [!TIP]
|
||||
> Check out the [AI Forever](https://huggingface.co/ai-forever) organization on the Hub for the official model checkpoints for text-to-video generation, including pretrained, SFT, no-CFG, and distilled variants.
|
||||
|
||||
## Available Models
|
||||
|
||||
Kandinsky 5.0 T2V Lite comes in several variants optimized for different use cases:
|
||||
|
||||
| model_id | Description | Use Cases |
|
||||
|------------|-------------|-----------|
|
||||
| **ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers** | 5 second Supervised Fine-Tuned model | Highest generation quality |
|
||||
| **ai-forever/Kandinsky-5.0-T2V-Lite-sft-10s-Diffusers** | 10 second Supervised Fine-Tuned model | Highest generation quality |
|
||||
| **ai-forever/Kandinsky-5.0-T2V-Lite-nocfg-5s-Diffusers** | 5 second Classifier-Free Guidance distilled | 2× faster inference |
|
||||
| **ai-forever/Kandinsky-5.0-T2V-Lite-nocfg-10s-Diffusers** | 10 second Classifier-Free Guidance distilled | 2× faster inference |
|
||||
| **ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers** | 5 second Diffusion distilled to 16 steps | 6× faster inference, minimal quality loss |
|
||||
| **ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-10s-Diffusers** | 10 second Diffusion distilled to 16 steps | 6× faster inference, minimal quality loss |
|
||||
| **ai-forever/Kandinsky-5.0-T2V-Lite-pretrain-5s-Diffusers** | 5 second Base pretrained model | Research and fine-tuning |
|
||||
| **ai-forever/Kandinsky-5.0-T2V-Lite-pretrain-10s-Diffusers** | 10 second Base pretrained model | Research and fine-tuning |
|
||||
|
||||
All models are available in 5-second and 10-second video generation versions.
|
||||
|
||||
## Kandinsky5T2VPipeline
|
||||
|
||||
[[autodoc]] Kandinsky5T2VPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Basic Text-to-Video Generation
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import Kandinsky5T2VPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
# Load the pipeline
|
||||
model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers"
|
||||
pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
# Generate video
|
||||
prompt = "A cat and a dog baking a cake together in a kitchen."
|
||||
negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
|
||||
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=512,
|
||||
width=768,
|
||||
num_frames=121, # ~5 seconds at 24fps
|
||||
num_inference_steps=50,
|
||||
guidance_scale=5.0,
|
||||
).frames[0]
|
||||
|
||||
export_to_video(output, "output.mp4", fps=24, quality=9)
|
||||
```
|
||||
|
||||
### 10 second Models
|
||||
**⚠️ Warning!** all 10 second models should be used with Flex attention and max-autotune-no-cudagraphs compilation:
|
||||
|
||||
```python
|
||||
pipe = Kandinsky5T2VPipeline.from_pretrained(
|
||||
"ai-forever/Kandinsky-5.0-T2V-Lite-sft-10s-Diffusers",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
pipe.transformer.set_attention_backend(
|
||||
"flex"
|
||||
) # <--- Sett attention bakend to Flex
|
||||
pipe.transformer.compile(
|
||||
mode="max-autotune-no-cudagraphs",
|
||||
dynamic=True
|
||||
) # <--- Compile with max-autotune-no-cudagraphs
|
||||
|
||||
prompt = "A cat and a dog baking a cake together in a kitchen."
|
||||
negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
|
||||
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=512,
|
||||
width=768,
|
||||
num_frames=241,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=5.0,
|
||||
).frames[0]
|
||||
|
||||
export_to_video(output, "output.mp4", fps=24, quality=9)
|
||||
```
|
||||
|
||||
### Diffusion Distilled model
|
||||
**⚠️ Warning!** all nocfg and diffusion distilled models should be infered wothout CFG (```guidance_scale=1.0```):
|
||||
|
||||
```python
|
||||
model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers"
|
||||
pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
output = pipe(
|
||||
prompt="A beautiful sunset over mountains",
|
||||
num_inference_steps=16, # <--- Model is distilled in 16 steps
|
||||
guidance_scale=1.0, # <--- no CFG
|
||||
).frames[0]
|
||||
|
||||
export_to_video(output, "output.mp4", fps=24, quality=9)
|
||||
```
|
||||
|
||||
|
||||
## Citation
|
||||
```bibtex
|
||||
@misc{kandinsky2025,
|
||||
author = {Alexey Letunovskiy and Maria Kovaleva and Ivan Kirillov and Lev Novitskiy and Denis Koposov and
|
||||
Dmitrii Mikhailov and Anna Averchenkova and Andrey Shutkin and Julia Agafonova and Olga Kim and
|
||||
Anastasiia Kargapoltseva and Nikita Kiselev and Vladimir Arkhipkin and Vladimir Korviakov and
|
||||
Nikolai Gerasimenko and Denis Parkhomenko and Anna Dmitrienko and Anastasia Maltseva and
|
||||
Kirill Chernyshev and Ilia Vasiliev and Viacheslav Vasilev and Vladimir Polovnikov and
|
||||
Yury Kolabushin and Alexander Belykh and Mikhail Mamaev and Anastasia Aliaskina and
|
||||
Tatiana Nikulina and Polina Gavrilova and Denis Dimitrov},
|
||||
title = {Kandinsky 5.0: A family of diffusion models for Video & Image generation},
|
||||
howpublished = {\url{https://github.com/ai-forever/Kandinsky-5}},
|
||||
year = 2025
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,131 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# PRX
|
||||
|
||||
|
||||
PRX generates high-quality images from text using a simplified MMDIT architecture where text tokens don't update through transformer blocks. It employs flow matching with discrete scheduling for efficient sampling and uses Google's T5Gemma-2B-2B-UL2 model for multi-language text encoding. The ~1.3B parameter transformer delivers fast inference without sacrificing quality. You can choose between Flux VAE (8x compression, 16 latent channels) for balanced quality and speed or DC-AE (32x compression, 32 latent channels) for latent compression and faster processing.
|
||||
|
||||
## Available models
|
||||
|
||||
PRX offers multiple variants with different VAE configurations, each optimized for specific resolutions. Base models excel with detailed prompts, capturing complex compositions and subtle details. Fine-tuned models trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) improve aesthetic quality, especially with simpler prompts.
|
||||
|
||||
|
||||
| Model | Resolution | Fine-tuned | Distilled | Description | Suggested prompts | Suggested parameters | Recommended dtype |
|
||||
|:-----:|:-----------------:|:----------:|:----------:|:----------:|:----------:|:----------:|:----------:|
|
||||
| [`Photoroom/prx-256-t2i`](https://huggingface.co/Photoroom/prx-256-t2i)| 256 | No | No | Base model pre-trained at 256 with Flux VAE|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/prx-256-t2i-sft`](https://huggingface.co/Photoroom/prx-256-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/prx-512-t2i`](https://huggingface.co/Photoroom/prx-512-t2i)| 512 | No | No | Base model pre-trained at 512 with Flux VAE |Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/prx-512-t2i-sft`](https://huggingface.co/Photoroom/prx-512-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/prx-512-t2i-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-sft`](https://huggingface.co/Photoroom/prx-512-t2i-sft) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |
|
||||
| [`Photoroom/prx-512-t2i-dc-ae`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae)| 512 | No | No | Base model pre-trained at 512 with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae)|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/prx-512-t2i-dc-ae-sft`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae) | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |s
|
||||
|
||||
Refer to [this](https://huggingface.co/collections/Photoroom/prx-models-68e66254c202ebfab99ad38e) collection for more information.
|
||||
|
||||
## Loading the pipeline
|
||||
|
||||
Load the pipeline with [`~DiffusionPipeline.from_pretrained`].
|
||||
|
||||
```py
|
||||
from diffusers.pipelines.prx import PRXPipeline
|
||||
|
||||
# Load pipeline - VAE and text encoder will be loaded from HuggingFace
|
||||
pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "A front-facing portrait of a lion the golden savanna at sunset."
|
||||
image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
|
||||
image.save("prx_output.png")
|
||||
```
|
||||
|
||||
### Manual Component Loading
|
||||
|
||||
Load components individually to customize the pipeline for instance to use quantized models.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.pipelines.prx import PRXPipeline
|
||||
from diffusers.models import AutoencoderKL, AutoencoderDC
|
||||
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from transformers import T5GemmaModel, GemmaTokenizerFast
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
|
||||
from transformers import BitsAndBytesConfig as BitsAndBytesConfig
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
|
||||
# Load transformer
|
||||
transformer = PRXTransformer2DModel.from_pretrained(
|
||||
"checkpoints/prx-512-t2i-sft",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
# Load scheduler
|
||||
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
||||
"checkpoints/prx-512-t2i-sft", subfolder="scheduler"
|
||||
)
|
||||
|
||||
# Load T5Gemma text encoder
|
||||
t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.bfloat16)
|
||||
text_encoder = t5gemma_model.encoder.to(dtype=torch.bfloat16)
|
||||
tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2")
|
||||
tokenizer.model_max_length = 256
|
||||
|
||||
# Load VAE - choose either Flux VAE or DC-AE
|
||||
# Flux VAE
|
||||
vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev",
|
||||
subfolder="vae",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.bfloat16)
|
||||
|
||||
pipe = PRXPipeline(
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae
|
||||
)
|
||||
pipe.to("cuda")
|
||||
```
|
||||
|
||||
|
||||
## Memory Optimization
|
||||
|
||||
For memory-constrained environments:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.pipelines.prx import PRXPipeline
|
||||
|
||||
pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16)
|
||||
pipe.enable_model_cpu_offload() # Offload components to CPU when not in use
|
||||
|
||||
# Or use sequential CPU offload for even lower memory
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
```
|
||||
|
||||
## PRXPipeline
|
||||
|
||||
[[autodoc]] PRXPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## PRXPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.prx.pipeline_output.PRXPipelineOutput
|
||||
@@ -24,9 +24,6 @@ The abstract from the paper is:
|
||||
|
||||
*This paper presents SANA-Sprint, an efficient diffusion model for ultra-fast text-to-image (T2I) generation. SANA-Sprint is built on a pre-trained foundation model and augmented with hybrid distillation, dramatically reducing inference steps from 20 to 1-4. We introduce three key innovations: (1) We propose a training-free approach that transforms a pre-trained flow-matching model for continuous-time consistency distillation (sCM), eliminating costly training from scratch and achieving high training efficiency. Our hybrid distillation strategy combines sCM with latent adversarial distillation (LADD): sCM ensures alignment with the teacher model, while LADD enhances single-step generation fidelity. (2) SANA-Sprint is a unified step-adaptive model that achieves high-quality generation in 1-4 steps, eliminating step-specific training and improving efficiency. (3) We integrate ControlNet with SANA-Sprint for real-time interactive image generation, enabling instant visual feedback for user interaction. SANA-Sprint establishes a new Pareto frontier in speed-quality tradeoffs, achieving state-of-the-art performance with 7.59 FID and 0.74 GenEval in only 1 step — outperforming FLUX-schnell (7.94 FID / 0.71 GenEval) while being 10× faster (0.1s vs 1.1s on H100). It also achieves 0.1s (T2I) and 0.25s (ControlNet) latency for 1024×1024 images on H100, and 0.31s (T2I) on an RTX 4090, showcasing its exceptional efficiency and potential for AI-powered consumer applications (AIPC). Code and pre-trained models will be open-sourced.*
|
||||
|
||||
> [!TIP]
|
||||
> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
This pipeline was contributed by [lawrence-cj](https://github.com/lawrence-cj), [shuchen Xue](https://github.com/scxue) and [Enze Xie](https://github.com/xieenze). The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://huggingface.co/Efficient-Large-Model/).
|
||||
|
||||
Available models:
|
||||
|
||||
@@ -0,0 +1,189 @@
|
||||
<!-- Copyright 2025 The SANA-Video Authors and HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# Sana-Video
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
<img alt="MPS" src="https://img.shields.io/badge/MPS-000000?style=flat&logo=apple&logoColor=white%22">
|
||||
</div>
|
||||
|
||||
[SANA-Video: Efficient Video Generation with Block Linear Diffusion Transformer](https://huggingface.co/papers/2509.24695) from NVIDIA and MIT HAN Lab, by Junsong Chen, Yuyang Zhao, Jincheng Yu, Ruihang Chu, Junyu Chen, Shuai Yang, Xianbang Wang, Yicheng Pan, Daquan Zhou, Huan Ling, Haozhe Liu, Hongwei Yi, Hao Zhang, Muyang Li, Yukang Chen, Han Cai, Sanja Fidler, Ping Luo, Song Han, Enze Xie.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*We introduce SANA-Video, a small diffusion model that can efficiently generate videos up to 720x1280 resolution and minute-length duration. SANA-Video synthesizes high-resolution, high-quality and long videos with strong text-video alignment at a remarkably fast speed, deployable on RTX 5090 GPU. Two core designs ensure our efficient, effective and long video generation: (1) Linear DiT: We leverage linear attention as the core operation, which is more efficient than vanilla attention given the large number of tokens processed in video generation. (2) Constant-Memory KV cache for Block Linear Attention: we design block-wise autoregressive approach for long video generation by employing a constant-memory state, derived from the cumulative properties of linear attention. This KV cache provides the Linear DiT with global context at a fixed memory cost, eliminating the need for a traditional KV cache and enabling efficient, minute-long video generation. In addition, we explore effective data filters and model training strategies, narrowing the training cost to 12 days on 64 H100 GPUs, which is only 1% of the cost of MovieGen. Given its low cost, SANA-Video achieves competitive performance compared to modern state-of-the-art small diffusion models (e.g., Wan 2.1-1.3B and SkyReel-V2-1.3B) while being 16x faster in measured latency. Moreover, SANA-Video can be deployed on RTX 5090 GPUs with NVFP4 precision, accelerating the inference speed of generating a 5-second 720p video from 71s to 29s (2.4x speedup). In summary, SANA-Video enables low-cost, high-quality video generation. [this https URL](https://github.com/NVlabs/SANA).*
|
||||
|
||||
This pipeline was contributed by SANA Team. The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://hf.co/collections/Efficient-Large-Model/sana-video).
|
||||
|
||||
Available models:
|
||||
|
||||
| Model | Recommended dtype |
|
||||
|:-----:|:-----------------:|
|
||||
| [`Efficient-Large-Model/SANA-Video_2B_480p_diffusers`](https://huggingface.co/Efficient-Large-Model/ANA-Video_2B_480p_diffusers) | `torch.bfloat16` |
|
||||
|
||||
Refer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-video) collection for more information.
|
||||
|
||||
Note: The recommended dtype mentioned is for the transformer weights. The text encoder and VAE weights must stay in `torch.bfloat16` or `torch.float32` for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype.
|
||||
|
||||
|
||||
## Generation Pipelines
|
||||
|
||||
<hfoptions id="generation pipelines">`
|
||||
<hfoption id="Text-to-Video">
|
||||
|
||||
The example below demonstrates how to use the text-to-video pipeline to generate a video using a text description.
|
||||
|
||||
```python
|
||||
pipe = SanaVideoPipeline.from_pretrained(
|
||||
"Efficient-Large-Model/SANA-Video_2B_480p_diffusers",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe.text_encoder.to(torch.bfloat16)
|
||||
pipe.vae.to(torch.float32)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
|
||||
negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
|
||||
motion_scale = 30
|
||||
motion_prompt = f" motion score: {motion_scale}."
|
||||
prompt = prompt + motion_prompt
|
||||
|
||||
video = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=480,
|
||||
width=832,
|
||||
frames=81,
|
||||
guidance_scale=6,
|
||||
num_inference_steps=50,
|
||||
generator=torch.Generator(device="cuda").manual_seed(0),
|
||||
).frames[0]
|
||||
|
||||
export_to_video(video, "sana_video.mp4", fps=16)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Image-to-Video">
|
||||
|
||||
The example below demonstrates how to use the image-to-video pipeline to generate a video using a text description and a starting frame.
|
||||
|
||||
```python
|
||||
pipe = SanaImageToVideoPipeline.from_pretrained(
|
||||
"Efficient-Large-Model/SANA-Video_2B_480p_diffusers",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
|
||||
pipe.vae.to(torch.float32)
|
||||
pipe.text_encoder.to(torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
image = load_image("https://raw.githubusercontent.com/NVlabs/Sana/refs/heads/main/asset/samples/i2v-1.png")
|
||||
prompt = "A woman stands against a stunning sunset backdrop, her long, wavy brown hair gently blowing in the breeze. She wears a sleeveless, light-colored blouse with a deep V-neckline, which accentuates her graceful posture. The warm hues of the setting sun cast a golden glow across her face and hair, creating a serene and ethereal atmosphere. The background features a blurred landscape with soft, rolling hills and scattered clouds, adding depth to the scene. The camera remains steady, capturing the tranquil moment from a medium close-up angle."
|
||||
negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
|
||||
motion_scale = 30
|
||||
motion_prompt = f" motion score: {motion_scale}."
|
||||
prompt = prompt + motion_prompt
|
||||
|
||||
motion_scale = 30.0
|
||||
|
||||
video = pipe(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=480,
|
||||
width=832,
|
||||
frames=81,
|
||||
guidance_scale=6,
|
||||
num_inference_steps=50,
|
||||
generator=torch.Generator(device="cuda").manual_seed(0),
|
||||
).frames[0]
|
||||
|
||||
export_to_video(video, "sana-i2v.mp4", fps=16)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
|
||||
## Quantization
|
||||
|
||||
Quantization helps reduce the memory requirements of very large models by storing model weights in a lower precision data type. However, quantization may have varying impact on video quality depending on the video model.
|
||||
|
||||
Refer to the [Quantization](../../quantization/overview) overview to learn more about supported quantization backends and selecting a quantization backend that supports your use case. The example below demonstrates how to load a quantized [`SanaVideoPipeline`] for inference with bitsandbytes.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig, SanaVideoTransformer3DModel, SanaVideoPipeline
|
||||
from transformers import BitsAndBytesConfig as BitsAndBytesConfig, AutoModel
|
||||
|
||||
quant_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
text_encoder_8bit = AutoModel.from_pretrained(
|
||||
"Efficient-Large-Model/SANA-Video_2B_480p_diffusers",
|
||||
subfolder="text_encoder",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
|
||||
transformer_8bit = SanaVideoTransformer3DModel.from_pretrained(
|
||||
"Efficient-Large-Model/SANA-Video_2B_480p_diffusers",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
pipeline = SanaVideoPipeline.from_pretrained(
|
||||
"Efficient-Large-Model/SANA-Video_2B_480p_diffusers",
|
||||
text_encoder=text_encoder_8bit,
|
||||
transformer=transformer_8bit,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="balanced",
|
||||
)
|
||||
|
||||
model_score = 30
|
||||
prompt = "Evening, backlight, side lighting, soft light, high contrast, mid-shot, centered composition, clean solo shot, warm color. A young Caucasian man stands in a forest, golden light glimmers on his hair as sunlight filters through the leaves. He wears a light shirt, wind gently blowing his hair and collar, light dances across his face with his movements. The background is blurred, with dappled light and soft tree shadows in the distance. The camera focuses on his lifted gaze, clear and emotional."
|
||||
negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
|
||||
motion_prompt = f" motion score: {model_score}."
|
||||
prompt = prompt + motion_prompt
|
||||
|
||||
output = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=480,
|
||||
width=832,
|
||||
num_frames=81,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=50
|
||||
).frames[0]
|
||||
export_to_video(output, "sana-video-output.mp4", fps=16)
|
||||
```
|
||||
|
||||
## SanaVideoPipeline
|
||||
|
||||
[[autodoc]] SanaVideoPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
## SanaImageToVideoPipeline
|
||||
|
||||
[[autodoc]] SanaImageToVideoPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
## SanaVideoPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.sana_video.pipeline_sana_video.SanaVideoPipelineOutput
|
||||
@@ -40,6 +40,7 @@ The following Wan models are supported in Diffusers:
|
||||
- [Wan 2.2 T2V 14B](https://huggingface.co/Wan-AI/Wan2.2-T2V-A14B-Diffusers)
|
||||
- [Wan 2.2 I2V 14B](https://huggingface.co/Wan-AI/Wan2.2-I2V-A14B-Diffusers)
|
||||
- [Wan 2.2 TI2V 5B](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B-Diffusers)
|
||||
- [Wan 2.2 Animate 14B](https://huggingface.co/Wan-AI/Wan2.2-Animate-14B-Diffusers)
|
||||
|
||||
> [!TIP]
|
||||
> Click on the Wan models in the right sidebar for more examples of video generation.
|
||||
@@ -95,15 +96,15 @@ pipeline = WanPipeline.from_pretrained(
|
||||
pipeline.to("cuda")
|
||||
|
||||
prompt = """
|
||||
The camera rushes from far to near in a low-angle shot,
|
||||
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
|
||||
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
|
||||
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
|
||||
The camera rushes from far to near in a low-angle shot,
|
||||
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
|
||||
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
|
||||
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
|
||||
shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
|
||||
"""
|
||||
negative_prompt = """
|
||||
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
|
||||
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
|
||||
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
|
||||
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
|
||||
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
|
||||
"""
|
||||
|
||||
@@ -150,15 +151,15 @@ pipeline.transformer = torch.compile(
|
||||
)
|
||||
|
||||
prompt = """
|
||||
The camera rushes from far to near in a low-angle shot,
|
||||
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
|
||||
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
|
||||
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
|
||||
The camera rushes from far to near in a low-angle shot,
|
||||
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
|
||||
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
|
||||
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
|
||||
shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
|
||||
"""
|
||||
negative_prompt = """
|
||||
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
|
||||
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
|
||||
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
|
||||
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
|
||||
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
|
||||
"""
|
||||
|
||||
@@ -249,6 +250,208 @@ The code snippets available in [this](https://github.com/huggingface/diffusers/p
|
||||
|
||||
The general rule of thumb to keep in mind when preparing inputs for the VACE pipeline is that the input images, or frames of a video that you want to use for conditioning, should have a corresponding mask that is black in color. The black mask signifies that the model will not generate new content for that area, and only use those parts for conditioning the generation process. For parts/frames that should be generated by the model, the mask should be white in color.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Wan-Animate: Unified Character Animation and Replacement with Holistic Replication
|
||||
|
||||
[Wan-Animate](https://huggingface.co/papers/2509.14055) by the Wan Team.
|
||||
|
||||
*We introduce Wan-Animate, a unified framework for character animation and replacement. Given a character image and a reference video, Wan-Animate can animate the character by precisely replicating the expressions and movements of the character in the video to generate high-fidelity character videos. Alternatively, it can integrate the animated character into the reference video to replace the original character, replicating the scene's lighting and color tone to achieve seamless environmental integration. Wan-Animate is built upon the Wan model. To adapt it for character animation tasks, we employ a modified input paradigm to differentiate between reference conditions and regions for generation. This design unifies multiple tasks into a common symbolic representation. We use spatially-aligned skeleton signals to replicate body motion and implicit facial features extracted from source images to reenact expressions, enabling the generation of character videos with high controllability and expressiveness. Furthermore, to enhance environmental integration during character replacement, we develop an auxiliary Relighting LoRA. This module preserves the character's appearance consistency while applying the appropriate environmental lighting and color tone. Experimental results demonstrate that Wan-Animate achieves state-of-the-art performance. We are committed to open-sourcing the model weights and its source code.*
|
||||
|
||||
The project page: https://humanaigc.github.io/wan-animate
|
||||
|
||||
This model was mostly contributed by [M. Tolga Cangöz](https://github.com/tolgacangoz).
|
||||
|
||||
#### Usage
|
||||
|
||||
The Wan-Animate pipeline supports two modes of operation:
|
||||
|
||||
1. **Animation Mode** (default): Animates a character image based on motion and expression from reference videos
|
||||
2. **Replacement Mode**: Replaces a character in a background video with a new character while preserving the scene
|
||||
|
||||
##### Prerequisites
|
||||
|
||||
Before using the pipeline, you need to preprocess your reference video to extract:
|
||||
- **Pose video**: Contains skeletal keypoints representing body motion
|
||||
- **Face video**: Contains facial feature representations for expression control
|
||||
|
||||
For replacement mode, you additionally need:
|
||||
- **Background video**: The original video containing the scene
|
||||
- **Mask video**: A mask indicating where to generate content (white) vs. preserve original (black)
|
||||
|
||||
> [!NOTE]
|
||||
> Raw videos should not be used for inputs such as `pose_video`, which the pipeline expects to be preprocessed to extract the proper information. Preprocessing scripts to prepare these inputs are available in the [original Wan-Animate repository](https://github.com/Wan-Video/Wan2.2?tab=readme-ov-file#1-preprocessing). Integration of these preprocessing steps into Diffusers is planned for a future release.
|
||||
|
||||
The example below demonstrates how to use the Wan-Animate pipeline:
|
||||
|
||||
<hfoptions id="Animate usage">
|
||||
<hfoption id="Animation mode">
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import AutoencoderKLWan, WanAnimatePipeline
|
||||
from diffusers.utils import export_to_video, load_image, load_video
|
||||
|
||||
model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"
|
||||
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
pipe = WanAnimatePipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
# Load character image and preprocessed videos
|
||||
image = load_image("path/to/character.jpg")
|
||||
pose_video = load_video("path/to/pose_video.mp4") # Preprocessed skeletal keypoints
|
||||
face_video = load_video("path/to/face_video.mp4") # Preprocessed facial features
|
||||
|
||||
# Resize image to match VAE constraints
|
||||
def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
|
||||
aspect_ratio = image.height / image.width
|
||||
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
|
||||
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
||||
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
||||
image = image.resize((width, height))
|
||||
return image, height, width
|
||||
|
||||
image, height, width = aspect_ratio_resize(image, pipe)
|
||||
|
||||
prompt = "A person dancing energetically in a studio with dynamic lighting and professional camera work"
|
||||
negative_prompt = "blurry, low quality, distorted, deformed, static, poorly drawn"
|
||||
|
||||
# Generate animated video
|
||||
output = pipe(
|
||||
image=image,
|
||||
pose_video=pose_video,
|
||||
face_video=face_video,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
segment_frame_length=77,
|
||||
guidance_scale=1.0,
|
||||
mode="animate", # Animation mode (default)
|
||||
).frames[0]
|
||||
export_to_video(output, "animated_character.mp4", fps=30)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Replacement mode">
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import AutoencoderKLWan, WanAnimatePipeline
|
||||
from diffusers.utils import export_to_video, load_image, load_video
|
||||
|
||||
model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"
|
||||
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
pipe = WanAnimatePipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
# Load all required inputs for replacement mode
|
||||
image = load_image("path/to/new_character.jpg")
|
||||
pose_video = load_video("path/to/pose_video.mp4") # Preprocessed skeletal keypoints
|
||||
face_video = load_video("path/to/face_video.mp4") # Preprocessed facial features
|
||||
background_video = load_video("path/to/background_video.mp4") # Original scene
|
||||
mask_video = load_video("path/to/mask_video.mp4") # Black: preserve, White: generate
|
||||
|
||||
# Resize image to match video dimensions
|
||||
def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
|
||||
aspect_ratio = image.height / image.width
|
||||
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
|
||||
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
||||
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
||||
image = image.resize((width, height))
|
||||
return image, height, width
|
||||
|
||||
image, height, width = aspect_ratio_resize(image, pipe)
|
||||
|
||||
prompt = "A person seamlessly integrated into the scene with consistent lighting and environment"
|
||||
negative_prompt = "blurry, low quality, inconsistent lighting, floating, disconnected from scene"
|
||||
|
||||
# Replace character in background video
|
||||
output = pipe(
|
||||
image=image,
|
||||
pose_video=pose_video,
|
||||
face_video=face_video,
|
||||
background_video=background_video,
|
||||
mask_video=mask_video,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
segment_frame_lengths=77,
|
||||
guidance_scale=1.0,
|
||||
mode="replace", # Replacement mode
|
||||
).frames[0]
|
||||
export_to_video(output, "character_replaced.mp4", fps=30)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Advanced options">
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import AutoencoderKLWan, WanAnimatePipeline
|
||||
from diffusers.utils import export_to_video, load_image, load_video
|
||||
|
||||
model_id = "Wan-AI/Wan2.2-Animate-14B-Diffusers"
|
||||
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
pipe = WanAnimatePipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
image = load_image("path/to/character.jpg")
|
||||
pose_video = load_video("path/to/pose_video.mp4")
|
||||
face_video = load_video("path/to/face_video.mp4")
|
||||
|
||||
def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
|
||||
aspect_ratio = image.height / image.width
|
||||
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
|
||||
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
||||
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
||||
image = image.resize((width, height))
|
||||
return image, height, width
|
||||
|
||||
image, height, width = aspect_ratio_resize(image, pipe)
|
||||
|
||||
prompt = "A person dancing energetically in a studio"
|
||||
negative_prompt = "blurry, low quality"
|
||||
|
||||
# Advanced: Use temporal guidance and custom callback
|
||||
def callback_fn(pipe, step_index, timestep, callback_kwargs):
|
||||
# You can modify latents or other tensors here
|
||||
print(f"Step {step_index}, Timestep {timestep}")
|
||||
return callback_kwargs
|
||||
|
||||
output = pipe(
|
||||
image=image,
|
||||
pose_video=pose_video,
|
||||
face_video=face_video,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
segment_frame_length=77,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=5.0,
|
||||
prev_segment_conditioning_frames=5, # Use 5 frames for temporal guidance (1 or 5 recommended)
|
||||
callback_on_step_end=callback_fn,
|
||||
callback_on_step_end_tensor_inputs=["latents"],
|
||||
).frames[0]
|
||||
export_to_video(output, "animated_advanced.mp4", fps=30)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
#### Key Parameters
|
||||
|
||||
- **mode**: Choose between `"animate"` (default) or `"replace"`
|
||||
- **prev_segment_conditioning_frames**: Number of frames for temporal guidance (1 or 5 recommended). Using 5 provides better temporal consistency but requires more memory
|
||||
- **guidance_scale**: Controls how closely the output follows the text prompt. Higher values (5-7) produce results more aligned with the prompt. For Wan-Animate, CFG is disabled by default (`guidance_scale=1.0`) but can be enabled to support negative prompts and finer control over facial expressions. (Note that CFG will only target the text prompt and face conditioning.)
|
||||
|
||||
|
||||
## Notes
|
||||
|
||||
- Wan2.1 supports LoRAs with [`~loaders.WanLoraLoaderMixin.load_lora_weights`].
|
||||
@@ -281,10 +484,10 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
|
||||
|
||||
# use "steamboat willie style" to trigger the LoRA
|
||||
prompt = """
|
||||
steamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot,
|
||||
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
|
||||
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
|
||||
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
|
||||
steamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot,
|
||||
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
|
||||
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
|
||||
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
|
||||
shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
|
||||
"""
|
||||
|
||||
@@ -359,6 +562,12 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## WanAnimatePipeline
|
||||
|
||||
[[autodoc]] WanAnimatePipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## WanPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput
|
||||
[[autodoc]] pipelines.wan.pipeline_output.WanPipelineOutput
|
||||
|
||||
@@ -0,0 +1,492 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
|
||||
# Building Custom Blocks
|
||||
|
||||
[ModularPipelineBlocks](./pipeline_block) are the fundamental building blocks of a [`ModularPipeline`]. You can create custom blocks by defining their inputs, outputs, and computation logic. This guide demonstrates how to create and use a custom block.
|
||||
|
||||
> [!TIP]
|
||||
> Explore the [Modular Diffusers Custom Blocks](https://huggingface.co/collections/diffusers/modular-diffusers-custom-blocks) collection for official custom modular blocks like Nano Banana.
|
||||
|
||||
## Project Structure
|
||||
|
||||
Your custom block project should use the following structure:
|
||||
|
||||
```shell
|
||||
.
|
||||
├── block.py
|
||||
└── modular_config.json
|
||||
```
|
||||
|
||||
- `block.py` contains the custom block implementation
|
||||
- `modular_config.json` contains the metadata needed to load the block
|
||||
|
||||
## Example: Florence 2 Inpainting Block
|
||||
|
||||
In this example we will create a custom block that uses the [Florence 2](https://huggingface.co/docs/transformers/model_doc/florence2) model to process an input image and generate a mask for inpainting.
|
||||
|
||||
The first step is to define the components that the block will use. In this case, we will need to use the `Florence2ForConditionalGeneration` model and its corresponding processor `AutoProcessor`. When defining components, we must specify the name of the component within our pipeline, model class via `type_hint`, and provide a `pretrained_model_name_or_path` for the component if we intend to load the model weights from a specific repository on the Hub.
|
||||
|
||||
```py
|
||||
# Inside block.py
|
||||
from diffusers.modular_pipelines import (
|
||||
ModularPipelineBlocks,
|
||||
ComponentSpec,
|
||||
)
|
||||
from transformers import AutoProcessor, Florence2ForConditionalGeneration
|
||||
|
||||
|
||||
class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def expected_components(self):
|
||||
return [
|
||||
ComponentSpec(
|
||||
name="image_annotator",
|
||||
type_hint=Florence2ForConditionalGeneration,
|
||||
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
|
||||
),
|
||||
ComponentSpec(
|
||||
name="image_annotator_processor",
|
||||
type_hint=AutoProcessor,
|
||||
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
|
||||
),
|
||||
]
|
||||
```
|
||||
|
||||
Next, we define the inputs and outputs of the block. The inputs include the image to be annotated, the annotation task, and the annotation prompt. The outputs include the generated mask image and annotations.
|
||||
|
||||
```py
|
||||
from typing import List, Union
|
||||
from PIL import Image, ImageDraw
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from diffusers.modular_pipelines import (
|
||||
PipelineState,
|
||||
ModularPipelineBlocks,
|
||||
InputParam,
|
||||
ComponentSpec,
|
||||
OutputParam,
|
||||
)
|
||||
from transformers import AutoProcessor, Florence2ForConditionalGeneration
|
||||
|
||||
|
||||
class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def expected_components(self):
|
||||
return [
|
||||
ComponentSpec(
|
||||
name="image_annotator",
|
||||
type_hint=Florence2ForConditionalGeneration,
|
||||
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
|
||||
),
|
||||
ComponentSpec(
|
||||
name="image_annotator_processor",
|
||||
type_hint=AutoProcessor,
|
||||
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"image",
|
||||
type_hint=Union[Image.Image, List[Image.Image]],
|
||||
required=True,
|
||||
description="Image(s) to annotate",
|
||||
),
|
||||
InputParam(
|
||||
"annotation_task",
|
||||
type_hint=Union[str, List[str]],
|
||||
required=True,
|
||||
default="<REFERRING_EXPRESSION_SEGMENTATION>",
|
||||
description="""Annotation Task to perform on the image.
|
||||
Supported Tasks:
|
||||
|
||||
<OD>
|
||||
<REFERRING_EXPRESSION_SEGMENTATION>
|
||||
<CAPTION>
|
||||
<DETAILED_CAPTION>
|
||||
<MORE_DETAILED_CAPTION>
|
||||
<DENSE_REGION_CAPTION>
|
||||
<CAPTION_TO_PHRASE_GROUNDING>
|
||||
<OPEN_VOCABULARY_DETECTION>
|
||||
|
||||
""",
|
||||
),
|
||||
InputParam(
|
||||
"annotation_prompt",
|
||||
type_hint=Union[str, List[str]],
|
||||
required=True,
|
||||
description="""Annotation Prompt to provide more context to the task.
|
||||
Can be used to detect or segment out specific elements in the image
|
||||
""",
|
||||
),
|
||||
InputParam(
|
||||
"annotation_output_type",
|
||||
type_hint=str,
|
||||
required=True,
|
||||
default="mask_image",
|
||||
description="""Output type from annotation predictions. Availabe options are
|
||||
mask_image:
|
||||
-black and white mask image for the given image based on the task type
|
||||
mask_overlay:
|
||||
- mask overlayed on the original image
|
||||
bounding_box:
|
||||
- bounding boxes drawn on the original image
|
||||
""",
|
||||
),
|
||||
InputParam(
|
||||
"annotation_overlay",
|
||||
type_hint=bool,
|
||||
required=True,
|
||||
default=False,
|
||||
description="",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"mask_image",
|
||||
type_hint=Image,
|
||||
description="Inpainting Mask for input Image(s)",
|
||||
),
|
||||
OutputParam(
|
||||
"annotations",
|
||||
type_hint=dict,
|
||||
description="Annotations Predictions for input Image(s)",
|
||||
),
|
||||
OutputParam(
|
||||
"image",
|
||||
type_hint=Image,
|
||||
description="Annotated input Image(s)",
|
||||
),
|
||||
]
|
||||
|
||||
```
|
||||
|
||||
Now we implement the `__call__` method, which contains the logic for processing the input image and generating the mask.
|
||||
|
||||
```py
|
||||
from typing import List, Union
|
||||
from PIL import Image, ImageDraw
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from diffusers.modular_pipelines import (
|
||||
PipelineState,
|
||||
ModularPipelineBlocks,
|
||||
InputParam,
|
||||
ComponentSpec,
|
||||
OutputParam,
|
||||
)
|
||||
from transformers import AutoProcessor, Florence2ForConditionalGeneration
|
||||
|
||||
|
||||
class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def expected_components(self):
|
||||
return [
|
||||
ComponentSpec(
|
||||
name="image_annotator",
|
||||
type_hint=Florence2ForConditionalGeneration,
|
||||
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
|
||||
),
|
||||
ComponentSpec(
|
||||
name="image_annotator_processor",
|
||||
type_hint=AutoProcessor,
|
||||
pretrained_model_name_or_path="florence-community/Florence-2-base-ft",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"image",
|
||||
type_hint=Union[Image.Image, List[Image.Image]],
|
||||
required=True,
|
||||
description="Image(s) to annotate",
|
||||
),
|
||||
InputParam(
|
||||
"annotation_task",
|
||||
type_hint=Union[str, List[str]],
|
||||
required=True,
|
||||
default="<REFERRING_EXPRESSION_SEGMENTATION>",
|
||||
description="""Annotation Task to perform on the image.
|
||||
Supported Tasks:
|
||||
|
||||
<OD>
|
||||
<REFERRING_EXPRESSION_SEGMENTATION>
|
||||
<CAPTION>
|
||||
<DETAILED_CAPTION>
|
||||
<MORE_DETAILED_CAPTION>
|
||||
<DENSE_REGION_CAPTION>
|
||||
<CAPTION_TO_PHRASE_GROUNDING>
|
||||
<OPEN_VOCABULARY_DETECTION>
|
||||
|
||||
""",
|
||||
),
|
||||
InputParam(
|
||||
"annotation_prompt",
|
||||
type_hint=Union[str, List[str]],
|
||||
required=True,
|
||||
description="""Annotation Prompt to provide more context to the task.
|
||||
Can be used to detect or segment out specific elements in the image
|
||||
""",
|
||||
),
|
||||
InputParam(
|
||||
"annotation_output_type",
|
||||
type_hint=str,
|
||||
required=True,
|
||||
default="mask_image",
|
||||
description="""Output type from annotation predictions. Availabe options are
|
||||
mask_image:
|
||||
-black and white mask image for the given image based on the task type
|
||||
mask_overlay:
|
||||
- mask overlayed on the original image
|
||||
bounding_box:
|
||||
- bounding boxes drawn on the original image
|
||||
""",
|
||||
),
|
||||
InputParam(
|
||||
"annotation_overlay",
|
||||
type_hint=bool,
|
||||
required=True,
|
||||
default=False,
|
||||
description="",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"mask_image",
|
||||
type_hint=Image,
|
||||
description="Inpainting Mask for input Image(s)",
|
||||
),
|
||||
OutputParam(
|
||||
"annotations",
|
||||
type_hint=dict,
|
||||
description="Annotations Predictions for input Image(s)",
|
||||
),
|
||||
OutputParam(
|
||||
"image",
|
||||
type_hint=Image,
|
||||
description="Annotated input Image(s)",
|
||||
),
|
||||
]
|
||||
|
||||
def get_annotations(self, components, images, prompts, task):
|
||||
task_prompts = [task + prompt for prompt in prompts]
|
||||
|
||||
inputs = components.image_annotator_processor(
|
||||
text=task_prompts, images=images, return_tensors="pt"
|
||||
).to(components.image_annotator.device, components.image_annotator.dtype)
|
||||
|
||||
generated_ids = components.image_annotator.generate(
|
||||
input_ids=inputs["input_ids"],
|
||||
pixel_values=inputs["pixel_values"],
|
||||
max_new_tokens=1024,
|
||||
early_stopping=False,
|
||||
do_sample=False,
|
||||
num_beams=3,
|
||||
)
|
||||
annotations = components.image_annotator_processor.batch_decode(
|
||||
generated_ids, skip_special_tokens=False
|
||||
)
|
||||
outputs = []
|
||||
for image, annotation in zip(images, annotations):
|
||||
outputs.append(
|
||||
components.image_annotator_processor.post_process_generation(
|
||||
annotation, task=task, image_size=(image.width, image.height)
|
||||
)
|
||||
)
|
||||
return outputs
|
||||
|
||||
def prepare_mask(self, images, annotations, overlay=False, fill="white"):
|
||||
masks = []
|
||||
for image, annotation in zip(images, annotations):
|
||||
mask_image = image.copy() if overlay else Image.new("L", image.size, 0)
|
||||
draw = ImageDraw.Draw(mask_image)
|
||||
|
||||
for _, _annotation in annotation.items():
|
||||
if "polygons" in _annotation:
|
||||
for polygon in _annotation["polygons"]:
|
||||
polygon = np.array(polygon).reshape(-1, 2)
|
||||
if len(polygon) < 3:
|
||||
continue
|
||||
polygon = polygon.reshape(-1).tolist()
|
||||
draw.polygon(polygon, fill=fill)
|
||||
|
||||
elif "bbox" in _annotation:
|
||||
bbox = _annotation["bbox"]
|
||||
draw.rectangle(bbox, fill="white")
|
||||
|
||||
masks.append(mask_image)
|
||||
|
||||
return masks
|
||||
|
||||
def prepare_bounding_boxes(self, images, annotations):
|
||||
outputs = []
|
||||
for image, annotation in zip(images, annotations):
|
||||
image_copy = image.copy()
|
||||
draw = ImageDraw.Draw(image_copy)
|
||||
for _, _annotation in annotation.items():
|
||||
bbox = _annotation["bbox"]
|
||||
label = _annotation["label"]
|
||||
|
||||
draw.rectangle(bbox, outline="red", width=3)
|
||||
draw.text((bbox[0], bbox[1] - 20), label, fill="red")
|
||||
|
||||
outputs.append(image_copy)
|
||||
|
||||
return outputs
|
||||
|
||||
def prepare_inputs(self, images, prompts):
|
||||
prompts = prompts or ""
|
||||
|
||||
if isinstance(images, Image.Image):
|
||||
images = [images]
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
if len(images) != len(prompts):
|
||||
raise ValueError("Number of images and annotation prompts must match.")
|
||||
|
||||
return images, prompts
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
images, annotation_task_prompt = self.prepare_inputs(
|
||||
block_state.image, block_state.annotation_prompt
|
||||
)
|
||||
task = block_state.annotation_task
|
||||
fill = block_state.fill
|
||||
|
||||
annotations = self.get_annotations(
|
||||
components, images, annotation_task_prompt, task
|
||||
)
|
||||
block_state.annotations = annotations
|
||||
if block_state.annotation_output_type == "mask_image":
|
||||
block_state.mask_image = self.prepare_mask(images, annotations)
|
||||
else:
|
||||
block_state.mask_image = None
|
||||
|
||||
if block_state.annotation_output_type == "mask_overlay":
|
||||
block_state.image = self.prepare_mask(images, annotations, overlay=True, fill=fill)
|
||||
|
||||
elif block_state.annotation_output_type == "bounding_box":
|
||||
block_state.image = self.prepare_bounding_boxes(images, annotations)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
```
|
||||
|
||||
Once we have defined our custom block, we can save it to the Hub, using either the CLI or the [`push_to_hub`] method. This will make it easy to share and reuse our custom block with other pipelines.
|
||||
|
||||
<hfoptions id="share">
|
||||
<hfoption id="hf CLI">
|
||||
|
||||
```shell
|
||||
# In the folder with the `block.py` file, run:
|
||||
diffusers-cli custom_block
|
||||
```
|
||||
|
||||
Then upload the block to the Hub:
|
||||
|
||||
```shell
|
||||
hf upload <your repo id> . .
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="push_to_hub">
|
||||
|
||||
```py
|
||||
from block import Florence2ImageAnnotatorBlock
|
||||
block = Florence2ImageAnnotatorBlock()
|
||||
block.push_to_hub("<your repo id>")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Using Custom Blocks
|
||||
|
||||
Load the custom block with [`~ModularPipelineBlocks.from_pretrained`] and set `trust_remote_code=True`.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.modular_pipelines import ModularPipelineBlocks, SequentialPipelineBlocks
|
||||
from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS
|
||||
from diffusers.utils import load_image
|
||||
|
||||
# Fetch the Florence2 image annotator block that will create our mask
|
||||
image_annotator_block = ModularPipelineBlocks.from_pretrained("diffusers/florence-2-custom-block", trust_remote_code=True)
|
||||
|
||||
my_blocks = INPAINT_BLOCKS.copy()
|
||||
# insert the annotation block before the image encoding step
|
||||
my_blocks.insert("image_annotator", image_annotator_block, 1)
|
||||
|
||||
# Create our initial set of inpainting blocks
|
||||
blocks = SequentialPipelineBlocks.from_blocks_dict(my_blocks)
|
||||
|
||||
repo_id = "diffusers/modular-stable-diffusion-xl-base-1.0"
|
||||
pipe = blocks.init_pipeline(repo_id)
|
||||
pipe.load_components(torch_dtype=torch.float16, device_map="cuda", trust_remote_code=True)
|
||||
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true")
|
||||
image = image.resize((1024, 1024))
|
||||
|
||||
prompt = ["A red car"]
|
||||
annotation_task = "<REFERRING_EXPRESSION_SEGMENTATION>"
|
||||
annotation_prompt = ["the car"]
|
||||
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
annotation_task=annotation_task,
|
||||
annotation_prompt=annotation_prompt,
|
||||
annotation_output_type="mask_image",
|
||||
num_inference_steps=35,
|
||||
guidance_scale=7.5,
|
||||
strength=0.95,
|
||||
output="images"
|
||||
)
|
||||
output[0].save("florence-inpainting.png")
|
||||
```
|
||||
|
||||
## Editing Custom Blocks
|
||||
|
||||
By default, custom blocks are saved in your cache directory. Use the `local_dir` argument to download and edit a custom block in a specific folder.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.modular_pipelines import ModularPipelineBlocks, SequentialPipelineBlocks
|
||||
from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS
|
||||
from diffusers.utils import load_image
|
||||
|
||||
# Fetch the Florence2 image annotator block that will create our mask
|
||||
image_annotator_block = ModularPipelineBlocks.from_pretrained("diffusers/florence-2-custom-block", trust_remote_code=True, local_dir="/my-local-folder")
|
||||
```
|
||||
|
||||
Any changes made to the block files in this folder will be reflected when you load the block again.
|
||||
@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# LoopSequentialPipelineBlocks
|
||||
|
||||
[`~modular_pipelines.LoopSequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a loop. Data flows circularly, using `intermediate_inputs` and `intermediate_outputs`, and each block is run iteratively. This is typically used to create a denoising loop which is iterative by default.
|
||||
[`~modular_pipelines.LoopSequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a loop. Data flows circularly, using `inputs` and `intermediate_outputs`, and each block is run iteratively. This is typically used to create a denoising loop which is iterative by default.
|
||||
|
||||
This guide shows you how to create [`~modular_pipelines.LoopSequentialPipelineBlocks`].
|
||||
|
||||
@@ -21,7 +21,6 @@ This guide shows you how to create [`~modular_pipelines.LoopSequentialPipelineBl
|
||||
[`~modular_pipelines.LoopSequentialPipelineBlocks`], is also known as the *loop wrapper* because it defines the loop structure, iteration variables, and configuration. Within the loop wrapper, you need the following variables.
|
||||
|
||||
- `loop_inputs` are user provided values and equivalent to [`~modular_pipelines.ModularPipelineBlocks.inputs`].
|
||||
- `loop_intermediate_inputs` are intermediate variables from the [`~modular_pipelines.PipelineState`] and equivalent to [`~modular_pipelines.ModularPipelineBlocks.intermediate_inputs`].
|
||||
- `loop_intermediate_outputs` are new intermediate variables created by the block and added to the [`~modular_pipelines.PipelineState`]. It is equivalent to [`~modular_pipelines.ModularPipelineBlocks.intermediate_outputs`].
|
||||
- `__call__` method defines the loop structure and iteration logic.
|
||||
|
||||
@@ -90,4 +89,4 @@ Add more loop blocks to run within each iteration with [`~modular_pipelines.Loop
|
||||
|
||||
```py
|
||||
loop = LoopWrapper.from_blocks_dict({"block1": LoopBlock(), "block2": LoopBlock})
|
||||
```
|
||||
```
|
||||
|
||||
@@ -37,17 +37,7 @@ A [`~modular_pipelines.ModularPipelineBlocks`] requires `inputs`, and `intermedi
|
||||
]
|
||||
```
|
||||
|
||||
- `intermediate_inputs` are values typically created from a previous block but it can also be directly provided if no preceding block generates them. Unlike `inputs`, `intermediate_inputs` can be modified.
|
||||
|
||||
Use `InputParam` to define `intermediate_inputs`.
|
||||
|
||||
```py
|
||||
user_intermediate_inputs = [
|
||||
InputParam(name="processed_image", type_hint="torch.Tensor", description="image that has been preprocessed and normalized"),
|
||||
]
|
||||
```
|
||||
|
||||
- `intermediate_outputs` are new values created by a block and added to the [`~modular_pipelines.PipelineState`]. The `intermediate_outputs` are available as `intermediate_inputs` for subsequent blocks or available as the final output from running the pipeline.
|
||||
- `intermediate_outputs` are new values created by a block and added to the [`~modular_pipelines.PipelineState`]. The `intermediate_outputs` are available as `inputs` for subsequent blocks or available as the final output from running the pipeline.
|
||||
|
||||
Use `OutputParam` to define `intermediate_outputs`.
|
||||
|
||||
@@ -65,8 +55,8 @@ The intermediate inputs and outputs share data to connect blocks. They are acces
|
||||
|
||||
The computation a block performs is defined in the `__call__` method and it follows a specific structure.
|
||||
|
||||
1. Retrieve the [`~modular_pipelines.BlockState`] to get a local view of the `inputs` and `intermediate_inputs`.
|
||||
2. Implement the computation logic on the `inputs` and `intermediate_inputs`.
|
||||
1. Retrieve the [`~modular_pipelines.BlockState`] to get a local view of the `inputs`
|
||||
2. Implement the computation logic on the `inputs`.
|
||||
3. Update [`~modular_pipelines.PipelineState`] to push changes from the local [`~modular_pipelines.BlockState`] back to the global [`~modular_pipelines.PipelineState`].
|
||||
4. Return the components and state which becomes available to the next block.
|
||||
|
||||
@@ -76,7 +66,7 @@ def __call__(self, components, state):
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# Your computation logic here
|
||||
# block_state contains all your inputs and intermediate_inputs
|
||||
# block_state contains all your inputs
|
||||
# Access them like: block_state.image, block_state.processed_image
|
||||
|
||||
# Update the pipeline state with your updated block_states
|
||||
@@ -112,4 +102,4 @@ def __call__(self, components, state):
|
||||
unet = components.unet
|
||||
vae = components.vae
|
||||
scheduler = components.scheduler
|
||||
```
|
||||
```
|
||||
|
||||
@@ -183,7 +183,7 @@ from diffusers.modular_pipelines import ComponentsManager
|
||||
components = ComponentManager()
|
||||
|
||||
dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", components_manager=components, collection="diffdiff")
|
||||
dd_pipeline.load_default_componenets(torch_dtype=torch.float16)
|
||||
dd_pipeline.load_componenets(torch_dtype=torch.float16)
|
||||
dd_pipeline.to("cuda")
|
||||
```
|
||||
|
||||
|
||||
@@ -12,11 +12,11 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# SequentialPipelineBlocks
|
||||
|
||||
[`~modular_pipelines.SequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a sequence. Data flows linearly from one block to the next using `intermediate_inputs` and `intermediate_outputs`. Each block in [`~modular_pipelines.SequentialPipelineBlocks`] usually represents a step in the pipeline, and by combining them, you gradually build a pipeline.
|
||||
[`~modular_pipelines.SequentialPipelineBlocks`] are a multi-block type that composes other [`~modular_pipelines.ModularPipelineBlocks`] together in a sequence. Data flows linearly from one block to the next using `inputs` and `intermediate_outputs`. Each block in [`~modular_pipelines.SequentialPipelineBlocks`] usually represents a step in the pipeline, and by combining them, you gradually build a pipeline.
|
||||
|
||||
This guide shows you how to connect two blocks into a [`~modular_pipelines.SequentialPipelineBlocks`].
|
||||
|
||||
Create two [`~modular_pipelines.ModularPipelineBlocks`]. The first block, `InputBlock`, outputs a `batch_size` value and the second block, `ImageEncoderBlock` uses `batch_size` as `intermediate_inputs`.
|
||||
Create two [`~modular_pipelines.ModularPipelineBlocks`]. The first block, `InputBlock`, outputs a `batch_size` value and the second block, `ImageEncoderBlock` uses `batch_size` as `inputs`.
|
||||
|
||||
<hfoptions id="sequential">
|
||||
<hfoption id="InputBlock">
|
||||
@@ -110,4 +110,4 @@ Inspect the sub-blocks in [`~modular_pipelines.SequentialPipelineBlocks`] by cal
|
||||
```py
|
||||
print(blocks)
|
||||
print(blocks.doc)
|
||||
```
|
||||
```
|
||||
|
||||
@@ -21,6 +21,7 @@ Refer to the table below for an overview of the available attention families and
|
||||
| attention family | main feature |
|
||||
|---|---|
|
||||
| FlashAttention | minimizes memory reads/writes through tiling and recomputation |
|
||||
| AI Tensor Engine for ROCm | FlashAttention implementation optimized for AMD ROCm accelerators |
|
||||
| SageAttention | quantizes attention to int8 |
|
||||
| PyTorch native | built-in PyTorch implementation using [scaled_dot_product_attention](./fp16#scaled-dot-product-attention) |
|
||||
| xFormers | memory-efficient attention with support for various attention kernels |
|
||||
@@ -138,11 +139,14 @@ Refer to the table below for a complete list of available attention backends and
|
||||
| `_native_npu` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | NPU-optimized attention |
|
||||
| `_native_xla` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | XLA-optimized attention |
|
||||
| `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 |
|
||||
| `flash_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 from kernels |
|
||||
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
|
||||
| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm |
|
||||
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
|
||||
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
|
||||
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |
|
||||
| `sage` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) |
|
||||
| `sage_hub` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) from kernels |
|
||||
| `sage_varlen` | [SageAttention](https://github.com/thu-ml/SageAttention) | Variable length SageAttention |
|
||||
| `_sage_qk_int8_pv_fp8_cuda` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (CUDA) |
|
||||
| `_sage_qk_int8_pv_fp8_cuda_sm90` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (SM90) |
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# AutoModel
|
||||
|
||||
The [`AutoModel`] class automatically detects and loads the correct model class (UNet, transformer, VAE) from a `config.json` file. You don't need to know the specific model class name ahead of time. It supports data types and device placement, and works across model types and libraries.
|
||||
|
||||
The example below loads a transformer from Diffusers and a text encoder from Transformers. Use the `subfolder` parameter to specify where to load the `config.json` file from.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import AutoModel, DiffusionPipeline
|
||||
|
||||
transformer = AutoModel.from_pretrained(
|
||||
"Qwen/Qwen-Image", subfolder="transformer", torch_dtype=torch.bfloat16, device_map="cuda"
|
||||
)
|
||||
|
||||
text_encoder = AutoModel.from_pretrained(
|
||||
"Qwen/Qwen-Image", subfolder="text_encoder", torch_dtype=torch.bfloat16, device_map="cuda"
|
||||
)
|
||||
```
|
||||
|
||||
[`AutoModel`] also loads models from the [Hub](https://huggingface.co/models) that aren't included in Diffusers. Set `trust_remote_code=True` in [`AutoModel.from_pretrained`] to load custom models.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import AutoModel
|
||||
|
||||
transformer = AutoModel.from_pretrained(
|
||||
"custom/custom-transformer-model", trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="cuda"
|
||||
)
|
||||
```
|
||||
|
||||
If the custom model inherits from the [`ModelMixin`] class, it gets access to the same features as Diffusers model classes, like [regional compilation](../optimization/fp16#regional-compilation) and [group offloading](../optimization/memory#group-offloading).
|
||||
|
||||
> [!NOTE]
|
||||
> Learn more about implementing custom models in the [Community components](../using-diffusers/custom_pipeline_overview#community-components) guide.
|
||||
@@ -1,8 +1,10 @@
|
||||
- sections:
|
||||
- local: index
|
||||
title: 🧨 Diffusers
|
||||
- local: quicktour
|
||||
title: Tour rápido
|
||||
- local: installation
|
||||
title: Instalação
|
||||
- local: index
|
||||
title: Diffusers
|
||||
- local: installation
|
||||
title: Instalação
|
||||
- local: quicktour
|
||||
title: Tour rápido
|
||||
- local: stable_diffusion
|
||||
title: Desempenho básico
|
||||
title: Primeiros passos
|
||||
|
||||
@@ -18,11 +18,11 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# Diffusers
|
||||
|
||||
🤗 Diffusers é uma biblioteca de modelos de difusão de última geração para geração de imagens, áudio e até mesmo estruturas 3D de moléculas. Se você está procurando uma solução de geração simples ou queira treinar seu próprio modelo de difusão, 🤗 Diffusers é uma modular caixa de ferramentas que suporta ambos. Nossa biblioteca é desenhada com foco em [usabilidade em vez de desempenho](conceptual/philosophy#usability-over-performance), [simples em vez de fácil](conceptual/philosophy#simple-over-easy) e [customizável em vez de abstrações](conceptual/philosophy#tweakable-contributorfriendly-over-abstraction).
|
||||
🤗 Diffusers é uma biblioteca de modelos de difusão de última geração para geração de imagens, áudio e até mesmo estruturas 3D de moléculas. Se você está procurando uma solução de geração simples ou quer treinar seu próprio modelo de difusão, 🤗 Diffusers é uma caixa de ferramentas modular que suporta ambos. Nossa biblioteca é desenhada com foco em [usabilidade em vez de desempenho](conceptual/philosophy#usability-over-performance), [simples em vez de fácil](conceptual/philosophy#simple-over-easy) e [customizável em vez de abstrações](conceptual/philosophy#tweakable-contributorfriendly-over-abstraction).
|
||||
|
||||
A Biblioteca tem três componentes principais:
|
||||
|
||||
- Pipelines de última geração para a geração em poucas linhas de código. Têm muitos pipelines no 🤗 Diffusers, veja a tabela no pipeline [Visão geral](api/pipelines/overview) para uma lista completa de pipelines disponíveis e as tarefas que eles resolvem.
|
||||
- Pipelines de última geração para a geração em poucas linhas de código. Há muitos pipelines no 🤗 Diffusers, veja a tabela no pipeline [Visão geral](api/pipelines/overview) para uma lista completa de pipelines disponíveis e as tarefas que eles resolvem.
|
||||
- Intercambiáveis [agendadores de ruído](api/schedulers/overview) para balancear as compensações entre velocidade e qualidade de geração.
|
||||
- [Modelos](api/models) pré-treinados que podem ser usados como se fossem blocos de construção, e combinados com agendadores, para criar seu próprio sistema de difusão de ponta a ponta.
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
Recomenda-se instalar 🤗 Diffusers em um [ambiente virtual](https://docs.python.org/3/library/venv.html).
|
||||
Se você não está familiarizado com ambiente virtuals, veja o [guia](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).
|
||||
Um ambiente virtual deixa mais fácil gerenciar diferentes projetos e evitar problemas de compatibilidade entre dependências.
|
||||
Um ambiente virtual facilita gerenciar diferentes projetos e evitar problemas de compatibilidade entre dependências.
|
||||
|
||||
Comece criando um ambiente virtual no diretório do projeto:
|
||||
|
||||
@@ -100,12 +100,12 @@ pip install -e ".[flax]"
|
||||
</jax>
|
||||
</frameworkcontent>
|
||||
|
||||
Esses comandos irá linkar a pasta que você clonou o repositório e os caminhos das suas bibliotecas Python.
|
||||
Esses comandos irão vincular a pasta que você clonou o repositório e os caminhos das suas bibliotecas Python.
|
||||
Python então irá procurar dentro da pasta que você clonou além dos caminhos normais das bibliotecas.
|
||||
Por exemplo, se o pacote python for tipicamente instalado no `~/anaconda3/envs/main/lib/python3.10/site-packages/`, o Python também irá procurar na pasta `~/diffusers/` que você clonou.
|
||||
|
||||
> [!WARNING]
|
||||
> Você deve deixar a pasta `diffusers` se você quiser continuar usando a biblioteca.
|
||||
> Você deve manter a pasta `diffusers` se quiser continuar usando a biblioteca.
|
||||
|
||||
Agora você pode facilmente atualizar seu clone para a última versão do 🤗 Diffusers com o seguinte comando:
|
||||
|
||||
|
||||
@@ -0,0 +1,132 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
[[open-in-colab]]
|
||||
|
||||
# Desempenho básico
|
||||
|
||||
Difusão é um processo aleatório que demanda muito processamento. Você pode precisar executar o [`DiffusionPipeline`] várias vezes antes de obter o resultado desejado. Por isso é importante equilibrar cuidadosamente a velocidade de geração e o uso de memória para iterar mais rápido.
|
||||
|
||||
Este guia recomenda algumas dicas básicas de desempenho para usar o [`DiffusionPipeline`]. Consulte a seção de documentação sobre Otimização de Inferência, como [Acelerar inferência](./optimization/fp16) ou [Reduzir uso de memória](./optimization/memory) para guias de desempenho mais detalhados.
|
||||
|
||||
## Uso de memória
|
||||
|
||||
Reduzir a quantidade de memória usada indiretamente acelera a geração e pode ajudar um modelo a caber no dispositivo.
|
||||
|
||||
O método [`~DiffusionPipeline.enable_model_cpu_offload`] move um modelo para a CPU quando não está em uso para economizar memória da GPU.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda"
|
||||
)
|
||||
pipeline.enable_model_cpu_offload()
|
||||
|
||||
prompt = """
|
||||
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
|
||||
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
|
||||
"""
|
||||
pipeline(prompt).images[0]
|
||||
print(f"Memória máxima reservada: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
|
||||
```
|
||||
|
||||
## Velocidade de inferência
|
||||
|
||||
O processo de remoção de ruído é o mais exigente computacionalmente durante a difusão. Métodos que otimizam este processo aceleram a velocidade de inferência. Experimente os seguintes métodos para acelerar.
|
||||
|
||||
- Adicione `device_map="cuda"` para colocar o pipeline em uma GPU. Colocar um modelo em um acelerador, como uma GPU, aumenta a velocidade porque realiza computações em paralelo.
|
||||
- Defina `torch_dtype=torch.bfloat16` para executar o pipeline em meia-precisão. Reduzir a precisão do tipo de dado aumenta a velocidade porque leva menos tempo para realizar computações em precisão mais baixa.
|
||||
|
||||
```py
|
||||
import torch
|
||||
import time
|
||||
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda"
|
||||
)
|
||||
```
|
||||
|
||||
- Use um agendador mais rápido, como [`DPMSolverMultistepScheduler`], que requer apenas ~20-25 passos.
|
||||
- Defina `num_inference_steps` para um valor menor. Reduzir o número de passos de inferência reduz o número total de computações. No entanto, isso pode resultar em menor qualidade de geração.
|
||||
|
||||
```py
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
||||
|
||||
prompt = """
|
||||
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
|
||||
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
|
||||
"""
|
||||
|
||||
start_time = time.perf_counter()
|
||||
image = pipeline(prompt).images[0]
|
||||
end_time = time.perf_counter()
|
||||
|
||||
print(f"Geração de imagem levou {end_time - start_time:.3f} segundos")
|
||||
```
|
||||
|
||||
## Qualidade de geração
|
||||
|
||||
Muitos modelos de difusão modernos entregam imagens de alta qualidade imediatamente. No entanto, você ainda pode melhorar a qualidade de geração experimentando o seguinte.
|
||||
|
||||
- Experimente um prompt mais detalhado e descritivo. Inclua detalhes como o meio da imagem, assunto, estilo e estética. Um prompt negativo também pode ajudar, guiando um modelo para longe de características indesejáveis usando palavras como baixa qualidade ou desfocado.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda"
|
||||
)
|
||||
|
||||
prompt = """
|
||||
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
|
||||
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
|
||||
"""
|
||||
negative_prompt = "low quality, blurry, ugly, poor details"
|
||||
pipeline(prompt, negative_prompt=negative_prompt).images[0]
|
||||
```
|
||||
|
||||
Para mais detalhes sobre como criar prompts melhores, consulte a documentação sobre [Técnicas de prompt](./using-diffusers/weighted_prompts).
|
||||
|
||||
- Experimente um agendador diferente, como [`HeunDiscreteScheduler`] ou [`LMSDiscreteScheduler`], que sacrifica velocidade de geração por qualidade.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline, HeunDiscreteScheduler
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda"
|
||||
)
|
||||
pipeline.scheduler = HeunDiscreteScheduler.from_config(pipeline.scheduler.config)
|
||||
|
||||
prompt = """
|
||||
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
|
||||
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
|
||||
"""
|
||||
negative_prompt = "low quality, blurry, ugly, poor details"
|
||||
pipeline(prompt, negative_prompt=negative_prompt).images[0]
|
||||
```
|
||||
|
||||
## Próximos passos
|
||||
|
||||
Diffusers oferece otimizações mais avançadas e poderosas, como [group-offloading](./optimization/memory#group-offloading) e [compilação regional](./optimization/fp16#regional-compilation). Para saber mais sobre como maximizar o desempenho, consulte a seção sobre Otimização de Inferência.
|
||||
@@ -88,7 +88,7 @@ PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixar
|
||||
| FaithDiff Stable Diffusion XL Pipeline | Implementation of [(CVPR 2025) FaithDiff: Unleashing Diffusion Priors for Faithful Image Super-resolutionUnleashing Diffusion Priors for Faithful Image Super-resolution](https://huggingface.co/papers/2411.18824) - FaithDiff is a faithful image super-resolution method that leverages latent diffusion models by actively adapting the diffusion prior and jointly fine-tuning its components (encoder and diffusion model) with an alignment module to ensure high fidelity and structural consistency. | [FaithDiff Stable Diffusion XL Pipeline](#faithdiff-stable-diffusion-xl-pipeline) | [](https://huggingface.co/jychen9811/FaithDiff) | [Junyang Chen, Jinshan Pan, Jiangxin Dong, IMAG Lab, (Adapted by Eliseu Silva)](https://github.com/JyChen9811/FaithDiff) |
|
||||
| Stable Diffusion 3 InstructPix2Pix Pipeline | Implementation of Stable Diffusion 3 InstructPix2Pix Pipeline | [Stable Diffusion 3 InstructPix2Pix Pipeline](#stable-diffusion-3-instructpix2pix-pipeline) | [](https://huggingface.co/BleachNick/SD3_UltraEdit_freeform) [](https://huggingface.co/CaptainZZZ/sd3-instructpix2pix) | [Jiayu Zhang](https://github.com/xduzhangjiayu) and [Haozhe Zhao](https://github.com/HaozheZhao)|
|
||||
| Flux Kontext multiple images | A modified version of the `FluxKontextPipeline` that supports calling Flux Kontext with multiple reference images.| [Flux Kontext multiple input Pipeline](#flux-kontext-multiple-images) | - | [Net-Mist](https://github.com/Net-Mist) |
|
||||
|
||||
| Flux Fill ControlNet Pipeline | A modified version of the `FluxFillPipeline` and `FluxControlNetInpaintPipeline` that supports Controlnet with Flux Fill model.| [Flux Fill ControlNet Pipeline](#Flux-Fill-ControlNet-Pipeline) | - | [pratim4dasude](https://github.com/pratim4dasude) |
|
||||
|
||||
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
|
||||
|
||||
@@ -5488,7 +5488,7 @@ Editing at Scale", many thanks to their contribution!
|
||||
|
||||
This implementation of Flux Kontext allows users to pass multiple reference images. Each image is encoded separately, and the resulting latent vectors are concatenated.
|
||||
|
||||
As explained in Section 3 of [the paper](https://arxiv.org/pdf/2506.15742), the model's sequence concatenation mechanism can extend its capabilities to handle multiple reference images. However, note that the current version of Flux Kontext was not trained for this use case. In practice, stacking along the first axis does not yield correct results, while stacking along the other two axes appears to work.
|
||||
As explained in Section 3 of [the paper](https://huggingface.co/papers/2506.15742), the model's sequence concatenation mechanism can extend its capabilities to handle multiple reference images. However, note that the current version of Flux Kontext was not trained for this use case. In practice, stacking along the first axis does not yield correct results, while stacking along the other two axes appears to work.
|
||||
|
||||
## Example Usage
|
||||
|
||||
@@ -5527,3 +5527,106 @@ images = pipe(
|
||||
).images
|
||||
images[0].save("pizzeria.png")
|
||||
```
|
||||
|
||||
# Flux Fill ControlNet Pipeline
|
||||
|
||||
This implementation of Flux Fill + ControlNet Inpaint combines the fill-style masked editing of FLUX.1-Fill-dev with full ControlNet conditioning. The base image is processed through the Fill model while the ControlNet receives the corresponding conditioning input (depth, canny, pose, etc.), and both outputs are fused during denoising to guide structure and composition.
|
||||
|
||||
While FLUX.1-Fill-dev is designed for mask-based edits, it was not originally trained to operate jointly with ControlNet. In practice, this combined setup works well for structured inpainting tasks, though results may vary depending on the conditioning strength and the alignment between the mask and the control input.
|
||||
|
||||
## Example Usage
|
||||
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import (
|
||||
FluxControlNetModel,
|
||||
FluxPriorReduxPipeline,
|
||||
)
|
||||
from diffusers.utils import load_image
|
||||
|
||||
# NEW PIPELINE (updated name)
|
||||
from pipline_flux_fill_controlnet_Inpaint import FluxControlNetFillInpaintPipeline
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
# Models
|
||||
base_model = "black-forest-labs/FLUX.1-Fill-dev"
|
||||
controlnet_model = "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro-2.0"
|
||||
prior_model = "black-forest-labs/FLUX.1-Redux-dev"
|
||||
|
||||
# Load ControlNet
|
||||
controlnet = FluxControlNetModel.from_pretrained(
|
||||
controlnet_model,
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
|
||||
# Load Fill + ControlNet Pipeline
|
||||
fill_pipe = FluxControlNetFillInpaintPipeline.from_pretrained(
|
||||
base_model,
|
||||
controlnet=controlnet,
|
||||
torch_dtype=dtype,
|
||||
).to(device)
|
||||
|
||||
# OPTIONAL FP8
|
||||
# fill_pipe.transformer.enable_layerwise_casting(
|
||||
# storage_dtype=torch.float8_e4m3fn,
|
||||
# compute_dtype=torch.bfloat16
|
||||
# )
|
||||
|
||||
# OPTIONAL Prior Redux
|
||||
#pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(
|
||||
# prior_model,
|
||||
# torch_dtype=dtype,
|
||||
#).to(device)
|
||||
|
||||
# Inputs
|
||||
|
||||
# combined_image = load_image("person_input.png")
|
||||
|
||||
|
||||
# 1. Prior conditioning
|
||||
#prior_out = pipe_prior_redux(
|
||||
# image=cloth_image,
|
||||
# prompt=cloth_prompt,
|
||||
#)
|
||||
|
||||
# 2. Fill Inpaint with ControlNet
|
||||
|
||||
# canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6).
|
||||
|
||||
img = load_image(r"imgs/background.jpg")
|
||||
mask = load_image(r"imgs/mask.png")
|
||||
|
||||
control_image_depth = load_image(r"imgs/dog_depth _2.png")
|
||||
|
||||
result = fill_pipe(
|
||||
prompt="a dog on a bench",
|
||||
image=img,
|
||||
mask_image=mask,
|
||||
|
||||
control_image=control_image_depth,
|
||||
control_mode=[2], # union mode
|
||||
control_guidance_start=0.0,
|
||||
control_guidance_end=0.8,
|
||||
controlnet_conditioning_scale=0.9,
|
||||
|
||||
height=1024,
|
||||
width=1024,
|
||||
|
||||
strength=1.0,
|
||||
guidance_scale=50.0,
|
||||
num_inference_steps=60,
|
||||
max_sequence_length=512,
|
||||
|
||||
# **prior_out,
|
||||
)
|
||||
|
||||
# result.images[0].save("flux_fill_controlnet_inpaint.png")
|
||||
|
||||
from datetime import datetime
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
result.images[0].save(f"flux_fill_controlnet_inpaint_depth{timestamp}.jpg")
|
||||
```
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ def check_size(image, height, width):
|
||||
raise ValueError(f"Image size should be {height}x{width}, but got {h}x{w}")
|
||||
|
||||
|
||||
def overlay_inner_image(image, inner_image, paste_offset: Tuple[int] = (0, 0)):
|
||||
def overlay_inner_image(image, inner_image, paste_offset: Tuple[int, ...] = (0, 0)):
|
||||
inner_image = inner_image.convert("RGBA")
|
||||
image = image.convert("RGB")
|
||||
|
||||
|
||||
@@ -1966,16 +1966,21 @@ class MatryoshkaUNet2DConditionModel(
|
||||
center_input_sample: bool = False,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str] = (
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
"UpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: Union[int, Tuple[int]] = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
@@ -2294,10 +2299,10 @@ class MatryoshkaUNet2DConditionModel(
|
||||
|
||||
def _check_config(
|
||||
self,
|
||||
down_block_types: Tuple[str],
|
||||
up_block_types: Tuple[str],
|
||||
down_block_types: Tuple[str, ...],
|
||||
up_block_types: Tuple[str, ...],
|
||||
only_cross_attention: Union[bool, Tuple[bool]],
|
||||
block_out_channels: Tuple[int],
|
||||
block_out_channels: Tuple[int, ...],
|
||||
layers_per_block: Union[int, Tuple[int]],
|
||||
cross_attention_dim: Union[int, Tuple[int]],
|
||||
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
|
||||
|
||||
@@ -438,16 +438,21 @@ class UNet2DConditionModel(OriginalUNet2DConditionModel, ConfigMixin, UNet2DCond
|
||||
center_input_sample: bool = False,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str] = (
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
"UpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: Union[int, Tuple[int]] = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -490,7 +490,7 @@ class RegionalPromptingStableDiffusionPipeline(
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
@@ -841,7 +841,7 @@ class RegionalPromptingStableDiffusionPipeline(
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
||||
Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only applies
|
||||
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
@@ -872,7 +872,7 @@ class RegionalPromptingStableDiffusionPipeline(
|
||||
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
||||
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
|
||||
Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
|
||||
using zero terminal SNR.
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
@@ -1062,7 +1062,7 @@ class RegionalPromptingStableDiffusionPipeline(
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
# Based on 3.4. in https://huggingface.co/papers/2305.08891
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
@@ -1668,7 +1668,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
r"""
|
||||
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
|
||||
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
|
||||
Args:
|
||||
noise_cfg (`torch.Tensor`):
|
||||
|
||||
@@ -268,12 +268,11 @@ provide a simple script for LoRA fine-tuning Kontext in [train_dreambooth_lora_f
|
||||
**important**
|
||||
|
||||
> [!NOTE]
|
||||
> To make sure you can successfully run the latest version of the kontext example script, we highly recommend installing from source, specifically from the commit mentioned below.
|
||||
> To make sure you can successfully run the latest version of the kontext example script, we highly recommend installing from source.
|
||||
> To do this, execute the following steps in a new virtual environment:
|
||||
> ```
|
||||
> git clone https://github.com/huggingface/diffusers
|
||||
> cd diffusers
|
||||
> git checkout 05e7a854d0a5661f5b433f6dd5954c224b104f0b
|
||||
> pip install -e .
|
||||
> ```
|
||||
|
||||
|
||||
@@ -0,0 +1,315 @@
|
||||
# DreamBooth training example for FLUX.2 [dev]
|
||||
|
||||
[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize image generation models given just a few (3~5) images of a subject/concept.
|
||||
|
||||
The `train_dreambooth_lora_flux2.py` script shows how to implement the training procedure for [LoRAs](https://huggingface.co/blog/lora) and adapt it for [FLUX.2 [dev]](https://github.com/black-forest-labs/flux2).
|
||||
|
||||
> [!NOTE]
|
||||
> **Memory consumption**
|
||||
>
|
||||
> Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements -
|
||||
> a LoRA with a rank of 16 can exceed XXGB of VRAM for training. below we provide some tips and tricks to reduce memory consumption during training.
|
||||
|
||||
> For more tips & guidance on training on a resource-constrained device and general good practices please check out these great guides and trainers for FLUX:
|
||||
> 1) [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX2.md)
|
||||
> 2) [`ostris`'s guide](https://github.com/ostris/ai-toolkit?tab=readme-ov-file#flux2-training)
|
||||
|
||||
> [!NOTE]
|
||||
> **Gated model**
|
||||
>
|
||||
> As the model is gated, before using it with diffusers you first need to go to the [FLUX.2 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.2-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
|
||||
|
||||
```bash
|
||||
hf auth login
|
||||
```
|
||||
|
||||
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then cd in the `examples/dreambooth` folder and run
|
||||
```bash
|
||||
pip install -r requirements_flux.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Or for a default accelerate configuration without answering questions about your environment
|
||||
|
||||
```bash
|
||||
accelerate config default
|
||||
```
|
||||
|
||||
Or if your environment doesn't support an interactive shell (e.g., a notebook)
|
||||
|
||||
```python
|
||||
from accelerate.utils import write_basic_config
|
||||
write_basic_config()
|
||||
```
|
||||
|
||||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
|
||||
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
|
||||
|
||||
|
||||
### Dog toy example
|
||||
|
||||
Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
|
||||
|
||||
Let's first download it locally:
|
||||
|
||||
```python
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
local_dir = "./dog"
|
||||
snapshot_download(
|
||||
"diffusers/dog-example",
|
||||
local_dir=local_dir, repo_type="dataset",
|
||||
ignore_patterns=".gitattributes",
|
||||
)
|
||||
```
|
||||
|
||||
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
|
||||
|
||||
As mentioned, Flux2 LoRA training is *very* memory intensive. Here are memory optimizations we can use (some still experimental) for a more memory efficient training:
|
||||
|
||||
## Memory Optimizations
|
||||
> [!NOTE] many of these techniques complement each other and can be used together to further reduce memory consumption.
|
||||
> However some techniques may be mutually exclusive so be sure to check before launching a training run.
|
||||
### Remote Text Encoder
|
||||
Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--remote_text_encoder` flag to enable remote computation of the prompt embeddings using the HuggingFace Inference API.
|
||||
This way, the text encoder model is not loaded into memory during training.
|
||||
> [!NOTE]
|
||||
> to enable remote text encoding you must either be logged in to your HuggingFace account (`hf auth login`) OR pass a token with `--hub_token`.
|
||||
### CPU Offloading
|
||||
To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the vae and text encoder to CPU memory and only move them to GPU when needed.
|
||||
### Latent Caching
|
||||
Pre-encode the training images with the vae, and then delete it to free up some memory. To enable `latent_caching` simply pass `--cache_latents`.
|
||||
### QLoRA: Low Precision Training with Quantization
|
||||
Perform low precision training using 8-bit or 4-bit quantization to reduce memory usage. You can use the following flags:
|
||||
- **FP8 training** with `torchao`:
|
||||
enable FP8 training by passing `--do_fp8_training`.
|
||||
> [!IMPORTANT] Since we are utilizing FP8 tensor cores we need CUDA GPUs with compute capability at least 8.9 or greater.
|
||||
> If you're looking for memory-efficient training on relatively older cards, we encourage you to check out other trainers like SimpleTuner, ai-toolkit, etc.
|
||||
- **NF4 training** with `bitsandbytes`:
|
||||
Alternatively, you can use 8-bit or 4-bit quantization with `bitsandbytes` by passing:
|
||||
`--bnb_quantization_config_path` to enable 4-bit NF4 quantization.
|
||||
### Gradient Checkpointing and Accumulation
|
||||
* `--gradient accumulation` refers to the number of updates steps to accumulate before performing a backward/update pass.
|
||||
by passing a value > 1 you can reduce the amount of backward/update passes and hence also memory reqs.
|
||||
* with `--gradient checkpointing` we can save memory by not storing all intermediate activations during the forward pass.
|
||||
Instead, only a subset of these activations (the checkpoints) are stored and the rest is recomputed as needed during the backward pass. Note that this comes at the expanse of a slower backward pass.
|
||||
### 8-bit-Adam Optimizer
|
||||
When training with `AdamW`(doesn't apply to `prodigy`) You can pass `--use_8bit_adam` to reduce the memory requirements of training.
|
||||
Make sure to install `bitsandbytes` if you want to do so.
|
||||
### Image Resolution
|
||||
An easy way to mitigate some of the memory requirements is through `--resolution`. `--resolution` refers to the resolution for input images, all the images in the train/validation dataset are resized to this.
|
||||
Note that by default, images are resized to resolution of 512, but it's good to keep in mind in case you're accustomed to training on higher resolutions.
|
||||
### Precision of saved LoRA layers
|
||||
By default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision="bf16"`, final finetuned layers will be saved in `torch.bfloat16` as well.
|
||||
This reduces memory requirements significantly w/o a significant quality loss. Note that if you do wish to save the final layers in float32 at the expanse of more memory usage, you can do so by passing `--upcast_before_saving`.
|
||||
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="black-forest-labs/FLUX.2-dev"
|
||||
export INSTANCE_DIR="dog"
|
||||
export OUTPUT_DIR="trained-flux2"
|
||||
|
||||
accelerate launch train_dreambooth_lora_flux2.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--do_fp8_training \
|
||||
--gradient_checkpointing \
|
||||
--remote_text_encoder \
|
||||
--cache_latents \
|
||||
--instance_prompt="a photo of sks dog" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--guidance_scale=1 \
|
||||
--use_8bit_adam \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--optimizer="adamW" \
|
||||
--learning_rate=1e-4 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=100 \
|
||||
--max_train_steps=500 \
|
||||
--validation_prompt="A photo of sks dog in a bucket" \
|
||||
--validation_epochs=25 \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
To better track our training experiments, we're using the following flags in the command above:
|
||||
|
||||
* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.
|
||||
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
|
||||
|
||||
> [!NOTE]
|
||||
> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases.
|
||||
|
||||
## LoRA + DreamBooth
|
||||
|
||||
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
|
||||
|
||||
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
|
||||
|
||||
### Prodigy Optimizer
|
||||
Prodigy is an adaptive optimizer that dynamically adjusts the learning rate learned parameters based on past gradients, allowing for more efficient convergence.
|
||||
By using prodigy we can "eliminate" the need for manual learning rate tuning. read more [here](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers).
|
||||
|
||||
to use prodigy, first make sure to install the prodigyopt library: `pip install prodigyopt`, and then specify -
|
||||
```bash
|
||||
--optimizer="prodigy"
|
||||
```
|
||||
> [!TIP]
|
||||
> When using prodigy it's generally good practice to set- `--learning_rate=1.0`
|
||||
|
||||
To perform DreamBooth with LoRA, run:
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="black-forest-labs/FLUX.2-dev"
|
||||
export INSTANCE_DIR="dog"
|
||||
export OUTPUT_DIR="trained-flux2-lora"
|
||||
|
||||
accelerate launch train_dreambooth_lora_flux2.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--do_fp8_training \
|
||||
--gradient_checkpointing \
|
||||
--remote_text_encoder \
|
||||
--cache_latents \
|
||||
--instance_prompt="a photo of sks dog" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--guidance_scale=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--optimizer="prodigy" \
|
||||
--learning_rate=1. \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant_with_warmup" \
|
||||
--lr_warmup_steps=100 \
|
||||
--max_train_steps=500 \
|
||||
--validation_prompt="A photo of sks dog in a bucket" \
|
||||
--validation_epochs=25 \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
### LoRA Rank and Alpha
|
||||
Two key LoRA hyperparameters are LoRA rank and LoRA alpha.
|
||||
- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters).
|
||||
- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by lora_alpha / lora_rank.
|
||||
- lora_alpha vs. rank:
|
||||
This ratio dictates the LoRA's effective strength:
|
||||
lora_alpha == rank: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16)
|
||||
lora_alpha < rank: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16)
|
||||
lora_alpha > rank: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16)
|
||||
|
||||
> [!TIP]
|
||||
> A common starting point is to set `lora_alpha` equal to `rank`.
|
||||
> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16)
|
||||
> to give the LoRA updates more influence without increasing parameter count.
|
||||
> If you find your LoRA is "overcooking" or learning too aggressively, consider setting `lora_alpha` to half of `rank`
|
||||
> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case.
|
||||
|
||||
### Target Modules
|
||||
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
|
||||
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
|
||||
applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma separated string
|
||||
the exact modules for LoRA training. Here are some examples of target modules you can provide:
|
||||
- for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0"`
|
||||
- to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2"`
|
||||
- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out"`
|
||||
> [!NOTE]
|
||||
> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma separated string:
|
||||
> **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k`
|
||||
> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k`
|
||||
> [!NOTE]
|
||||
> keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights.
|
||||
|
||||
|
||||
|
||||
## Training Image-to-Image
|
||||
|
||||
Flux.2 lets us perform image editing as well as image generation. We provide a simple script for image-to-image(I2I) LoRA fine-tuning in [train_dreambooth_lora_flux2_img2img.py](./train_dreambooth_lora_flux2_img2img.py) for both T2I and I2I. The optimizations discussed above apply this script, too.
|
||||
|
||||
**important**
|
||||
|
||||
**Important**
|
||||
To make sure you can successfully run the latest version of the image-to-image example script, we highly recommend installing from source, specifically from the commit mentioned below. To do this, execute the following steps in a new virtual environment:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
|
||||
To start, you must have a dataset containing triplets:
|
||||
|
||||
* Condition image - the input image to be transformed.
|
||||
* Target image - the desired output image after transformation.
|
||||
* Instruction - a text prompt describing the transformation from the condition image to the target image.
|
||||
|
||||
[kontext-community/relighting](https://huggingface.co/datasets/kontext-community/relighting) is a good example of such a dataset. If you are using such a dataset, you can use the command below to launch training:
|
||||
|
||||
```bash
|
||||
accelerate launch train_dreambooth_lora_flux2_img2img.py \
|
||||
--pretrained_model_name_or_path=black-forest-labs/FLUX.2-dev \
|
||||
--output_dir="flux2-i2i" \
|
||||
--dataset_name="kontext-community/relighting" \
|
||||
--image_column="output" --cond_image_column="file_name" --caption_column="instruction" \
|
||||
--do_fp8_training \
|
||||
--gradient_checkpointing \
|
||||
--remote_text_encoder \
|
||||
--cache_latents \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--guidance_scale=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--gradient_checkpointing \
|
||||
--optimizer="adamw" \
|
||||
--use_8bit_adam \
|
||||
--cache_latents \
|
||||
--learning_rate=1e-4 \
|
||||
--lr_scheduler="constant_with_warmup" \
|
||||
--lr_warmup_steps=200 \
|
||||
--max_train_steps=1000 \
|
||||
--rank=16\
|
||||
--seed="0"
|
||||
```
|
||||
|
||||
More generally, when performing I2I fine-tuning, we expect you to:
|
||||
|
||||
* Have a dataset `kontext-community/relighting`
|
||||
* Supply `image_column`, `cond_image_column`, and `caption_column` values when launching training
|
||||
|
||||
### Misc notes
|
||||
|
||||
* By default, we use `mode` as the value of `--vae_encode_mode` argument. This is because Kontext uses `mode()` of the distribution predicted by the VAE instead of sampling from it.
|
||||
### Aspect Ratio Bucketing
|
||||
we've added aspect ratio bucketing support which allows training on images with different aspect ratios without cropping them to a single square resolution. This technique helps preserve the original composition of training images and can improve training efficiency.
|
||||
|
||||
To enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a semicolon-separated list of height,width pairs, such as:
|
||||
|
||||
`--aspect_ratio_buckets="672,1568;688,1504;720,1456;752,1392;800,1328;832,1248;880,1184;944,1104;1024,1024;1104,944;1184,880;1248,832;1328,800;1392,752;1456,720;1504,688;1568,672"
|
||||
`
|
||||
Since Flux.2 finetuning is still an experimental phase, we encourage you to explore different settings and share your insights! 🤗
|
||||
@@ -0,0 +1,262 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import safetensors
|
||||
|
||||
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
class DreamBoothLoRAFlux2(ExamplesTestsAccelerate):
|
||||
instance_data_dir = "docs/source/en/imgs"
|
||||
instance_prompt = "dog"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2"
|
||||
script_path = "examples/dreambooth/train_dreambooth_lora_flux2.py"
|
||||
transformer_layer_type = "single_transformer_blocks.0.attn.to_qkv_mlp_proj"
|
||||
|
||||
def test_dreambooth_lora_flux2(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--max_sequence_length 8
|
||||
--text_encoder_out_layers 1
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_latent_caching(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--cache_latents
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--max_sequence_length 8
|
||||
--text_encoder_out_layers 1
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_layers(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--cache_latents
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lora_layers {self.transformer_layer_type}
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--max_sequence_length 8
|
||||
--text_encoder_out_layers 1
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names. In this test, we only params of
|
||||
# transformer.single_transformer_blocks.0.attn.to_k should be in the state dict
|
||||
starts_with_transformer = all(
|
||||
key.startswith(f"transformer.{self.transformer_layer_type}") for key in lora_state_dict.keys()
|
||||
)
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_flux2_checkpointing_checkpoints_total_limit(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=6
|
||||
--checkpoints_total_limit=2
|
||||
--max_sequence_length 8
|
||||
--checkpointing_steps=2
|
||||
--text_encoder_out_layers 1
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_dreambooth_lora_flux2_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=4
|
||||
--checkpointing_steps=2
|
||||
--max_sequence_length 8
|
||||
--text_encoder_out_layers 1
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
|
||||
|
||||
resume_run_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=8
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--checkpoints_total_limit=2
|
||||
--max_sequence_length 8
|
||||
--text_encoder_out_layers 1
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
|
||||
|
||||
def test_dreambooth_lora_with_metadata(self):
|
||||
# Use a `lora_alpha` that is different from `rank`.
|
||||
lora_alpha = 8
|
||||
rank = 4
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--lora_alpha={lora_alpha}
|
||||
--rank={rank}
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--max_sequence_length 8
|
||||
--text_encoder_out_layers 1
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
|
||||
self.assertTrue(os.path.isfile(state_dict_file))
|
||||
|
||||
# Check if the metadata was properly serialized.
|
||||
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata() or {}
|
||||
|
||||
metadata.pop("format", None)
|
||||
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
|
||||
if raw:
|
||||
raw = json.loads(raw)
|
||||
|
||||
loaded_lora_alpha = raw["transformer.lora_alpha"]
|
||||
self.assertTrue(loaded_lora_alpha == lora_alpha)
|
||||
loaded_lora_rank = raw["transformer.r"]
|
||||
self.assertTrue(loaded_lora_rank == rank)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -104,6 +104,8 @@ To use your own dataset, there are 2 ways:
|
||||
- you can either provide your own folder as `--train_data_dir`
|
||||
- or you can upload your dataset to the hub (possibly as a private repo, if you prefer so), and simply pass the `--dataset_name` argument.
|
||||
|
||||
If your dataset contains 16 or 32-bit channels (for example, medical TIFFs), add the `--preserve_input_precision` flag so the preprocessing keeps the original precision while still training a 3-channel model. Precision still depends on the decoder: Pillow keeps 16-bit grayscale and float inputs, but many 16-bit RGB files are decoded as 8-bit RGB, and the flag cannot recover precision lost at load time.
|
||||
|
||||
Below, we explain both in more detail.
|
||||
|
||||
#### Provide the dataset as a folder
|
||||
|
||||
@@ -52,6 +52,24 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
||||
return res.expand(broadcast_shape)
|
||||
|
||||
|
||||
def _ensure_three_channels(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Ensure the tensor has exactly three channels (C, H, W) by repeating or truncating channels when needed.
|
||||
"""
|
||||
if tensor.ndim == 2:
|
||||
tensor = tensor.unsqueeze(0)
|
||||
channels = tensor.shape[0]
|
||||
if channels == 3:
|
||||
return tensor
|
||||
if channels == 1:
|
||||
return tensor.repeat(3, 1, 1)
|
||||
if channels == 2:
|
||||
return torch.cat([tensor, tensor[:1]], dim=0)
|
||||
if channels > 3:
|
||||
return tensor[:3]
|
||||
raise ValueError(f"Unsupported number of channels: {channels}")
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
@@ -260,6 +278,11 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--preserve_input_precision",
|
||||
action="store_true",
|
||||
help="Preserve 16/32-bit image precision by avoiding 8-bit RGB conversion while still producing 3-channel tensors.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
@@ -453,19 +476,41 @@ def main(args):
|
||||
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
|
||||
|
||||
# Preprocessing the datasets and DataLoaders creation.
|
||||
spatial_augmentations = [
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
|
||||
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
|
||||
]
|
||||
|
||||
augmentations = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
|
||||
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
|
||||
spatial_augmentations
|
||||
+ [
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
|
||||
precision_augmentations = transforms.Compose(
|
||||
[
|
||||
transforms.PILToTensor(),
|
||||
transforms.Lambda(_ensure_three_channels),
|
||||
transforms.ConvertImageDtype(torch.float32),
|
||||
]
|
||||
+ spatial_augmentations
|
||||
+ [transforms.Normalize([0.5], [0.5])]
|
||||
)
|
||||
|
||||
def transform_images(examples):
|
||||
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
|
||||
return {"input": images}
|
||||
processed = []
|
||||
for image in examples["image"]:
|
||||
if not args.preserve_input_precision:
|
||||
processed.append(augmentations(image.convert("RGB")))
|
||||
else:
|
||||
precise_image = image
|
||||
if precise_image.mode == "P":
|
||||
precise_image = precise_image.convert("RGB")
|
||||
processed.append(precision_augmentations(precise_image))
|
||||
return {"input": processed}
|
||||
|
||||
logger.info(f"Dataset size: {len(dataset)}")
|
||||
|
||||
|
||||
@@ -0,0 +1,475 @@
|
||||
import argparse
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import AutoProcessor, GenerationConfig, Mistral3ForConditionalGeneration
|
||||
|
||||
from diffusers import AutoencoderKLFlux2, FlowMatchEulerDiscreteScheduler, Flux2Pipeline, Flux2Transformer2DModel
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
|
||||
"""
|
||||
# VAE
|
||||
|
||||
python scripts/convert_flux2_to_diffusers.py \
|
||||
--original_state_dict_repo_id "diffusers-internal-dev/new-model-image" \
|
||||
--vae_filename "flux2-vae.sft" \
|
||||
--output_path "/raid/yiyi/dummy-flux2-diffusers" \
|
||||
--vae
|
||||
|
||||
# DiT
|
||||
|
||||
python scripts/convert_flux2_to_diffusers.py \
|
||||
--original_state_dict_repo_id diffusers-internal-dev/new-model-image \
|
||||
--dit_filename flux-dev-dummy.sft \
|
||||
--dit \
|
||||
--output_path .
|
||||
|
||||
# Full pipe
|
||||
|
||||
python scripts/convert_flux2_to_diffusers.py \
|
||||
--original_state_dict_repo_id diffusers-internal-dev/new-model-image \
|
||||
--dit_filename flux-dev-dummy.sft \
|
||||
--vae_filename "flux2-vae.sft" \
|
||||
--dit --vae --full_pipe \
|
||||
--output_path .
|
||||
"""
|
||||
|
||||
CTX = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
|
||||
parser.add_argument("--vae_filename", default="flux2-vae.sft", type=str)
|
||||
parser.add_argument("--dit_filename", default="flux-dev-dummy.sft", type=str)
|
||||
parser.add_argument("--vae", action="store_true")
|
||||
parser.add_argument("--dit", action="store_true")
|
||||
parser.add_argument("--vae_dtype", type=str, default="fp32")
|
||||
parser.add_argument("--dit_dtype", type=str, default="bf16")
|
||||
parser.add_argument("--checkpoint_path", default=None, type=str)
|
||||
parser.add_argument("--full_pipe", action="store_true")
|
||||
parser.add_argument("--output_path", type=str)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def load_original_checkpoint(args, filename):
|
||||
if args.original_state_dict_repo_id is not None:
|
||||
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename)
|
||||
elif args.checkpoint_path is not None:
|
||||
ckpt_path = args.checkpoint_path
|
||||
else:
|
||||
raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
|
||||
|
||||
original_state_dict = safetensors.torch.load_file(ckpt_path)
|
||||
return original_state_dict
|
||||
|
||||
|
||||
DIFFUSERS_VAE_TO_FLUX2_MAPPING = {
|
||||
"encoder.conv_in.weight": "encoder.conv_in.weight",
|
||||
"encoder.conv_in.bias": "encoder.conv_in.bias",
|
||||
"encoder.conv_out.weight": "encoder.conv_out.weight",
|
||||
"encoder.conv_out.bias": "encoder.conv_out.bias",
|
||||
"encoder.conv_norm_out.weight": "encoder.norm_out.weight",
|
||||
"encoder.conv_norm_out.bias": "encoder.norm_out.bias",
|
||||
"decoder.conv_in.weight": "decoder.conv_in.weight",
|
||||
"decoder.conv_in.bias": "decoder.conv_in.bias",
|
||||
"decoder.conv_out.weight": "decoder.conv_out.weight",
|
||||
"decoder.conv_out.bias": "decoder.conv_out.bias",
|
||||
"decoder.conv_norm_out.weight": "decoder.norm_out.weight",
|
||||
"decoder.conv_norm_out.bias": "decoder.norm_out.bias",
|
||||
"quant_conv.weight": "encoder.quant_conv.weight",
|
||||
"quant_conv.bias": "encoder.quant_conv.bias",
|
||||
"post_quant_conv.weight": "decoder.post_quant_conv.weight",
|
||||
"post_quant_conv.bias": "decoder.post_quant_conv.bias",
|
||||
"bn.running_mean": "bn.running_mean",
|
||||
"bn.running_var": "bn.running_var",
|
||||
}
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
|
||||
def conv_attn_to_linear(checkpoint):
|
||||
keys = list(checkpoint.keys())
|
||||
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
||||
for key in keys:
|
||||
if ".".join(key.split(".")[-2:]) in attn_keys:
|
||||
if checkpoint[key].ndim > 2:
|
||||
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
||||
elif "proj_attn.weight" in key:
|
||||
if checkpoint[key].ndim > 2:
|
||||
checkpoint[key] = checkpoint[key][:, :, 0]
|
||||
|
||||
|
||||
def update_vae_resnet_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
|
||||
for ldm_key in keys:
|
||||
diffusers_key = ldm_key.replace(mapping["old"], mapping["new"]).replace("nin_shortcut", "conv_shortcut")
|
||||
new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
|
||||
|
||||
|
||||
def update_vae_attentions_ldm_to_diffusers(keys, new_checkpoint, checkpoint, mapping):
|
||||
for ldm_key in keys:
|
||||
diffusers_key = (
|
||||
ldm_key.replace(mapping["old"], mapping["new"])
|
||||
.replace("norm.weight", "group_norm.weight")
|
||||
.replace("norm.bias", "group_norm.bias")
|
||||
.replace("q.weight", "to_q.weight")
|
||||
.replace("q.bias", "to_q.bias")
|
||||
.replace("k.weight", "to_k.weight")
|
||||
.replace("k.bias", "to_k.bias")
|
||||
.replace("v.weight", "to_v.weight")
|
||||
.replace("v.bias", "to_v.bias")
|
||||
.replace("proj_out.weight", "to_out.0.weight")
|
||||
.replace("proj_out.bias", "to_out.0.bias")
|
||||
)
|
||||
new_checkpoint[diffusers_key] = checkpoint.get(ldm_key)
|
||||
|
||||
# proj_attn.weight has to be converted from conv 1D to linear
|
||||
shape = new_checkpoint[diffusers_key].shape
|
||||
|
||||
if len(shape) == 3:
|
||||
new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0]
|
||||
elif len(shape) == 4:
|
||||
new_checkpoint[diffusers_key] = new_checkpoint[diffusers_key][:, :, 0, 0]
|
||||
|
||||
|
||||
def convert_flux2_vae_checkpoint_to_diffusers(vae_state_dict, config):
|
||||
new_checkpoint = {}
|
||||
for diffusers_key, ldm_key in DIFFUSERS_VAE_TO_FLUX2_MAPPING.items():
|
||||
if ldm_key not in vae_state_dict:
|
||||
continue
|
||||
new_checkpoint[diffusers_key] = vae_state_dict[ldm_key]
|
||||
|
||||
# Retrieves the keys for the encoder down blocks only
|
||||
num_down_blocks = len(config["down_block_types"])
|
||||
down_blocks = {
|
||||
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
||||
}
|
||||
|
||||
for i in range(num_down_blocks):
|
||||
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
||||
update_vae_resnet_ldm_to_diffusers(
|
||||
resnets,
|
||||
new_checkpoint,
|
||||
vae_state_dict,
|
||||
mapping={"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"},
|
||||
)
|
||||
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
||||
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.get(
|
||||
f"encoder.down.{i}.downsample.conv.weight"
|
||||
)
|
||||
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.get(
|
||||
f"encoder.down.{i}.downsample.conv.bias"
|
||||
)
|
||||
|
||||
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
||||
num_mid_res_blocks = 2
|
||||
for i in range(1, num_mid_res_blocks + 1):
|
||||
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
||||
update_vae_resnet_ldm_to_diffusers(
|
||||
resnets,
|
||||
new_checkpoint,
|
||||
vae_state_dict,
|
||||
mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"},
|
||||
)
|
||||
|
||||
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
||||
update_vae_attentions_ldm_to_diffusers(
|
||||
mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
||||
)
|
||||
|
||||
# Retrieves the keys for the decoder up blocks only
|
||||
num_up_blocks = len(config["up_block_types"])
|
||||
up_blocks = {
|
||||
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
||||
}
|
||||
|
||||
for i in range(num_up_blocks):
|
||||
block_id = num_up_blocks - 1 - i
|
||||
resnets = [
|
||||
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
||||
]
|
||||
update_vae_resnet_ldm_to_diffusers(
|
||||
resnets,
|
||||
new_checkpoint,
|
||||
vae_state_dict,
|
||||
mapping={"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"},
|
||||
)
|
||||
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
||||
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
||||
f"decoder.up.{block_id}.upsample.conv.weight"
|
||||
]
|
||||
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
||||
f"decoder.up.{block_id}.upsample.conv.bias"
|
||||
]
|
||||
|
||||
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
||||
num_mid_res_blocks = 2
|
||||
for i in range(1, num_mid_res_blocks + 1):
|
||||
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
||||
update_vae_resnet_ldm_to_diffusers(
|
||||
resnets,
|
||||
new_checkpoint,
|
||||
vae_state_dict,
|
||||
mapping={"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"},
|
||||
)
|
||||
|
||||
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
||||
update_vae_attentions_ldm_to_diffusers(
|
||||
mid_attentions, new_checkpoint, vae_state_dict, mapping={"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
||||
)
|
||||
conv_attn_to_linear(new_checkpoint)
|
||||
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
FLUX2_TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
# Image and text input projections
|
||||
"img_in": "x_embedder",
|
||||
"txt_in": "context_embedder",
|
||||
# Timestep and guidance embeddings
|
||||
"time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1",
|
||||
"time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2",
|
||||
"guidance_in.in_layer": "time_guidance_embed.guidance_embedder.linear_1",
|
||||
"guidance_in.out_layer": "time_guidance_embed.guidance_embedder.linear_2",
|
||||
# Modulation parameters
|
||||
"double_stream_modulation_img.lin": "double_stream_modulation_img.linear",
|
||||
"double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear",
|
||||
"single_stream_modulation.lin": "single_stream_modulation.linear",
|
||||
# Final output layer
|
||||
# "final_layer.adaLN_modulation.1": "norm_out.linear", # Handle separately since we need to swap mod params
|
||||
"final_layer.linear": "proj_out",
|
||||
}
|
||||
|
||||
|
||||
FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP = {
|
||||
"final_layer.adaLN_modulation.1": "norm_out.linear",
|
||||
}
|
||||
|
||||
|
||||
FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP = {
|
||||
# Handle fused QKV projections separately as we need to break into Q, K, V projections
|
||||
"img_attn.norm.query_norm": "attn.norm_q",
|
||||
"img_attn.norm.key_norm": "attn.norm_k",
|
||||
"img_attn.proj": "attn.to_out.0",
|
||||
"img_mlp.0": "ff.linear_in",
|
||||
"img_mlp.2": "ff.linear_out",
|
||||
"txt_attn.norm.query_norm": "attn.norm_added_q",
|
||||
"txt_attn.norm.key_norm": "attn.norm_added_k",
|
||||
"txt_attn.proj": "attn.to_add_out",
|
||||
"txt_mlp.0": "ff_context.linear_in",
|
||||
"txt_mlp.2": "ff_context.linear_out",
|
||||
}
|
||||
|
||||
|
||||
FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP = {
|
||||
"linear1": "attn.to_qkv_mlp_proj",
|
||||
"linear2": "attn.to_out",
|
||||
"norm.query_norm": "attn.norm_q",
|
||||
"norm.key_norm": "attn.norm_k",
|
||||
}
|
||||
|
||||
|
||||
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
|
||||
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use
|
||||
# diffusers implementation
|
||||
def swap_scale_shift(weight):
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
return new_weight
|
||||
|
||||
|
||||
def convert_ada_layer_norm_weights(key: str, state_dict: Dict[str, Any]) -> None:
|
||||
# Skip if not a weight
|
||||
if ".weight" not in key:
|
||||
return
|
||||
|
||||
# If adaLN_modulation is in the key, swap scale and shift parameters
|
||||
# Original implementation is (shift, scale); diffusers implementation is (scale, shift)
|
||||
if "adaLN_modulation" in key:
|
||||
key_without_param_type, param_type = key.rsplit(".", maxsplit=1)
|
||||
# Assume all such keys are in the AdaLayerNorm key map
|
||||
new_key_without_param_type = FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP[key_without_param_type]
|
||||
new_key = ".".join([new_key_without_param_type, param_type])
|
||||
|
||||
swapped_weight = swap_scale_shift(state_dict.pop(key))
|
||||
state_dict[new_key] = swapped_weight
|
||||
return
|
||||
|
||||
|
||||
def convert_flux2_double_stream_blocks(key: str, state_dict: Dict[str, Any]) -> None:
|
||||
# Skip if not a weight, bias, or scale
|
||||
if ".weight" not in key and ".bias" not in key and ".scale" not in key:
|
||||
return
|
||||
|
||||
new_prefix = "transformer_blocks"
|
||||
if "double_blocks." in key:
|
||||
parts = key.split(".")
|
||||
block_idx = parts[1]
|
||||
modality_block_name = parts[2] # img_attn, img_mlp, txt_attn, txt_mlp
|
||||
within_block_name = ".".join(parts[2:-1])
|
||||
param_type = parts[-1]
|
||||
|
||||
if param_type == "scale":
|
||||
param_type = "weight"
|
||||
|
||||
if "qkv" in within_block_name:
|
||||
fused_qkv_weight = state_dict.pop(key)
|
||||
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
|
||||
if "img" in modality_block_name:
|
||||
# double_blocks.{N}.img_attn.qkv --> transformer_blocks.{N}.attn.{to_q|to_k|to_v}
|
||||
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
|
||||
new_q_name = "attn.to_q"
|
||||
new_k_name = "attn.to_k"
|
||||
new_v_name = "attn.to_v"
|
||||
elif "txt" in modality_block_name:
|
||||
# double_blocks.{N}.txt_attn.qkv --> transformer_blocks.{N}.attn.{add_q_proj|add_k_proj|add_v_proj}
|
||||
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
|
||||
new_q_name = "attn.add_q_proj"
|
||||
new_k_name = "attn.add_k_proj"
|
||||
new_v_name = "attn.add_v_proj"
|
||||
new_q_key = ".".join([new_prefix, block_idx, new_q_name, param_type])
|
||||
new_k_key = ".".join([new_prefix, block_idx, new_k_name, param_type])
|
||||
new_v_key = ".".join([new_prefix, block_idx, new_v_name, param_type])
|
||||
state_dict[new_q_key] = to_q_weight
|
||||
state_dict[new_k_key] = to_k_weight
|
||||
state_dict[new_v_key] = to_v_weight
|
||||
else:
|
||||
new_within_block_name = FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP[within_block_name]
|
||||
new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type])
|
||||
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
return
|
||||
|
||||
|
||||
def convert_flux2_single_stream_blocks(key: str, state_dict: Dict[str, Any]) -> None:
|
||||
# Skip if not a weight, bias, or scale
|
||||
if ".weight" not in key and ".bias" not in key and ".scale" not in key:
|
||||
return
|
||||
|
||||
# Mapping:
|
||||
# - single_blocks.{N}.linear1 --> single_transformer_blocks.{N}.attn.to_qkv_mlp_proj
|
||||
# - single_blocks.{N}.linear2 --> single_transformer_blocks.{N}.attn.to_out
|
||||
# - single_blocks.{N}.norm.query_norm.scale --> single_transformer_blocks.{N}.attn.norm_q.weight
|
||||
# - single_blocks.{N}.norm.key_norm.scale --> single_transformer_blocks.{N}.attn.norm_k.weight
|
||||
new_prefix = "single_transformer_blocks"
|
||||
if "single_blocks." in key:
|
||||
parts = key.split(".")
|
||||
block_idx = parts[1]
|
||||
within_block_name = ".".join(parts[2:-1])
|
||||
param_type = parts[-1]
|
||||
|
||||
if param_type == "scale":
|
||||
param_type = "weight"
|
||||
|
||||
new_within_block_name = FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP[within_block_name]
|
||||
new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type])
|
||||
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
return
|
||||
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"adaLN_modulation": convert_ada_layer_norm_weights,
|
||||
"double_blocks": convert_flux2_double_stream_blocks,
|
||||
"single_blocks": convert_flux2_single_stream_blocks,
|
||||
}
|
||||
|
||||
|
||||
def update_state_dict(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
|
||||
def get_flux2_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
|
||||
if model_type == "test" or model_type == "dummy-flux2":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/dummy-flux2",
|
||||
"diffusers_config": {
|
||||
"patch_size": 1,
|
||||
"in_channels": 128,
|
||||
"num_layers": 8,
|
||||
"num_single_layers": 48,
|
||||
"attention_head_dim": 128,
|
||||
"num_attention_heads": 48,
|
||||
"joint_attention_dim": 15360,
|
||||
"timestep_guidance_channels": 256,
|
||||
"mlp_ratio": 3.0,
|
||||
"axes_dims_rope": (32, 32, 32, 32),
|
||||
"rope_theta": 2000,
|
||||
"eps": 1e-6,
|
||||
},
|
||||
}
|
||||
rename_dict = FLUX2_TRANSFORMER_KEYS_RENAME_DICT
|
||||
special_keys_remap = TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_flux2_transformer_to_diffusers(original_state_dict: Dict[str, torch.Tensor], model_type: str):
|
||||
config, rename_dict, special_keys_remap = get_flux2_transformer_config(model_type)
|
||||
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
with init_empty_weights():
|
||||
transformer = Flux2Transformer2DModel.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict(original_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return transformer
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.vae:
|
||||
original_vae_ckpt = load_original_checkpoint(args, filename=args.vae_filename)
|
||||
vae = AutoencoderKLFlux2()
|
||||
converted_vae_state_dict = convert_flux2_vae_checkpoint_to_diffusers(original_vae_ckpt, vae.config)
|
||||
vae.load_state_dict(converted_vae_state_dict, strict=True)
|
||||
if not args.full_pipe:
|
||||
vae_dtype = torch.bfloat16 if args.vae_dtype == "bf16" else torch.float32
|
||||
vae.to(vae_dtype).save_pretrained(f"{args.output_path}/vae")
|
||||
|
||||
if args.dit:
|
||||
original_dit_ckpt = load_original_checkpoint(args, filename=args.dit_filename)
|
||||
transformer = convert_flux2_transformer_to_diffusers(original_dit_ckpt, "test")
|
||||
if not args.full_pipe:
|
||||
dit_dtype = torch.bfloat16 if args.dit_dtype == "bf16" else torch.float32
|
||||
transformer.to(dit_dtype).save_pretrained(f"{args.output_path}/transformer")
|
||||
|
||||
if args.full_pipe:
|
||||
tokenizer_id = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
|
||||
text_encoder_id = "mistralai/Mistral-Small-3.2-24B-Instruct-2506"
|
||||
generate_config = GenerationConfig.from_pretrained(text_encoder_id)
|
||||
generate_config.do_sample = True
|
||||
text_encoder = Mistral3ForConditionalGeneration.from_pretrained(
|
||||
text_encoder_id, generation_config=generate_config, torch_dtype=torch.bfloat16
|
||||
)
|
||||
tokenizer = AutoProcessor.from_pretrained(tokenizer_id)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev", subfolder="scheduler"
|
||||
)
|
||||
|
||||
pipe = Flux2Pipeline(
|
||||
vae=vae, transformer=transformer, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler
|
||||
)
|
||||
pipe.save_pretrained(args.output_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(args)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,345 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Script to convert PRX checkpoint from original codebase to diffusers format.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
|
||||
from diffusers.pipelines.prx import PRXPipeline
|
||||
|
||||
|
||||
DEFAULT_RESOLUTION = 512
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PRXBase:
|
||||
context_in_dim: int = 2304
|
||||
hidden_size: int = 1792
|
||||
mlp_ratio: float = 3.5
|
||||
num_heads: int = 28
|
||||
depth: int = 16
|
||||
axes_dim: Tuple[int, int] = (32, 32)
|
||||
theta: int = 10_000
|
||||
time_factor: float = 1000.0
|
||||
time_max_period: int = 10_000
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PRXFlux(PRXBase):
|
||||
in_channels: int = 16
|
||||
patch_size: int = 2
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PRXDCAE(PRXBase):
|
||||
in_channels: int = 32
|
||||
patch_size: int = 1
|
||||
|
||||
|
||||
def build_config(vae_type: str) -> Tuple[dict, int]:
|
||||
if vae_type == "flux":
|
||||
cfg = PRXFlux()
|
||||
elif vae_type == "dc-ae":
|
||||
cfg = PRXDCAE()
|
||||
else:
|
||||
raise ValueError(f"Unsupported VAE type: {vae_type}. Use 'flux' or 'dc-ae'")
|
||||
|
||||
config_dict = asdict(cfg)
|
||||
config_dict["axes_dim"] = list(config_dict["axes_dim"]) # type: ignore[index]
|
||||
return config_dict
|
||||
|
||||
|
||||
def create_parameter_mapping(depth: int) -> dict:
|
||||
"""Create mapping from old parameter names to new diffusers names."""
|
||||
|
||||
# Key mappings for structural changes
|
||||
mapping = {}
|
||||
|
||||
# Map old structure (layers in PRXBlock) to new structure (layers in PRXAttention)
|
||||
for i in range(depth):
|
||||
# QKV projections moved to attention module
|
||||
mapping[f"blocks.{i}.img_qkv_proj.weight"] = f"blocks.{i}.attention.img_qkv_proj.weight"
|
||||
mapping[f"blocks.{i}.txt_kv_proj.weight"] = f"blocks.{i}.attention.txt_kv_proj.weight"
|
||||
|
||||
# QK norm moved to attention module and renamed to match Attention's qk_norm structure
|
||||
mapping[f"blocks.{i}.qk_norm.query_norm.scale"] = f"blocks.{i}.attention.norm_q.weight"
|
||||
mapping[f"blocks.{i}.qk_norm.key_norm.scale"] = f"blocks.{i}.attention.norm_k.weight"
|
||||
mapping[f"blocks.{i}.qk_norm.query_norm.weight"] = f"blocks.{i}.attention.norm_q.weight"
|
||||
mapping[f"blocks.{i}.qk_norm.key_norm.weight"] = f"blocks.{i}.attention.norm_k.weight"
|
||||
|
||||
# K norm for text tokens moved to attention module
|
||||
mapping[f"blocks.{i}.k_norm.scale"] = f"blocks.{i}.attention.norm_added_k.weight"
|
||||
mapping[f"blocks.{i}.k_norm.weight"] = f"blocks.{i}.attention.norm_added_k.weight"
|
||||
|
||||
# Attention output projection
|
||||
mapping[f"blocks.{i}.attn_out.weight"] = f"blocks.{i}.attention.to_out.0.weight"
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
def convert_checkpoint_parameters(old_state_dict: Dict[str, torch.Tensor], depth: int) -> Dict[str, torch.Tensor]:
|
||||
"""Convert old checkpoint parameters to new diffusers format."""
|
||||
|
||||
print("Converting checkpoint parameters...")
|
||||
|
||||
mapping = create_parameter_mapping(depth)
|
||||
converted_state_dict = {}
|
||||
|
||||
for key, value in old_state_dict.items():
|
||||
new_key = key
|
||||
|
||||
# Apply specific mappings if needed
|
||||
if key in mapping:
|
||||
new_key = mapping[key]
|
||||
print(f" Mapped: {key} -> {new_key}")
|
||||
|
||||
converted_state_dict[new_key] = value
|
||||
|
||||
print(f"✓ Converted {len(converted_state_dict)} parameters")
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> PRXTransformer2DModel:
|
||||
"""Create and load PRXTransformer2DModel from old checkpoint."""
|
||||
|
||||
print(f"Loading checkpoint from: {checkpoint_path}")
|
||||
|
||||
# Load old checkpoint
|
||||
if not os.path.exists(checkpoint_path):
|
||||
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
|
||||
|
||||
old_checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
||||
|
||||
# Handle different checkpoint formats
|
||||
if isinstance(old_checkpoint, dict):
|
||||
if "model" in old_checkpoint:
|
||||
state_dict = old_checkpoint["model"]
|
||||
elif "state_dict" in old_checkpoint:
|
||||
state_dict = old_checkpoint["state_dict"]
|
||||
else:
|
||||
state_dict = old_checkpoint
|
||||
else:
|
||||
state_dict = old_checkpoint
|
||||
|
||||
print(f"✓ Loaded checkpoint with {len(state_dict)} parameters")
|
||||
|
||||
# Convert parameter names if needed
|
||||
model_depth = int(config.get("depth", 16))
|
||||
converted_state_dict = convert_checkpoint_parameters(state_dict, depth=model_depth)
|
||||
|
||||
# Create transformer with config
|
||||
print("Creating PRXTransformer2DModel...")
|
||||
transformer = PRXTransformer2DModel(**config)
|
||||
|
||||
# Load state dict
|
||||
print("Loading converted parameters...")
|
||||
missing_keys, unexpected_keys = transformer.load_state_dict(converted_state_dict, strict=False)
|
||||
|
||||
if missing_keys:
|
||||
print(f"⚠ Missing keys: {missing_keys}")
|
||||
if unexpected_keys:
|
||||
print(f"⚠ Unexpected keys: {unexpected_keys}")
|
||||
|
||||
if not missing_keys and not unexpected_keys:
|
||||
print("✓ All parameters loaded successfully!")
|
||||
|
||||
return transformer
|
||||
|
||||
|
||||
def create_scheduler_config(output_path: str, shift: float):
|
||||
"""Create FlowMatchEulerDiscreteScheduler config."""
|
||||
|
||||
scheduler_config = {"_class_name": "FlowMatchEulerDiscreteScheduler", "num_train_timesteps": 1000, "shift": shift}
|
||||
|
||||
scheduler_path = os.path.join(output_path, "scheduler")
|
||||
os.makedirs(scheduler_path, exist_ok=True)
|
||||
|
||||
with open(os.path.join(scheduler_path, "scheduler_config.json"), "w") as f:
|
||||
json.dump(scheduler_config, f, indent=2)
|
||||
|
||||
print("✓ Created scheduler config")
|
||||
|
||||
|
||||
def download_and_save_vae(vae_type: str, output_path: str):
|
||||
"""Download and save VAE to local directory."""
|
||||
from diffusers import AutoencoderDC, AutoencoderKL
|
||||
|
||||
vae_path = os.path.join(output_path, "vae")
|
||||
os.makedirs(vae_path, exist_ok=True)
|
||||
|
||||
if vae_type == "flux":
|
||||
print("Downloading FLUX VAE from black-forest-labs/FLUX.1-dev...")
|
||||
vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae")
|
||||
else: # dc-ae
|
||||
print("Downloading DC-AE VAE from mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers...")
|
||||
vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers")
|
||||
|
||||
vae.save_pretrained(vae_path)
|
||||
print(f"✓ Saved VAE to {vae_path}")
|
||||
|
||||
|
||||
def download_and_save_text_encoder(output_path: str):
|
||||
"""Download and save T5Gemma text encoder and tokenizer."""
|
||||
from transformers import GemmaTokenizerFast
|
||||
from transformers.models.t5gemma.modeling_t5gemma import T5GemmaModel
|
||||
|
||||
text_encoder_path = os.path.join(output_path, "text_encoder")
|
||||
tokenizer_path = os.path.join(output_path, "tokenizer")
|
||||
os.makedirs(text_encoder_path, exist_ok=True)
|
||||
os.makedirs(tokenizer_path, exist_ok=True)
|
||||
|
||||
print("Downloading T5Gemma model from google/t5gemma-2b-2b-ul2...")
|
||||
t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2")
|
||||
|
||||
# Extract and save only the encoder
|
||||
t5gemma_encoder = t5gemma_model.encoder
|
||||
t5gemma_encoder.save_pretrained(text_encoder_path)
|
||||
print(f"✓ Saved T5GemmaEncoder to {text_encoder_path}")
|
||||
|
||||
print("Downloading tokenizer from google/t5gemma-2b-2b-ul2...")
|
||||
tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2")
|
||||
tokenizer.model_max_length = 256
|
||||
tokenizer.save_pretrained(tokenizer_path)
|
||||
print(f"✓ Saved tokenizer to {tokenizer_path}")
|
||||
|
||||
|
||||
def create_model_index(vae_type: str, default_image_size: int, output_path: str):
|
||||
"""Create model_index.json for the pipeline."""
|
||||
|
||||
if vae_type == "flux":
|
||||
vae_class = "AutoencoderKL"
|
||||
else: # dc-ae
|
||||
vae_class = "AutoencoderDC"
|
||||
|
||||
model_index = {
|
||||
"_class_name": "PRXPipeline",
|
||||
"_diffusers_version": "0.31.0.dev0",
|
||||
"_name_or_path": os.path.basename(output_path),
|
||||
"default_sample_size": default_image_size,
|
||||
"scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"],
|
||||
"text_encoder": ["prx", "T5GemmaEncoder"],
|
||||
"tokenizer": ["transformers", "GemmaTokenizerFast"],
|
||||
"transformer": ["diffusers", "PRXTransformer2DModel"],
|
||||
"vae": ["diffusers", vae_class],
|
||||
}
|
||||
|
||||
model_index_path = os.path.join(output_path, "model_index.json")
|
||||
with open(model_index_path, "w") as f:
|
||||
json.dump(model_index, f, indent=2)
|
||||
|
||||
|
||||
def main(args):
|
||||
# Validate inputs
|
||||
if not os.path.exists(args.checkpoint_path):
|
||||
raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint_path}")
|
||||
|
||||
config = build_config(args.vae_type)
|
||||
|
||||
# Create output directory
|
||||
os.makedirs(args.output_path, exist_ok=True)
|
||||
print(f"✓ Output directory: {args.output_path}")
|
||||
|
||||
# Create transformer from checkpoint
|
||||
transformer = create_transformer_from_checkpoint(args.checkpoint_path, config)
|
||||
|
||||
# Save transformer
|
||||
transformer_path = os.path.join(args.output_path, "transformer")
|
||||
os.makedirs(transformer_path, exist_ok=True)
|
||||
|
||||
# Save config
|
||||
with open(os.path.join(transformer_path, "config.json"), "w") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
# Save model weights as safetensors
|
||||
state_dict = transformer.state_dict()
|
||||
save_file(state_dict, os.path.join(transformer_path, "diffusion_pytorch_model.safetensors"))
|
||||
print(f"✓ Saved transformer to {transformer_path}")
|
||||
|
||||
# Create scheduler config
|
||||
create_scheduler_config(args.output_path, args.shift)
|
||||
|
||||
download_and_save_vae(args.vae_type, args.output_path)
|
||||
download_and_save_text_encoder(args.output_path)
|
||||
|
||||
# Create model_index.json
|
||||
create_model_index(args.vae_type, args.resolution, args.output_path)
|
||||
|
||||
# Verify the pipeline can be loaded
|
||||
try:
|
||||
pipeline = PRXPipeline.from_pretrained(args.output_path)
|
||||
print("Pipeline loaded successfully!")
|
||||
print(f"Transformer: {type(pipeline.transformer).__name__}")
|
||||
print(f"VAE: {type(pipeline.vae).__name__}")
|
||||
print(f"Text Encoder: {type(pipeline.text_encoder).__name__}")
|
||||
print(f"Scheduler: {type(pipeline.scheduler).__name__}")
|
||||
|
||||
# Display model info
|
||||
num_params = sum(p.numel() for p in pipeline.transformer.parameters())
|
||||
print(f"✓ Transformer parameters: {num_params:,}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Pipeline verification failed: {e}")
|
||||
return False
|
||||
|
||||
print("Conversion completed successfully!")
|
||||
print(f"Converted pipeline saved to: {args.output_path}")
|
||||
print(f"VAE type: {args.vae_type}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert PRX checkpoint to diffusers format")
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", type=str, required=True, help="Path to the original PRX checkpoint (.pth file )"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_path", type=str, required=True, help="Output directory for the converted diffusers pipeline"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vae_type",
|
||||
type=str,
|
||||
choices=["flux", "dc-ae"],
|
||||
required=True,
|
||||
help="VAE type to use: 'flux' for AutoencoderKL (16 channels) or 'dc-ae' for AutoencoderDC (32 channels)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
choices=[256, 512, 1024],
|
||||
default=DEFAULT_RESOLUTION,
|
||||
help="Target resolution for the model (256, 512, or 1024). Affects the transformer's sample_size.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--shift",
|
||||
type=float,
|
||||
default=3.0,
|
||||
help="Shift for the scheduler",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
success = main(args)
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"Conversion failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
@@ -10,7 +10,7 @@ from accelerate import init_empty_weights
|
||||
from diffusers import (
|
||||
SanaControlNetModel,
|
||||
)
|
||||
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.models.model_loading_utils import load_model_dict_into_meta
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from diffusers import (
|
||||
SanaTransformer2DModel,
|
||||
SCMScheduler,
|
||||
)
|
||||
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.models.model_loading_utils import load_model_dict_into_meta
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,327 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
from termcolor import colored
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLWan,
|
||||
DPMSolverMultistepScheduler,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
SanaVideoPipeline,
|
||||
SanaVideoTransformer3DModel,
|
||||
UniPCMultistepScheduler,
|
||||
)
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
|
||||
CTX = init_empty_weights if is_accelerate_available else nullcontext
|
||||
|
||||
ckpt_ids = ["Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth"]
|
||||
# https://github.com/NVlabs/Sana/blob/main/inference_video_scripts/inference_sana_video.py
|
||||
|
||||
|
||||
def main(args):
|
||||
cache_dir_path = os.path.expanduser("~/.cache/huggingface/hub")
|
||||
|
||||
if args.orig_ckpt_path is None or args.orig_ckpt_path in ckpt_ids:
|
||||
ckpt_id = args.orig_ckpt_path or ckpt_ids[0]
|
||||
snapshot_download(
|
||||
repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}",
|
||||
cache_dir=cache_dir_path,
|
||||
repo_type="model",
|
||||
)
|
||||
file_path = hf_hub_download(
|
||||
repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}",
|
||||
filename=f"{'/'.join(ckpt_id.split('/')[2:])}",
|
||||
cache_dir=cache_dir_path,
|
||||
repo_type="model",
|
||||
)
|
||||
else:
|
||||
file_path = args.orig_ckpt_path
|
||||
|
||||
print(colored(f"Loading checkpoint from {file_path}", "green", attrs=["bold"]))
|
||||
all_state_dict = torch.load(file_path, weights_only=True)
|
||||
state_dict = all_state_dict.pop("state_dict")
|
||||
converted_state_dict = {}
|
||||
|
||||
# Patch embeddings.
|
||||
converted_state_dict["patch_embedding.weight"] = state_dict.pop("x_embedder.proj.weight")
|
||||
converted_state_dict["patch_embedding.bias"] = state_dict.pop("x_embedder.proj.bias")
|
||||
|
||||
# Caption projection.
|
||||
converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight")
|
||||
converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias")
|
||||
converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
|
||||
converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
|
||||
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
|
||||
"t_embedder.mlp.0.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
|
||||
"t_embedder.mlp.2.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
|
||||
|
||||
# Shared norm.
|
||||
converted_state_dict["time_embed.linear.weight"] = state_dict.pop("t_block.1.weight")
|
||||
converted_state_dict["time_embed.linear.bias"] = state_dict.pop("t_block.1.bias")
|
||||
|
||||
# y norm
|
||||
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
|
||||
|
||||
# scheduler
|
||||
flow_shift = 8.0
|
||||
if args.task == "i2v":
|
||||
assert args.scheduler_type == "flow-euler", "Scheduler type must be flow-euler for i2v task."
|
||||
|
||||
# model config
|
||||
layer_num = 20
|
||||
# Positional embedding interpolation scale.
|
||||
qk_norm = True
|
||||
|
||||
# sample size
|
||||
if args.video_size == 480:
|
||||
sample_size = 30 # Wan-VAE: 8xp2 downsample factor
|
||||
patch_size = (1, 2, 2)
|
||||
elif args.video_size == 720:
|
||||
sample_size = 22 # Wan-VAE: 32xp1 downsample factor
|
||||
patch_size = (1, 1, 1)
|
||||
else:
|
||||
raise ValueError(f"Video size {args.video_size} is not supported.")
|
||||
|
||||
for depth in range(layer_num):
|
||||
# Transformer blocks.
|
||||
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
|
||||
f"blocks.{depth}.scale_shift_table"
|
||||
)
|
||||
|
||||
# Linear Attention is all you need 🤘
|
||||
# Self attention.
|
||||
q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
|
||||
if qk_norm is not None:
|
||||
# Add Q/K normalization for self-attention (attn1) - needed for Sana-Sprint and Sana-1.5
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_q.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.attn.q_norm.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_k.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.attn.k_norm.weight"
|
||||
)
|
||||
# Projection.
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.attn.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop(
|
||||
f"blocks.{depth}.attn.proj.bias"
|
||||
)
|
||||
|
||||
# Feed-forward.
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.mlp.inverted_conv.conv.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.bias"] = state_dict.pop(
|
||||
f"blocks.{depth}.mlp.inverted_conv.conv.bias"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.mlp.depth_conv.conv.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.bias"] = state_dict.pop(
|
||||
f"blocks.{depth}.mlp.depth_conv.conv.bias"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_point.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.mlp.point_conv.conv.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_temp.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.mlp.t_conv.weight"
|
||||
)
|
||||
|
||||
# Cross-attention.
|
||||
q = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.weight")
|
||||
q_bias = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.bias")
|
||||
k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.weight"), 2, dim=0)
|
||||
k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.bias"), 2, dim=0)
|
||||
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
|
||||
if qk_norm is not None:
|
||||
# Add Q/K normalization for cross-attention (attn2) - needed for Sana-Sprint and Sana-1.5
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_q.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.cross_attn.q_norm.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_k.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.cross_attn.k_norm.weight"
|
||||
)
|
||||
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.cross_attn.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = state_dict.pop(
|
||||
f"blocks.{depth}.cross_attn.proj.bias"
|
||||
)
|
||||
|
||||
# Final block.
|
||||
converted_state_dict["proj_out.weight"] = state_dict.pop("final_layer.linear.weight")
|
||||
converted_state_dict["proj_out.bias"] = state_dict.pop("final_layer.linear.bias")
|
||||
converted_state_dict["scale_shift_table"] = state_dict.pop("final_layer.scale_shift_table")
|
||||
|
||||
# Transformer
|
||||
with CTX():
|
||||
transformer_kwargs = {
|
||||
"in_channels": 16,
|
||||
"out_channels": 16,
|
||||
"num_attention_heads": 20,
|
||||
"attention_head_dim": 112,
|
||||
"num_layers": 20,
|
||||
"num_cross_attention_heads": 20,
|
||||
"cross_attention_head_dim": 112,
|
||||
"cross_attention_dim": 2240,
|
||||
"caption_channels": 2304,
|
||||
"mlp_ratio": 3.0,
|
||||
"attention_bias": False,
|
||||
"sample_size": sample_size,
|
||||
"patch_size": patch_size,
|
||||
"norm_elementwise_affine": False,
|
||||
"norm_eps": 1e-6,
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"rope_max_seq_len": 1024,
|
||||
}
|
||||
|
||||
transformer = SanaVideoTransformer3DModel(**transformer_kwargs)
|
||||
|
||||
transformer.load_state_dict(converted_state_dict, strict=True, assign=True)
|
||||
|
||||
try:
|
||||
state_dict.pop("y_embedder.y_embedding")
|
||||
state_dict.pop("pos_embed")
|
||||
state_dict.pop("logvar_linear.weight")
|
||||
state_dict.pop("logvar_linear.bias")
|
||||
except KeyError:
|
||||
print("y_embedder.y_embedding or pos_embed not found in the state_dict")
|
||||
|
||||
assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}"
|
||||
|
||||
num_model_params = sum(p.numel() for p in transformer.parameters())
|
||||
print(f"Total number of transformer parameters: {num_model_params}")
|
||||
|
||||
transformer = transformer.to(weight_dtype)
|
||||
|
||||
if not args.save_full_pipeline:
|
||||
print(
|
||||
colored(
|
||||
f"Only saving transformer model of {args.model_type}. "
|
||||
f"Set --save_full_pipeline to save the whole Pipeline",
|
||||
"green",
|
||||
attrs=["bold"],
|
||||
)
|
||||
)
|
||||
transformer.save_pretrained(
|
||||
os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB"
|
||||
)
|
||||
else:
|
||||
print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"]))
|
||||
# VAE
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
|
||||
)
|
||||
|
||||
# Text Encoder
|
||||
text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it"
|
||||
tokenizer = AutoTokenizer.from_pretrained(text_encoder_model_path)
|
||||
tokenizer.padding_side = "right"
|
||||
text_encoder = AutoModelForCausalLM.from_pretrained(
|
||||
text_encoder_model_path, torch_dtype=torch.bfloat16
|
||||
).get_decoder()
|
||||
|
||||
# Choose the appropriate pipeline and scheduler based on model type
|
||||
# Original Sana scheduler
|
||||
if args.scheduler_type == "flow-dpm_solver":
|
||||
scheduler = DPMSolverMultistepScheduler(
|
||||
flow_shift=flow_shift,
|
||||
use_flow_sigmas=True,
|
||||
prediction_type="flow_prediction",
|
||||
)
|
||||
elif args.scheduler_type == "flow-euler":
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
|
||||
elif args.scheduler_type == "uni-pc":
|
||||
scheduler = UniPCMultistepScheduler(
|
||||
prediction_type="flow_prediction",
|
||||
use_flow_sigmas=True,
|
||||
num_train_timesteps=1000,
|
||||
flow_shift=flow_shift,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Scheduler type {args.scheduler_type} is not supported")
|
||||
|
||||
pipe = SanaVideoPipeline(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
transformer=transformer,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
|
||||
DTYPE_MAPPING = {
|
||||
"fp32": torch.float32,
|
||||
"fp16": torch.float16,
|
||||
"bf16": torch.bfloat16,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--orig_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--video_size",
|
||||
default=480,
|
||||
type=int,
|
||||
choices=[480, 720],
|
||||
required=False,
|
||||
help="Video size of pretrained model, 480 or 720.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default="SanaVideo",
|
||||
type=str,
|
||||
choices=[
|
||||
"SanaVideo",
|
||||
],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scheduler_type",
|
||||
default="flow-dpm_solver",
|
||||
type=str,
|
||||
choices=["flow-dpm_solver", "flow-euler", "uni-pc"],
|
||||
help="Scheduler type to use.",
|
||||
)
|
||||
parser.add_argument("--task", default="t2v", type=str, required=True, help="Task to convert, t2v or i2v.")
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
|
||||
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.")
|
||||
parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
weight_dtype = DTYPE_MAPPING[args.dtype]
|
||||
|
||||
main(args)
|
||||
@@ -7,7 +7,7 @@ from accelerate import init_empty_weights
|
||||
|
||||
from diffusers import AutoencoderKL, SD3Transformer2DModel
|
||||
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
|
||||
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.models.model_loading_utils import load_model_dict_into_meta
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from diffusers import (
|
||||
StableAudioPipeline,
|
||||
StableAudioProjectionModel,
|
||||
)
|
||||
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.models.model_loading_utils import load_model_dict_into_meta
|
||||
from diffusers.utils import is_accelerate_available
|
||||
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
|
||||
from diffusers.models import StableCascadeUNet
|
||||
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.models.model_loading_utils import load_model_dict_into_meta
|
||||
from diffusers.pipelines.wuerstchen import PaellaVQModel
|
||||
from diffusers.utils import is_accelerate_available
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers
|
||||
from diffusers.models import StableCascadeUNet
|
||||
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.models.model_loading_utils import load_model_dict_into_meta
|
||||
from diffusers.pipelines.wuerstchen import PaellaVQModel
|
||||
from diffusers.utils import is_accelerate_available
|
||||
|
||||
|
||||
@@ -6,11 +6,20 @@ import torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoProcessor, AutoTokenizer, CLIPVisionModelWithProjection, UMT5EncoderModel
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
CLIPImageProcessor,
|
||||
CLIPVisionModel,
|
||||
CLIPVisionModelWithProjection,
|
||||
UMT5EncoderModel,
|
||||
)
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLWan,
|
||||
UniPCMultistepScheduler,
|
||||
WanAnimatePipeline,
|
||||
WanAnimateTransformer3DModel,
|
||||
WanImageToVideoPipeline,
|
||||
WanPipeline,
|
||||
WanTransformer3DModel,
|
||||
@@ -105,8 +114,203 @@ VACE_TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"after_proj": "proj_out",
|
||||
}
|
||||
|
||||
ANIMATE_TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"time_embedding.0": "condition_embedder.time_embedder.linear_1",
|
||||
"time_embedding.2": "condition_embedder.time_embedder.linear_2",
|
||||
"text_embedding.0": "condition_embedder.text_embedder.linear_1",
|
||||
"text_embedding.2": "condition_embedder.text_embedder.linear_2",
|
||||
"time_projection.1": "condition_embedder.time_proj",
|
||||
"head.modulation": "scale_shift_table",
|
||||
"head.head": "proj_out",
|
||||
"modulation": "scale_shift_table",
|
||||
"ffn.0": "ffn.net.0.proj",
|
||||
"ffn.2": "ffn.net.2",
|
||||
# Hack to swap the layer names
|
||||
# The original model calls the norms in following order: norm1, norm3, norm2
|
||||
# We convert it to: norm1, norm2, norm3
|
||||
"norm2": "norm__placeholder",
|
||||
"norm3": "norm2",
|
||||
"norm__placeholder": "norm3",
|
||||
"img_emb.proj.0": "condition_embedder.image_embedder.norm1",
|
||||
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
|
||||
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
|
||||
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
|
||||
# Add attention component mappings
|
||||
"self_attn.q": "attn1.to_q",
|
||||
"self_attn.k": "attn1.to_k",
|
||||
"self_attn.v": "attn1.to_v",
|
||||
"self_attn.o": "attn1.to_out.0",
|
||||
"self_attn.norm_q": "attn1.norm_q",
|
||||
"self_attn.norm_k": "attn1.norm_k",
|
||||
"cross_attn.q": "attn2.to_q",
|
||||
"cross_attn.k": "attn2.to_k",
|
||||
"cross_attn.v": "attn2.to_v",
|
||||
"cross_attn.o": "attn2.to_out.0",
|
||||
"cross_attn.norm_q": "attn2.norm_q",
|
||||
"cross_attn.norm_k": "attn2.norm_k",
|
||||
"cross_attn.k_img": "attn2.to_k_img",
|
||||
"cross_attn.v_img": "attn2.to_v_img",
|
||||
"cross_attn.norm_k_img": "attn2.norm_k_img",
|
||||
# After cross_attn -> attn2 rename, we need to rename the img keys
|
||||
"attn2.to_k_img": "attn2.add_k_proj",
|
||||
"attn2.to_v_img": "attn2.add_v_proj",
|
||||
"attn2.norm_k_img": "attn2.norm_added_k",
|
||||
# Wan Animate-specific mappings (motion encoder, face encoder, face adapter)
|
||||
# Motion encoder mappings
|
||||
# The name mapping is complicated for the convolutional part so we handle that in its own function
|
||||
"motion_encoder.enc.fc": "motion_encoder.motion_network",
|
||||
"motion_encoder.dec.direction.weight": "motion_encoder.motion_synthesis_weight",
|
||||
# Face encoder mappings - CausalConv1d has a .conv submodule that we need to flatten
|
||||
"face_encoder.conv1_local.conv": "face_encoder.conv1_local",
|
||||
"face_encoder.conv2.conv": "face_encoder.conv2",
|
||||
"face_encoder.conv3.conv": "face_encoder.conv3",
|
||||
# Face adapter mappings are handled in a separate function
|
||||
}
|
||||
|
||||
|
||||
# TODO: Verify this and simplify if possible.
|
||||
def convert_animate_motion_encoder_weights(key: str, state_dict: Dict[str, Any], final_conv_idx: int = 8) -> None:
|
||||
"""
|
||||
Convert all motion encoder weights for Animate model.
|
||||
|
||||
In the original model:
|
||||
- All Linear layers in fc use EqualLinear
|
||||
- All Conv2d layers in convs use EqualConv2d (except blur_conv which is initialized separately)
|
||||
- Blur kernels are stored as buffers in Sequential modules
|
||||
- ConvLayer is nn.Sequential with indices: [Blur (optional), EqualConv2d, FusedLeakyReLU (optional)]
|
||||
|
||||
Conversion strategy:
|
||||
1. Drop .kernel buffers (blur kernels)
|
||||
2. Rename sequential indices to named components (e.g., 0 -> conv2d, 1 -> bias_leaky_relu)
|
||||
"""
|
||||
# Skip if not a weight, bias, or kernel
|
||||
if ".weight" not in key and ".bias" not in key and ".kernel" not in key:
|
||||
return
|
||||
|
||||
# Handle Blur kernel buffers from original implementation.
|
||||
# After renaming, these appear under: motion_encoder.res_blocks.*.conv{2,skip}.blur_kernel
|
||||
# Diffusers constructs blur kernels as a non-persistent buffer so we must drop these keys
|
||||
if ".kernel" in key and "motion_encoder" in key:
|
||||
# Remove unexpected blur kernel buffers to avoid strict load errors
|
||||
state_dict.pop(key, None)
|
||||
return
|
||||
|
||||
# Rename Sequential indices to named components in ConvLayer and ResBlock
|
||||
if ".enc.net_app.convs." in key and (".weight" in key or ".bias" in key):
|
||||
parts = key.split(".")
|
||||
|
||||
# Find the sequential index (digit) after convs or after conv1/conv2/skip
|
||||
# Examples:
|
||||
# - enc.net_app.convs.0.0.weight -> conv_in.weight (initial conv layer weight)
|
||||
# - enc.net_app.convs.0.1.bias -> conv_in.act_fn.bias (initial conv layer bias)
|
||||
# - enc.net_app.convs.{n:1-7}.conv1.0.weight -> res_blocks.{(n-1):0-6}.conv1.weight (conv1 weight)
|
||||
# - e.g. enc.net_app.convs.1.conv1.0.weight -> res_blocks.0.conv1.weight
|
||||
# - enc.net_app.convs.{n:1-7}.conv1.1.bias -> res_blocks.{(n-1):0-6}.conv1.act_fn.bias (conv1 bias)
|
||||
# - e.g. enc.net_app.convs.1.conv1.1.bias -> res_blocks.0.conv1.act_fn.bias
|
||||
# - enc.net_app.convs.{n:1-7}.conv2.1.weight -> res_blocks.{(n-1):0-6}.conv2.weight (conv2 weight)
|
||||
# - enc.net_app.convs.1.conv2.2.bias -> res_blocks.0.conv2.act_fn.bias (conv2 bias)
|
||||
# - enc.net_app.convs.{n:1-7}.skip.1.weight -> res_blocks.{(n-1):0-6}.conv_skip.weight (skip conv weight)
|
||||
# - enc.net_app.convs.8 -> conv_out (final conv layer)
|
||||
|
||||
convs_idx = parts.index("convs") if "convs" in parts else -1
|
||||
if convs_idx >= 0 and len(parts) - convs_idx >= 2:
|
||||
bias = False
|
||||
# The nn.Sequential index will always follow convs
|
||||
sequential_idx = int(parts[convs_idx + 1])
|
||||
if sequential_idx == 0:
|
||||
if key.endswith(".weight"):
|
||||
new_key = "motion_encoder.conv_in.weight"
|
||||
elif key.endswith(".bias"):
|
||||
new_key = "motion_encoder.conv_in.act_fn.bias"
|
||||
bias = True
|
||||
elif sequential_idx == final_conv_idx:
|
||||
if key.endswith(".weight"):
|
||||
new_key = "motion_encoder.conv_out.weight"
|
||||
else:
|
||||
# Intermediate .convs. layers, which get mapped to .res_blocks.
|
||||
prefix = "motion_encoder.res_blocks."
|
||||
|
||||
layer_name = parts[convs_idx + 2]
|
||||
if layer_name == "skip":
|
||||
layer_name = "conv_skip"
|
||||
|
||||
if key.endswith(".weight"):
|
||||
param_name = "weight"
|
||||
elif key.endswith(".bias"):
|
||||
param_name = "act_fn.bias"
|
||||
bias = True
|
||||
|
||||
suffix_parts = [str(sequential_idx - 1), layer_name, param_name]
|
||||
suffix = ".".join(suffix_parts)
|
||||
new_key = prefix + suffix
|
||||
|
||||
param = state_dict.pop(key)
|
||||
if bias:
|
||||
param = param.squeeze()
|
||||
state_dict[new_key] = param
|
||||
return
|
||||
return
|
||||
return
|
||||
|
||||
|
||||
def convert_animate_face_adapter_weights(key: str, state_dict: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Convert face adapter weights for the Animate model.
|
||||
|
||||
The original model uses a fused KV projection but the diffusers models uses separate K and V projections.
|
||||
"""
|
||||
# Skip if not a weight or bias
|
||||
if ".weight" not in key and ".bias" not in key:
|
||||
return
|
||||
|
||||
prefix = "face_adapter."
|
||||
if ".fuser_blocks." in key:
|
||||
parts = key.split(".")
|
||||
|
||||
module_list_idx = parts.index("fuser_blocks") if "fuser_blocks" in parts else -1
|
||||
if module_list_idx >= 0 and (len(parts) - 1) - module_list_idx == 3:
|
||||
block_idx = parts[module_list_idx + 1]
|
||||
layer_name = parts[module_list_idx + 2]
|
||||
param_name = parts[module_list_idx + 3]
|
||||
|
||||
if layer_name == "linear1_kv":
|
||||
layer_name_k = "to_k"
|
||||
layer_name_v = "to_v"
|
||||
|
||||
suffix_k = ".".join([block_idx, layer_name_k, param_name])
|
||||
suffix_v = ".".join([block_idx, layer_name_v, param_name])
|
||||
new_key_k = prefix + suffix_k
|
||||
new_key_v = prefix + suffix_v
|
||||
|
||||
kv_proj = state_dict.pop(key)
|
||||
k_proj, v_proj = torch.chunk(kv_proj, 2, dim=0)
|
||||
state_dict[new_key_k] = k_proj
|
||||
state_dict[new_key_v] = v_proj
|
||||
return
|
||||
else:
|
||||
if layer_name == "q_norm":
|
||||
new_layer_name = "norm_q"
|
||||
elif layer_name == "k_norm":
|
||||
new_layer_name = "norm_k"
|
||||
elif layer_name == "linear1_q":
|
||||
new_layer_name = "to_q"
|
||||
elif layer_name == "linear2":
|
||||
new_layer_name = "to_out"
|
||||
|
||||
suffix_parts = [block_idx, new_layer_name, param_name]
|
||||
suffix = ".".join(suffix_parts)
|
||||
new_key = prefix + suffix
|
||||
state_dict[new_key] = state_dict.pop(key)
|
||||
return
|
||||
return
|
||||
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
|
||||
VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {}
|
||||
ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"motion_encoder": convert_animate_motion_encoder_weights,
|
||||
"face_adapter": convert_animate_face_adapter_weights,
|
||||
}
|
||||
|
||||
|
||||
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
|
||||
@@ -364,6 +568,37 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
|
||||
}
|
||||
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif model_type == "Wan2.2-Animate-14B":
|
||||
config = {
|
||||
"model_id": "Wan-AI/Wan2.2-Animate-14B",
|
||||
"diffusers_config": {
|
||||
"image_dim": 1280,
|
||||
"added_kv_proj_dim": 5120,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attn_norm": True,
|
||||
"eps": 1e-06,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"in_channels": 36,
|
||||
"num_attention_heads": 40,
|
||||
"num_layers": 40,
|
||||
"out_channels": 16,
|
||||
"patch_size": (1, 2, 2),
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"text_dim": 4096,
|
||||
"rope_max_seq_len": 1024,
|
||||
"pos_embed_seq_len": None,
|
||||
"motion_encoder_size": 512, # Start of Wan Animate-specific configs
|
||||
"motion_style_dim": 512,
|
||||
"motion_dim": 20,
|
||||
"motion_encoder_dim": 512,
|
||||
"face_encoder_hidden_dim": 1024,
|
||||
"face_encoder_num_heads": 4,
|
||||
"inject_face_latents_blocks": 5,
|
||||
},
|
||||
}
|
||||
RENAME_DICT = ANIMATE_TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
return config, RENAME_DICT, SPECIAL_KEYS_REMAP
|
||||
|
||||
|
||||
@@ -380,10 +615,12 @@ def convert_transformer(model_type: str, stage: str = None):
|
||||
original_state_dict = load_sharded_safetensors(model_dir)
|
||||
|
||||
with init_empty_weights():
|
||||
if "VACE" not in model_type:
|
||||
transformer = WanTransformer3DModel.from_config(diffusers_config)
|
||||
else:
|
||||
if "Animate" in model_type:
|
||||
transformer = WanAnimateTransformer3DModel.from_config(diffusers_config)
|
||||
elif "VACE" in model_type:
|
||||
transformer = WanVACETransformer3DModel.from_config(diffusers_config)
|
||||
else:
|
||||
transformer = WanTransformer3DModel.from_config(diffusers_config)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
@@ -397,7 +634,12 @@ def convert_transformer(model_type: str, stage: str = None):
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
# Load state dict into the meta model, which will materialize the tensors
|
||||
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
|
||||
# Move to CPU to ensure all tensors are materialized
|
||||
transformer = transformer.to("cpu")
|
||||
|
||||
return transformer
|
||||
|
||||
|
||||
@@ -926,7 +1168,7 @@ DTYPE_MAPPING = {
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
if "Wan2.2" in args.model_type and "TI2V" not in args.model_type:
|
||||
if "Wan2.2" in args.model_type and "TI2V" not in args.model_type and "Animate" not in args.model_type:
|
||||
transformer = convert_transformer(args.model_type, stage="high_noise_model")
|
||||
transformer_2 = convert_transformer(args.model_type, stage="low_noise_model")
|
||||
else:
|
||||
@@ -942,7 +1184,7 @@ if __name__ == "__main__":
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
|
||||
if "FLF2V" in args.model_type:
|
||||
flow_shift = 16.0
|
||||
elif "TI2V" in args.model_type:
|
||||
elif "TI2V" in args.model_type or "Animate" in args.model_type:
|
||||
flow_shift = 5.0
|
||||
else:
|
||||
flow_shift = 3.0
|
||||
@@ -954,6 +1196,8 @@ if __name__ == "__main__":
|
||||
if args.dtype != "none":
|
||||
dtype = DTYPE_MAPPING[args.dtype]
|
||||
transformer.to(dtype)
|
||||
if transformer_2 is not None:
|
||||
transformer_2.to(dtype)
|
||||
|
||||
if "Wan2.2" and "I2V" in args.model_type and "TI2V" not in args.model_type:
|
||||
pipe = WanImageToVideoPipeline(
|
||||
@@ -1016,6 +1260,21 @@ if __name__ == "__main__":
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
elif "Animate" in args.model_type:
|
||||
image_encoder = CLIPVisionModel.from_pretrained(
|
||||
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
|
||||
)
|
||||
image_processor = CLIPImageProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
|
||||
|
||||
pipe = WanAnimatePipeline(
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
image_encoder=image_encoder,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
else:
|
||||
pipe = WanPipeline(
|
||||
transformer=transformer,
|
||||
|
||||
@@ -149,7 +149,9 @@ else:
|
||||
_import_structure["guiders"].extend(
|
||||
[
|
||||
"AdaptiveProjectedGuidance",
|
||||
"AdaptiveProjectedMixGuidance",
|
||||
"AutoGuidance",
|
||||
"BaseGuidance",
|
||||
"ClassifierFreeGuidance",
|
||||
"ClassifierFreeZeroStarGuidance",
|
||||
"FrequencyDecoupledGuidance",
|
||||
@@ -184,6 +186,9 @@ else:
|
||||
"AutoencoderKLAllegro",
|
||||
"AutoencoderKLCogVideoX",
|
||||
"AutoencoderKLCosmos",
|
||||
"AutoencoderKLFlux2",
|
||||
"AutoencoderKLHunyuanImage",
|
||||
"AutoencoderKLHunyuanImageRefiner",
|
||||
"AutoencoderKLHunyuanVideo",
|
||||
"AutoencoderKLLTXVideo",
|
||||
"AutoencoderKLMagvit",
|
||||
@@ -194,9 +199,11 @@ else:
|
||||
"AutoencoderOobleck",
|
||||
"AutoencoderTiny",
|
||||
"AutoModel",
|
||||
"BriaFiboTransformer2DModel",
|
||||
"BriaTransformer2DModel",
|
||||
"CacheMixin",
|
||||
"ChromaTransformer2DModel",
|
||||
"ChronoEditTransformer3DModel",
|
||||
"CogVideoXTransformer3DModel",
|
||||
"CogView3PlusTransformer2DModel",
|
||||
"CogView4Transformer2DModel",
|
||||
@@ -209,6 +216,7 @@ else:
|
||||
"CosmosTransformer3DModel",
|
||||
"DiTTransformer2DModel",
|
||||
"EasyAnimateTransformer3DModel",
|
||||
"Flux2Transformer2DModel",
|
||||
"FluxControlNetModel",
|
||||
"FluxMultiControlNetModel",
|
||||
"FluxTransformer2DModel",
|
||||
@@ -216,6 +224,7 @@ else:
|
||||
"HunyuanDiT2DControlNetModel",
|
||||
"HunyuanDiT2DModel",
|
||||
"HunyuanDiT2DMultiControlNetModel",
|
||||
"HunyuanImageTransformer2DModel",
|
||||
"HunyuanVideoFramepackTransformer3DModel",
|
||||
"HunyuanVideoTransformer3DModel",
|
||||
"I2VGenXLUNet",
|
||||
@@ -234,11 +243,13 @@ else:
|
||||
"ParallelConfig",
|
||||
"PixArtTransformer2DModel",
|
||||
"PriorTransformer",
|
||||
"PRXTransformer2DModel",
|
||||
"QwenImageControlNetModel",
|
||||
"QwenImageMultiControlNetModel",
|
||||
"QwenImageTransformer2DModel",
|
||||
"SanaControlNetModel",
|
||||
"SanaTransformer2DModel",
|
||||
"SanaVideoTransformer3DModel",
|
||||
"SD3ControlNetModel",
|
||||
"SD3MultiControlNetModel",
|
||||
"SD3Transformer2DModel",
|
||||
@@ -259,8 +270,10 @@ else:
|
||||
"UNetSpatioTemporalConditionModel",
|
||||
"UVit2DModel",
|
||||
"VQModel",
|
||||
"WanAnimateTransformer3DModel",
|
||||
"WanTransformer3DModel",
|
||||
"WanVACETransformer3DModel",
|
||||
"ZImageTransformer2DModel",
|
||||
"attention_backend",
|
||||
]
|
||||
)
|
||||
@@ -398,6 +411,7 @@ else:
|
||||
"QwenImageModularPipeline",
|
||||
"StableDiffusionXLAutoBlocks",
|
||||
"StableDiffusionXLModularPipeline",
|
||||
"Wan22AutoBlocks",
|
||||
"WanAutoBlocks",
|
||||
"WanModularPipeline",
|
||||
]
|
||||
@@ -424,9 +438,11 @@ else:
|
||||
"AuraFlowPipeline",
|
||||
"BlipDiffusionControlNetPipeline",
|
||||
"BlipDiffusionPipeline",
|
||||
"BriaFiboPipeline",
|
||||
"BriaPipeline",
|
||||
"ChromaImg2ImgPipeline",
|
||||
"ChromaPipeline",
|
||||
"ChronoEditPipeline",
|
||||
"CLIPImageProjection",
|
||||
"CogVideoXFunControlPipeline",
|
||||
"CogVideoXImageToVideoPipeline",
|
||||
@@ -444,6 +460,7 @@ else:
|
||||
"EasyAnimateControlPipeline",
|
||||
"EasyAnimateInpaintPipeline",
|
||||
"EasyAnimatePipeline",
|
||||
"Flux2Pipeline",
|
||||
"FluxControlImg2ImgPipeline",
|
||||
"FluxControlInpaintPipeline",
|
||||
"FluxControlNetImg2ImgPipeline",
|
||||
@@ -461,6 +478,8 @@ else:
|
||||
"HunyuanDiTControlNetPipeline",
|
||||
"HunyuanDiTPAGPipeline",
|
||||
"HunyuanDiTPipeline",
|
||||
"HunyuanImagePipeline",
|
||||
"HunyuanImageRefinerPipeline",
|
||||
"HunyuanSkyreelsImageToVideoPipeline",
|
||||
"HunyuanVideoFramepackPipeline",
|
||||
"HunyuanVideoImageToVideoPipeline",
|
||||
@@ -519,6 +538,7 @@ else:
|
||||
"PixArtAlphaPipeline",
|
||||
"PixArtSigmaPAGPipeline",
|
||||
"PixArtSigmaPipeline",
|
||||
"PRXPipeline",
|
||||
"QwenImageControlNetInpaintPipeline",
|
||||
"QwenImageControlNetPipeline",
|
||||
"QwenImageEditInpaintPipeline",
|
||||
@@ -529,10 +549,13 @@ else:
|
||||
"QwenImagePipeline",
|
||||
"ReduxImageEncoder",
|
||||
"SanaControlNetPipeline",
|
||||
"SanaImageToVideoPipeline",
|
||||
"SanaPAGPipeline",
|
||||
"SanaPipeline",
|
||||
"SanaSprintImg2ImgPipeline",
|
||||
"SanaSprintPipeline",
|
||||
"SanaVideoPipeline",
|
||||
"SanaVideoPipeline",
|
||||
"SemanticStableDiffusionPipeline",
|
||||
"ShapEImg2ImgPipeline",
|
||||
"ShapEPipeline",
|
||||
@@ -620,6 +643,7 @@ else:
|
||||
"VisualClozeGenerationPipeline",
|
||||
"VisualClozePipeline",
|
||||
"VQDiffusionPipeline",
|
||||
"WanAnimatePipeline",
|
||||
"WanImageToVideoPipeline",
|
||||
"WanPipeline",
|
||||
"WanVACEPipeline",
|
||||
@@ -627,6 +651,7 @@ else:
|
||||
"WuerstchenCombinedPipeline",
|
||||
"WuerstchenDecoderPipeline",
|
||||
"WuerstchenPriorPipeline",
|
||||
"ZImagePipeline",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -847,7 +872,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .guiders import (
|
||||
AdaptiveProjectedGuidance,
|
||||
AdaptiveProjectedMixGuidance,
|
||||
AutoGuidance,
|
||||
BaseGuidance,
|
||||
ClassifierFreeGuidance,
|
||||
ClassifierFreeZeroStarGuidance,
|
||||
FrequencyDecoupledGuidance,
|
||||
@@ -878,6 +905,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLAllegro,
|
||||
AutoencoderKLCogVideoX,
|
||||
AutoencoderKLCosmos,
|
||||
AutoencoderKLFlux2,
|
||||
AutoencoderKLHunyuanImage,
|
||||
AutoencoderKLHunyuanImageRefiner,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMagvit,
|
||||
@@ -888,9 +918,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderOobleck,
|
||||
AutoencoderTiny,
|
||||
AutoModel,
|
||||
BriaFiboTransformer2DModel,
|
||||
BriaTransformer2DModel,
|
||||
CacheMixin,
|
||||
ChromaTransformer2DModel,
|
||||
ChronoEditTransformer3DModel,
|
||||
CogVideoXTransformer3DModel,
|
||||
CogView3PlusTransformer2DModel,
|
||||
CogView4Transformer2DModel,
|
||||
@@ -903,6 +935,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
CosmosTransformer3DModel,
|
||||
DiTTransformer2DModel,
|
||||
EasyAnimateTransformer3DModel,
|
||||
Flux2Transformer2DModel,
|
||||
FluxControlNetModel,
|
||||
FluxMultiControlNetModel,
|
||||
FluxTransformer2DModel,
|
||||
@@ -910,6 +943,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
HunyuanDiT2DControlNetModel,
|
||||
HunyuanDiT2DModel,
|
||||
HunyuanDiT2DMultiControlNetModel,
|
||||
HunyuanImageTransformer2DModel,
|
||||
HunyuanVideoFramepackTransformer3DModel,
|
||||
HunyuanVideoTransformer3DModel,
|
||||
I2VGenXLUNet,
|
||||
@@ -928,11 +962,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ParallelConfig,
|
||||
PixArtTransformer2DModel,
|
||||
PriorTransformer,
|
||||
PRXTransformer2DModel,
|
||||
QwenImageControlNetModel,
|
||||
QwenImageMultiControlNetModel,
|
||||
QwenImageTransformer2DModel,
|
||||
SanaControlNetModel,
|
||||
SanaTransformer2DModel,
|
||||
SanaVideoTransformer3DModel,
|
||||
SD3ControlNetModel,
|
||||
SD3MultiControlNetModel,
|
||||
SD3Transformer2DModel,
|
||||
@@ -952,6 +988,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
UNetSpatioTemporalConditionModel,
|
||||
UVit2DModel,
|
||||
VQModel,
|
||||
WanAnimateTransformer3DModel,
|
||||
WanTransformer3DModel,
|
||||
WanVACETransformer3DModel,
|
||||
attention_backend,
|
||||
@@ -1066,6 +1103,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
QwenImageModularPipeline,
|
||||
StableDiffusionXLAutoBlocks,
|
||||
StableDiffusionXLModularPipeline,
|
||||
Wan22AutoBlocks,
|
||||
WanAutoBlocks,
|
||||
WanModularPipeline,
|
||||
)
|
||||
@@ -1088,9 +1126,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AudioLDM2UNet2DConditionModel,
|
||||
AudioLDMPipeline,
|
||||
AuraFlowPipeline,
|
||||
BriaFiboPipeline,
|
||||
BriaPipeline,
|
||||
ChromaImg2ImgPipeline,
|
||||
ChromaPipeline,
|
||||
ChronoEditPipeline,
|
||||
CLIPImageProjection,
|
||||
CogVideoXFunControlPipeline,
|
||||
CogVideoXImageToVideoPipeline,
|
||||
@@ -1108,6 +1148,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
EasyAnimateControlPipeline,
|
||||
EasyAnimateInpaintPipeline,
|
||||
EasyAnimatePipeline,
|
||||
Flux2Pipeline,
|
||||
FluxControlImg2ImgPipeline,
|
||||
FluxControlInpaintPipeline,
|
||||
FluxControlNetImg2ImgPipeline,
|
||||
@@ -1125,6 +1166,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
HunyuanDiTControlNetPipeline,
|
||||
HunyuanDiTPAGPipeline,
|
||||
HunyuanDiTPipeline,
|
||||
HunyuanImagePipeline,
|
||||
HunyuanImageRefinerPipeline,
|
||||
HunyuanSkyreelsImageToVideoPipeline,
|
||||
HunyuanVideoFramepackPipeline,
|
||||
HunyuanVideoImageToVideoPipeline,
|
||||
@@ -1183,6 +1226,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
PixArtAlphaPipeline,
|
||||
PixArtSigmaPAGPipeline,
|
||||
PixArtSigmaPipeline,
|
||||
PRXPipeline,
|
||||
QwenImageControlNetInpaintPipeline,
|
||||
QwenImageControlNetPipeline,
|
||||
QwenImageEditInpaintPipeline,
|
||||
@@ -1193,10 +1237,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
QwenImagePipeline,
|
||||
ReduxImageEncoder,
|
||||
SanaControlNetPipeline,
|
||||
SanaImageToVideoPipeline,
|
||||
SanaPAGPipeline,
|
||||
SanaPipeline,
|
||||
SanaSprintImg2ImgPipeline,
|
||||
SanaSprintPipeline,
|
||||
SanaVideoPipeline,
|
||||
SemanticStableDiffusionPipeline,
|
||||
ShapEImg2ImgPipeline,
|
||||
ShapEPipeline,
|
||||
@@ -1283,6 +1329,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
VisualClozeGenerationPipeline,
|
||||
VisualClozePipeline,
|
||||
VQDiffusionPipeline,
|
||||
WanAnimatePipeline,
|
||||
WanImageToVideoPipeline,
|
||||
WanPipeline,
|
||||
WanVACEPipeline,
|
||||
@@ -1290,6 +1337,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
WuerstchenCombinedPipeline,
|
||||
WuerstchenDecoderPipeline,
|
||||
WuerstchenPriorPipeline,
|
||||
ZImagePipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -14,28 +14,18 @@
|
||||
|
||||
from typing import Union
|
||||
|
||||
from ..utils import is_torch_available
|
||||
from ..utils import is_torch_available, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .adaptive_projected_guidance import AdaptiveProjectedGuidance
|
||||
from .adaptive_projected_guidance_mix import AdaptiveProjectedMixGuidance
|
||||
from .auto_guidance import AutoGuidance
|
||||
from .classifier_free_guidance import ClassifierFreeGuidance
|
||||
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
|
||||
from .frequency_decoupled_guidance import FrequencyDecoupledGuidance
|
||||
from .guider_utils import BaseGuidance
|
||||
from .perturbed_attention_guidance import PerturbedAttentionGuidance
|
||||
from .skip_layer_guidance import SkipLayerGuidance
|
||||
from .smoothed_energy_guidance import SmoothedEnergyGuidance
|
||||
from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance
|
||||
|
||||
GuiderType = Union[
|
||||
AdaptiveProjectedGuidance,
|
||||
AutoGuidance,
|
||||
ClassifierFreeGuidance,
|
||||
ClassifierFreeZeroStarGuidance,
|
||||
FrequencyDecoupledGuidance,
|
||||
PerturbedAttentionGuidance,
|
||||
SkipLayerGuidance,
|
||||
SmoothedEnergyGuidance,
|
||||
TangentialClassifierFreeGuidance,
|
||||
]
|
||||
|
||||
@@ -65,8 +65,9 @@ class AdaptiveProjectedGuidance(BaseGuidance):
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
enabled: bool = True,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
super().__init__(start, stop, enabled)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
|
||||
@@ -76,19 +77,27 @@ class AdaptiveProjectedGuidance(BaseGuidance):
|
||||
self.use_original_formulation = use_original_formulation
|
||||
self.momentum_buffer = None
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
if self._step == 0:
|
||||
if self.adaptive_projected_guidance_momentum is not None:
|
||||
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
if self._step == 0:
|
||||
if self.adaptive_projected_guidance_momentum is not None:
|
||||
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
@@ -152,6 +161,44 @@ class MomentumBuffer:
|
||||
new_average = self.momentum * self.running_average
|
||||
self.running_average = update_value + new_average
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""
|
||||
Returns a string representation showing momentum, shape, statistics, and a slice of the running_average.
|
||||
"""
|
||||
if isinstance(self.running_average, torch.Tensor):
|
||||
shape = tuple(self.running_average.shape)
|
||||
|
||||
# Calculate statistics
|
||||
with torch.no_grad():
|
||||
stats = {
|
||||
"mean": self.running_average.mean().item(),
|
||||
"std": self.running_average.std().item(),
|
||||
"min": self.running_average.min().item(),
|
||||
"max": self.running_average.max().item(),
|
||||
}
|
||||
|
||||
# Get a slice (max 3 elements per dimension)
|
||||
slice_indices = tuple(slice(None, min(3, dim)) for dim in shape)
|
||||
sliced_data = self.running_average[slice_indices]
|
||||
|
||||
# Format the slice for display (convert to float32 for numpy compatibility with bfloat16)
|
||||
slice_str = str(sliced_data.detach().float().cpu().numpy())
|
||||
if len(slice_str) > 200: # Truncate if too long
|
||||
slice_str = slice_str[:200] + "..."
|
||||
|
||||
stats_str = ", ".join([f"{k}={v:.4f}" for k, v in stats.items()])
|
||||
|
||||
return (
|
||||
f"MomentumBuffer(\n"
|
||||
f" momentum={self.momentum},\n"
|
||||
f" shape={shape},\n"
|
||||
f" stats=[{stats_str}],\n"
|
||||
f" slice={slice_str}\n"
|
||||
f")"
|
||||
)
|
||||
else:
|
||||
return f"MomentumBuffer(momentum={self.momentum}, running_average={self.running_average})"
|
||||
|
||||
|
||||
def normalized_guidance(
|
||||
pred_cond: torch.Tensor,
|
||||
|
||||
@@ -0,0 +1,297 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import register_to_config
|
||||
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
class AdaptiveProjectedMixGuidance(BaseGuidance):
|
||||
"""
|
||||
Adaptive Projected Guidance (APG) https://huggingface.co/papers/2410.02416 combined with Classifier-Free Guidance
|
||||
(CFG). This guider is used in HunyuanImage2.1 https://github.com/Tencent-Hunyuan/HunyuanImage-2.1
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
adaptive_projected_guidance_momentum (`float`, defaults to `None`):
|
||||
The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
|
||||
adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
|
||||
The rescale factor applied to the noise predictions for adaptive projected guidance. This is used to
|
||||
improve image quality and fix
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions for classifier-free guidance. This is used to improve
|
||||
image quality and fix overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample
|
||||
Steps are Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.0`):
|
||||
The fraction of the total number of denoising steps after which the classifier-free guidance starts.
|
||||
stop (`float`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which the classifier-free guidance stops.
|
||||
adaptive_projected_guidance_start_step (`int`, defaults to `5`):
|
||||
The step at which the adaptive projected guidance starts (before this step, classifier-free guidance is
|
||||
used, and momentum buffer is updated).
|
||||
enabled (`bool`, defaults to `True`):
|
||||
Whether this guidance is enabled.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 3.5,
|
||||
guidance_rescale: float = 0.0,
|
||||
adaptive_projected_guidance_scale: float = 10.0,
|
||||
adaptive_projected_guidance_momentum: float = -0.5,
|
||||
adaptive_projected_guidance_rescale: float = 10.0,
|
||||
eta: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
adaptive_projected_guidance_start_step: int = 5,
|
||||
enabled: bool = True,
|
||||
):
|
||||
super().__init__(start, stop, enabled)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.adaptive_projected_guidance_scale = adaptive_projected_guidance_scale
|
||||
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
|
||||
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
|
||||
self.eta = eta
|
||||
self.adaptive_projected_guidance_start_step = adaptive_projected_guidance_start_step
|
||||
self.use_original_formulation = use_original_formulation
|
||||
self.momentum_buffer = None
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
if self._step == 0:
|
||||
if self.adaptive_projected_guidance_momentum is not None:
|
||||
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
if self._step == 0:
|
||||
if self.adaptive_projected_guidance_momentum is not None:
|
||||
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
|
||||
pred = None
|
||||
|
||||
# no guidance
|
||||
if not self._is_cfg_enabled():
|
||||
pred = pred_cond
|
||||
|
||||
# CFG + update momentum buffer
|
||||
elif not self._is_apg_enabled():
|
||||
if self.momentum_buffer is not None:
|
||||
update_momentum_buffer(pred_cond, pred_uncond, self.momentum_buffer)
|
||||
# CFG + update momentum buffer
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
|
||||
# APG
|
||||
elif self._is_apg_enabled():
|
||||
pred = normalized_guidance(
|
||||
pred_cond,
|
||||
pred_uncond,
|
||||
self.adaptive_projected_guidance_scale,
|
||||
self.momentum_buffer,
|
||||
self.eta,
|
||||
self.adaptive_projected_guidance_rescale,
|
||||
self.use_original_formulation,
|
||||
)
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_apg_enabled() or self._is_cfg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
# Copied from diffusers.guiders.classifier_free_guidance.ClassifierFreeGuidance._is_cfg_enabled
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
def _is_apg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
if not self._is_cfg_enabled():
|
||||
return False
|
||||
|
||||
is_within_range = False
|
||||
if self._step is not None:
|
||||
is_within_range = self._step > self.adaptive_projected_guidance_start_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.adaptive_projected_guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.adaptive_projected_guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
def get_state(self):
|
||||
state = super().get_state()
|
||||
state["momentum_buffer"] = self.momentum_buffer
|
||||
state["is_apg_enabled"] = self._is_apg_enabled()
|
||||
state["is_cfg_enabled"] = self._is_cfg_enabled()
|
||||
return state
|
||||
|
||||
|
||||
# Copied from diffusers.guiders.adaptive_projected_guidance.MomentumBuffer
|
||||
class MomentumBuffer:
|
||||
def __init__(self, momentum: float):
|
||||
self.momentum = momentum
|
||||
self.running_average = 0
|
||||
|
||||
def update(self, update_value: torch.Tensor):
|
||||
new_average = self.momentum * self.running_average
|
||||
self.running_average = update_value + new_average
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""
|
||||
Returns a string representation showing momentum, shape, statistics, and a slice of the running_average.
|
||||
"""
|
||||
if isinstance(self.running_average, torch.Tensor):
|
||||
shape = tuple(self.running_average.shape)
|
||||
|
||||
# Calculate statistics
|
||||
with torch.no_grad():
|
||||
stats = {
|
||||
"mean": self.running_average.mean().item(),
|
||||
"std": self.running_average.std().item(),
|
||||
"min": self.running_average.min().item(),
|
||||
"max": self.running_average.max().item(),
|
||||
}
|
||||
|
||||
# Get a slice (max 3 elements per dimension)
|
||||
slice_indices = tuple(slice(None, min(3, dim)) for dim in shape)
|
||||
sliced_data = self.running_average[slice_indices]
|
||||
|
||||
# Format the slice for display (convert to float32 for numpy compatibility with bfloat16)
|
||||
slice_str = str(sliced_data.detach().float().cpu().numpy())
|
||||
if len(slice_str) > 200: # Truncate if too long
|
||||
slice_str = slice_str[:200] + "..."
|
||||
|
||||
stats_str = ", ".join([f"{k}={v:.4f}" for k, v in stats.items()])
|
||||
|
||||
return (
|
||||
f"MomentumBuffer(\n"
|
||||
f" momentum={self.momentum},\n"
|
||||
f" shape={shape},\n"
|
||||
f" stats=[{stats_str}],\n"
|
||||
f" slice={slice_str}\n"
|
||||
f")"
|
||||
)
|
||||
else:
|
||||
return f"MomentumBuffer(momentum={self.momentum}, running_average={self.running_average})"
|
||||
|
||||
|
||||
def update_momentum_buffer(
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: torch.Tensor,
|
||||
momentum_buffer: Optional[MomentumBuffer] = None,
|
||||
):
|
||||
diff = pred_cond - pred_uncond
|
||||
if momentum_buffer is not None:
|
||||
momentum_buffer.update(diff)
|
||||
|
||||
|
||||
def normalized_guidance(
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: torch.Tensor,
|
||||
guidance_scale: float,
|
||||
momentum_buffer: Optional[MomentumBuffer] = None,
|
||||
eta: float = 1.0,
|
||||
norm_threshold: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
):
|
||||
if momentum_buffer is not None:
|
||||
update_momentum_buffer(pred_cond, pred_uncond, momentum_buffer)
|
||||
diff = momentum_buffer.running_average
|
||||
else:
|
||||
diff = pred_cond - pred_uncond
|
||||
|
||||
dim = [-i for i in range(1, len(diff.shape))]
|
||||
|
||||
if norm_threshold > 0:
|
||||
ones = torch.ones_like(diff)
|
||||
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
|
||||
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
|
||||
diff = diff * scale_factor
|
||||
|
||||
v0, v1 = diff.double(), pred_cond.double()
|
||||
v1 = torch.nn.functional.normalize(v1, dim=dim)
|
||||
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
|
||||
v0_orthogonal = v0 - v0_parallel
|
||||
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
|
||||
normalized_update = diff_orthogonal + eta * diff_parallel
|
||||
|
||||
pred = pred_cond if use_original_formulation else pred_uncond
|
||||
pred = pred + guidance_scale * normalized_update
|
||||
|
||||
return pred
|
||||
@@ -72,8 +72,9 @@ class AutoGuidance(BaseGuidance):
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
enabled: bool = True,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
super().__init__(start, stop, enabled)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.auto_guidance_layers = auto_guidance_layers
|
||||
@@ -132,16 +133,21 @@ class AutoGuidance(BaseGuidance):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
||||
registry.remove_hook(name, recurse=True)
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
|
||||
@@ -27,43 +27,50 @@ if TYPE_CHECKING:
|
||||
|
||||
class ClassifierFreeGuidance(BaseGuidance):
|
||||
"""
|
||||
Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598
|
||||
Implements Classifier-Free Guidance (CFG) for diffusion models.
|
||||
|
||||
CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by
|
||||
jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during
|
||||
inference. This allows the model to tradeoff between generation quality and sample diversity. The original paper
|
||||
proposes scaling and shifting the conditional distribution based on the difference between conditional and
|
||||
unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)]
|
||||
Reference: https://huggingface.co/papers/2207.12598
|
||||
|
||||
Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen
|
||||
paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in
|
||||
theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
|
||||
CFG improves generation quality and prompt adherence by jointly training models on both conditional and
|
||||
unconditional data, then combining predictions during inference. This allows trading off between quality (high
|
||||
guidance) and diversity (low guidance).
|
||||
|
||||
The intution behind the original formulation can be thought of as moving the conditional distribution estimates
|
||||
further away from the unconditional distribution estimates, while the diffusers-native implementation can be
|
||||
thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of
|
||||
the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.)
|
||||
**Two CFG Formulations:**
|
||||
|
||||
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
|
||||
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
|
||||
1. **Original formulation** (from paper):
|
||||
```
|
||||
x_pred = x_cond + guidance_scale * (x_cond - x_uncond)
|
||||
```
|
||||
Moves conditional predictions further from unconditional ones.
|
||||
|
||||
2. **Diffusers-native formulation** (default, from Imagen paper):
|
||||
```
|
||||
x_pred = x_uncond + guidance_scale * (x_cond - x_uncond)
|
||||
```
|
||||
Moves unconditional predictions toward conditional ones, effectively suppressing negative features (e.g., "bad
|
||||
quality", "watermarks"). Equivalent in theory but more intuitive.
|
||||
|
||||
Use `use_original_formulation=True` to switch to the original formulation.
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
CFG scale applied by this guider during post-processing. Higher values = stronger prompt conditioning but
|
||||
may reduce quality. Typical range: 1.0-20.0.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
Rescaling factor to prevent overexposure from high guidance scales. Based on [Common Diffusion Noise
|
||||
Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). Range: 0.0 (no rescaling)
|
||||
to 1.0 (full rescaling).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
If `True`, uses the original CFG formulation from the paper. If `False` (default), uses the
|
||||
diffusers-native formulation from the Imagen paper.
|
||||
start (`float`, defaults to `0.0`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
Fraction of denoising steps (0.0-1.0) after which CFG starts. Use > 0.0 to disable CFG in early denoising
|
||||
steps.
|
||||
stop (`float`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
Fraction of denoising steps (0.0-1.0) after which CFG stops. Use < 1.0 to disable CFG in late denoising
|
||||
steps.
|
||||
enabled (`bool`, defaults to `True`):
|
||||
Whether CFG is enabled. Set to `False` to disable CFG entirely (uses only conditional predictions).
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
@@ -76,23 +83,29 @@ class ClassifierFreeGuidance(BaseGuidance):
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
enabled: bool = True,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
super().__init__(start, stop, enabled)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
|
||||
@@ -68,31 +68,41 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
enabled: bool = True,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
super().__init__(start, stop, enabled)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.zero_init_steps = zero_init_steps
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
|
||||
pred = None
|
||||
|
||||
if self._step < self.zero_init_steps:
|
||||
# YiYi Notes: add default behavior for self._enabled == False
|
||||
if not self._enabled:
|
||||
pred = pred_cond
|
||||
|
||||
elif self._step < self.zero_init_steps:
|
||||
pred = torch.zeros_like(pred_cond)
|
||||
elif not self._is_cfg_enabled():
|
||||
pred = pred_cond
|
||||
|
||||
@@ -149,6 +149,7 @@ class FrequencyDecoupledGuidance(BaseGuidance):
|
||||
stop: Union[float, List[float], Tuple[float]] = 1.0,
|
||||
guidance_rescale_space: str = "data",
|
||||
upcast_to_double: bool = True,
|
||||
enabled: bool = True,
|
||||
):
|
||||
if not _CAN_USE_KORNIA:
|
||||
raise ImportError(
|
||||
@@ -160,7 +161,7 @@ class FrequencyDecoupledGuidance(BaseGuidance):
|
||||
# Set start to earliest start for any freq component and stop to latest stop for any freq component
|
||||
min_start = start if isinstance(start, float) else min(start)
|
||||
max_stop = stop if isinstance(stop, float) else max(stop)
|
||||
super().__init__(min_start, max_stop)
|
||||
super().__init__(min_start, max_stop, enabled)
|
||||
|
||||
self.guidance_scales = guidance_scales
|
||||
self.levels = len(guidance_scales)
|
||||
@@ -217,16 +218,21 @@ class FrequencyDecoupledGuidance(BaseGuidance):
|
||||
f"({len(self.guidance_scales)})"
|
||||
)
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
|
||||
@@ -40,7 +40,11 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
_input_predictions = None
|
||||
_identifier_key = "__guidance_identifier__"
|
||||
|
||||
def __init__(self, start: float = 0.0, stop: float = 1.0):
|
||||
def __init__(self, start: float = 0.0, stop: float = 1.0, enabled: bool = True):
|
||||
logger.warning(
|
||||
"Guiders are currently an experimental feature under active development. The API is subject to breaking changes in future releases."
|
||||
)
|
||||
|
||||
self._start = start
|
||||
self._stop = stop
|
||||
self._step: int = None
|
||||
@@ -48,7 +52,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
self._timestep: torch.LongTensor = None
|
||||
self._count_prepared = 0
|
||||
self._input_fields: Dict[str, Union[str, Tuple[str, str]]] = None
|
||||
self._enabled = True
|
||||
self._enabled = enabled
|
||||
|
||||
if not (0.0 <= start < 1.0):
|
||||
raise ValueError(f"Expected `start` to be between 0.0 and 1.0, but got {start}.")
|
||||
@@ -60,6 +64,31 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
"`_input_predictions` must be a list of required prediction names for the guidance technique."
|
||||
)
|
||||
|
||||
def new(self, **kwargs):
|
||||
"""
|
||||
Creates a copy of this guider instance, optionally with modified configuration parameters.
|
||||
|
||||
Args:
|
||||
**kwargs: Configuration parameters to override in the new instance. If no kwargs are provided,
|
||||
returns an exact copy with the same configuration.
|
||||
|
||||
Returns:
|
||||
A new guider instance with the same (or updated) configuration.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Create a CFG guider
|
||||
guider = ClassifierFreeGuidance(guidance_scale=3.5)
|
||||
|
||||
# Create an exact copy
|
||||
same_guider = guider.new()
|
||||
|
||||
# Create a copy with different start step, keeping other config the same
|
||||
new_guider = guider.new(guidance_scale=5)
|
||||
```
|
||||
"""
|
||||
return self.__class__.from_config(self.config, **kwargs)
|
||||
|
||||
def disable(self):
|
||||
self._enabled = False
|
||||
|
||||
@@ -72,42 +101,52 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
self._timestep = timestep
|
||||
self._count_prepared = 0
|
||||
|
||||
def set_input_fields(self, **kwargs: Dict[str, Union[str, Tuple[str, str]]]) -> None:
|
||||
def get_state(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Set the input fields for the guidance technique. The input fields are used to specify the names of the returned
|
||||
attributes containing the prepared data after `prepare_inputs` is called. The prepared data is obtained from
|
||||
the values of the provided keyword arguments to this method.
|
||||
|
||||
Args:
|
||||
**kwargs (`Dict[str, Union[str, Tuple[str, str]]]`):
|
||||
A dictionary where the keys are the names of the fields that will be used to store the data once it is
|
||||
prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
|
||||
to look up the required data provided for preparation.
|
||||
|
||||
If a string is provided, it will be used as the conditional data (or unconditional if used with a
|
||||
guidance method that requires it). If a tuple of length 2 is provided, the first element must be the
|
||||
conditional data identifier and the second element must be the unconditional data identifier or None.
|
||||
|
||||
Example:
|
||||
```
|
||||
data = {"prompt_embeds": <some tensor>, "negative_prompt_embeds": <some tensor>, "latents": <some tensor>}
|
||||
|
||||
BaseGuidance.set_input_fields(
|
||||
latents="latents",
|
||||
prompt_embeds=("prompt_embeds", "negative_prompt_embeds"),
|
||||
)
|
||||
```
|
||||
Returns the current state of the guidance technique as a dictionary. The state variables will be included in
|
||||
the __repr__ method. Returns:
|
||||
`Dict[str, Any]`: A dictionary containing the current state variables including:
|
||||
- step: Current inference step
|
||||
- num_inference_steps: Total number of inference steps
|
||||
- timestep: Current timestep tensor
|
||||
- count_prepared: Number of times prepare_models has been called
|
||||
- enabled: Whether the guidance is enabled
|
||||
- num_conditions: Number of conditions
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
is_string = isinstance(value, str)
|
||||
is_tuple_of_str_with_len_2 = (
|
||||
isinstance(value, tuple) and len(value) == 2 and all(isinstance(v, str) for v in value)
|
||||
)
|
||||
if not (is_string or is_tuple_of_str_with_len_2):
|
||||
raise ValueError(
|
||||
f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}."
|
||||
)
|
||||
self._input_fields = kwargs
|
||||
state = {
|
||||
"step": self._step,
|
||||
"num_inference_steps": self._num_inference_steps,
|
||||
"timestep": self._timestep,
|
||||
"count_prepared": self._count_prepared,
|
||||
"enabled": self._enabled,
|
||||
"num_conditions": self.num_conditions,
|
||||
}
|
||||
return state
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""
|
||||
Returns a string representation of the guidance object including both config and current state.
|
||||
"""
|
||||
# Get ConfigMixin's __repr__
|
||||
str_repr = super().__repr__()
|
||||
|
||||
# Get current state
|
||||
state = self.get_state()
|
||||
|
||||
# Format each state variable on its own line with indentation
|
||||
state_lines = []
|
||||
for k, v in state.items():
|
||||
# Convert value to string and handle multi-line values
|
||||
v_str = str(v)
|
||||
if "\n" in v_str:
|
||||
# For multi-line values (like MomentumBuffer), indent subsequent lines
|
||||
v_lines = v_str.split("\n")
|
||||
v_str = v_lines[0] + "\n" + "\n".join([" " + line for line in v_lines[1:]])
|
||||
state_lines.append(f" {k}: {v_str}")
|
||||
|
||||
state_str = "\n".join(state_lines)
|
||||
|
||||
return f"{str_repr}\nState:\n{state_str}"
|
||||
|
||||
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
||||
"""
|
||||
@@ -127,6 +166,11 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
||||
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
raise NotImplementedError("BaseGuidance::prepare_inputs_from_block_state must be implemented in subclasses.")
|
||||
|
||||
def __call__(self, data: List["BlockState"]) -> Any:
|
||||
if not all(hasattr(d, "noise_pred") for d in data):
|
||||
raise ValueError("Expected all data to have `noise_pred` attribute.")
|
||||
@@ -154,6 +198,49 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
|
||||
@classmethod
|
||||
def _prepare_batch(
|
||||
cls,
|
||||
data: Dict[str, Tuple[torch.Tensor, torch.Tensor]],
|
||||
tuple_index: int,
|
||||
identifier: str,
|
||||
) -> "BlockState":
|
||||
"""
|
||||
Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of the
|
||||
`BaseGuidance` class. It prepares the batch based on the provided tuple index.
|
||||
|
||||
Args:
|
||||
input_fields (`Dict[str, Union[str, Tuple[str, str]]]`):
|
||||
A dictionary where the keys are the names of the fields that will be used to store the data once it is
|
||||
prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
|
||||
to look up the required data provided for preparation. If a string is provided, it will be used as the
|
||||
conditional data (or unconditional if used with a guidance method that requires it). If a tuple of
|
||||
length 2 is provided, the first element must be the conditional data identifier and the second element
|
||||
must be the unconditional data identifier or None.
|
||||
data (`BlockState`):
|
||||
The input data to be prepared.
|
||||
tuple_index (`int`):
|
||||
The index to use when accessing input fields that are tuples.
|
||||
|
||||
Returns:
|
||||
`BlockState`: The prepared batch of data.
|
||||
"""
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
data_batch = {}
|
||||
for key, value in data.items():
|
||||
try:
|
||||
if isinstance(value, torch.Tensor):
|
||||
data_batch[key] = value
|
||||
elif isinstance(value, tuple):
|
||||
data_batch[key] = value[tuple_index]
|
||||
else:
|
||||
raise ValueError(f"Invalid value type: {type(value)}")
|
||||
except ValueError:
|
||||
logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
|
||||
data_batch[cls._identifier_key] = identifier
|
||||
return BlockState(**data_batch)
|
||||
|
||||
@classmethod
|
||||
def _prepare_batch_from_block_state(
|
||||
cls,
|
||||
input_fields: Dict[str, Union[str, Tuple[str, str]]],
|
||||
data: "BlockState",
|
||||
@@ -182,10 +269,6 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
"""
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
if input_fields is None:
|
||||
raise ValueError(
|
||||
"Input fields cannot be None. Please pass `input_fields` to `prepare_inputs` or call `set_input_fields` before preparing inputs."
|
||||
)
|
||||
data_batch = {}
|
||||
for key, value in input_fields.items():
|
||||
try:
|
||||
@@ -290,7 +373,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
r"""
|
||||
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
|
||||
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
|
||||
Args:
|
||||
noise_cfg (`torch.Tensor`):
|
||||
|
||||
@@ -98,8 +98,9 @@ class PerturbedAttentionGuidance(BaseGuidance):
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
enabled: bool = True,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
super().__init__(start, stop, enabled)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.skip_layer_guidance_scale = perturbed_guidance_scale
|
||||
@@ -168,12 +169,7 @@ class PerturbedAttentionGuidance(BaseGuidance):
|
||||
registry.remove_hook(hook_name, recurse=True)
|
||||
|
||||
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_inputs
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
@@ -186,8 +182,28 @@ class PerturbedAttentionGuidance(BaseGuidance):
|
||||
tuple_indices = [0, 1, 0]
|
||||
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
|
||||
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
elif self.num_conditions == 2:
|
||||
tuple_indices = [0, 1]
|
||||
input_predictions = (
|
||||
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
|
||||
)
|
||||
else:
|
||||
tuple_indices = [0, 1, 0]
|
||||
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
|
||||
@@ -100,8 +100,9 @@ class SkipLayerGuidance(BaseGuidance):
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
enabled: bool = True,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
super().__init__(start, stop, enabled)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.skip_layer_guidance_scale = skip_layer_guidance_scale
|
||||
@@ -164,12 +165,7 @@ class SkipLayerGuidance(BaseGuidance):
|
||||
for hook_name in self._skip_layer_hook_names:
|
||||
registry.remove_hook(hook_name, recurse=True)
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
@@ -182,8 +178,28 @@ class SkipLayerGuidance(BaseGuidance):
|
||||
tuple_indices = [0, 1, 0]
|
||||
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
|
||||
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
elif self.num_conditions == 2:
|
||||
tuple_indices = [0, 1]
|
||||
input_predictions = (
|
||||
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
|
||||
)
|
||||
else:
|
||||
tuple_indices = [0, 1, 0]
|
||||
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
|
||||
@@ -92,8 +92,9 @@ class SmoothedEnergyGuidance(BaseGuidance):
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
enabled: bool = True,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
super().__init__(start, stop, enabled)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.seg_guidance_scale = seg_guidance_scale
|
||||
@@ -153,12 +154,7 @@ class SmoothedEnergyGuidance(BaseGuidance):
|
||||
for hook_name in self._seg_layer_hook_names:
|
||||
registry.remove_hook(hook_name, recurse=True)
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
@@ -171,8 +167,28 @@ class SmoothedEnergyGuidance(BaseGuidance):
|
||||
tuple_indices = [0, 1, 0]
|
||||
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
|
||||
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
elif self.num_conditions == 2:
|
||||
tuple_indices = [0, 1]
|
||||
input_predictions = (
|
||||
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"]
|
||||
)
|
||||
else:
|
||||
tuple_indices = [0, 1, 0]
|
||||
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
|
||||
@@ -58,23 +58,29 @@ class TangentialClassifierFreeGuidance(BaseGuidance):
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
enabled: bool = True,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
super().__init__(start, stop, enabled)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
|
||||
@@ -108,8 +108,10 @@ def _register_attention_processors_metadata():
|
||||
from ..models.attention_processor import AttnProcessor2_0
|
||||
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
|
||||
from ..models.transformers.transformer_flux import FluxAttnProcessor
|
||||
from ..models.transformers.transformer_hunyuanimage import HunyuanImageAttnProcessor
|
||||
from ..models.transformers.transformer_qwenimage import QwenDoubleStreamAttnProcessor2_0
|
||||
from ..models.transformers.transformer_wan import WanAttnProcessor2_0
|
||||
from ..models.transformers.transformer_z_image import ZSingleStreamAttnProcessor
|
||||
|
||||
# AttnProcessor2_0
|
||||
AttentionProcessorRegistry.register(
|
||||
@@ -149,6 +151,22 @@ def _register_attention_processors_metadata():
|
||||
),
|
||||
)
|
||||
|
||||
# HunyuanImageAttnProcessor
|
||||
AttentionProcessorRegistry.register(
|
||||
model_class=HunyuanImageAttnProcessor,
|
||||
metadata=AttentionProcessorMetadata(
|
||||
skip_processor_output_fn=_skip_proc_output_fn_Attention_HunyuanImageAttnProcessor,
|
||||
),
|
||||
)
|
||||
|
||||
# ZSingleStreamAttnProcessor
|
||||
AttentionProcessorRegistry.register(
|
||||
model_class=ZSingleStreamAttnProcessor,
|
||||
metadata=AttentionProcessorMetadata(
|
||||
skip_processor_output_fn=_skip_proc_output_fn_Attention_ZSingleStreamAttnProcessor,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _register_transformer_blocks_metadata():
|
||||
from ..models.attention import BasicTransformerBlock
|
||||
@@ -162,10 +180,15 @@ def _register_transformer_blocks_metadata():
|
||||
HunyuanVideoTokenReplaceTransformerBlock,
|
||||
HunyuanVideoTransformerBlock,
|
||||
)
|
||||
from ..models.transformers.transformer_hunyuanimage import (
|
||||
HunyuanImageSingleTransformerBlock,
|
||||
HunyuanImageTransformerBlock,
|
||||
)
|
||||
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
|
||||
from ..models.transformers.transformer_mochi import MochiTransformerBlock
|
||||
from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
|
||||
from ..models.transformers.transformer_wan import WanTransformerBlock
|
||||
from ..models.transformers.transformer_z_image import ZImageTransformerBlock
|
||||
|
||||
# BasicTransformerBlock
|
||||
TransformerBlockRegistry.register(
|
||||
@@ -283,6 +306,31 @@ def _register_transformer_blocks_metadata():
|
||||
),
|
||||
)
|
||||
|
||||
# HunyuanImage2.1
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=HunyuanImageTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=HunyuanImageSingleTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
|
||||
# ZImage
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=ZImageTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# fmt: off
|
||||
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
|
||||
@@ -308,4 +356,6 @@ _skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hid
|
||||
# not sure what this is yet.
|
||||
_skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states
|
||||
_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0 = _skip_attention___ret___hidden_states
|
||||
_skip_proc_output_fn_Attention_HunyuanImageAttnProcessor = _skip_attention___ret___hidden_states
|
||||
_skip_proc_output_fn_Attention_ZSingleStreamAttnProcessor = _skip_attention___ret___hidden_states
|
||||
# fmt: on
|
||||
|
||||
@@ -203,10 +203,12 @@ class ContextParallelSplitHook(ModelHook):
|
||||
|
||||
def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor:
|
||||
if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims:
|
||||
raise ValueError(
|
||||
f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions."
|
||||
logger.warning_once(
|
||||
f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions, split will not be applied."
|
||||
)
|
||||
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
|
||||
return x
|
||||
else:
|
||||
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
|
||||
|
||||
|
||||
class ContextParallelGatherHook(ModelHook):
|
||||
|
||||
@@ -409,7 +409,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
src_w = width if ratio < src_ratio else image.width * height // image.height
|
||||
src_h = height if ratio >= src_ratio else image.height * width // image.width
|
||||
|
||||
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
|
||||
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION[self.config.resample])
|
||||
res = Image.new("RGB", (width, height))
|
||||
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
||||
|
||||
@@ -460,7 +460,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
src_w = width if ratio > src_ratio else image.width * height // image.height
|
||||
src_h = height if ratio <= src_ratio else image.height * width // image.width
|
||||
|
||||
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
|
||||
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION[self.config.resample])
|
||||
res = Image.new("RGB", (width, height))
|
||||
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
||||
return res
|
||||
@@ -1045,16 +1045,39 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
||||
r"""
|
||||
Convert an RGB-like depth image to a depth map.
|
||||
|
||||
Args:
|
||||
image (`Union[np.ndarray, torch.Tensor]`):
|
||||
The RGB-like depth image to convert.
|
||||
|
||||
Returns:
|
||||
`Union[np.ndarray, torch.Tensor]`:
|
||||
The corresponding depth map.
|
||||
"""
|
||||
return image[:, :, 1] * 2**8 + image[:, :, 2]
|
||||
# 1. Cast the tensor to a larger integer type (e.g., int32)
|
||||
# to safely perform the multiplication by 256.
|
||||
# 2. Perform the 16-bit combination: High-byte * 256 + Low-byte.
|
||||
# 3. Cast the final result to the desired depth map type (uint16) if needed
|
||||
# before returning, though leaving it as int32/int64 is often safer
|
||||
# for return value from a library function.
|
||||
|
||||
if isinstance(image, torch.Tensor):
|
||||
# Cast to a safe dtype (e.g., int32 or int64) for the calculation
|
||||
original_dtype = image.dtype
|
||||
image_safe = image.to(torch.int32)
|
||||
|
||||
# Calculate the depth map
|
||||
depth_map = image_safe[:, :, 1] * 256 + image_safe[:, :, 2]
|
||||
|
||||
# You may want to cast the final result to uint16, but casting to a
|
||||
# larger int type (like int32) is sufficient to fix the overflow.
|
||||
# depth_map = depth_map.to(torch.uint16) # Uncomment if uint16 is strictly required
|
||||
return depth_map.to(original_dtype)
|
||||
|
||||
elif isinstance(image, np.ndarray):
|
||||
# NumPy equivalent: Cast to a safe dtype (e.g., np.int32)
|
||||
original_dtype = image.dtype
|
||||
image_safe = image.astype(np.int32)
|
||||
|
||||
# Calculate the depth map
|
||||
depth_map = image_safe[:, :, 1] * 256 + image_safe[:, :, 2]
|
||||
|
||||
# depth_map = depth_map.astype(np.uint16) # Uncomment if uint16 is strictly required
|
||||
return depth_map.astype(original_dtype)
|
||||
else:
|
||||
raise TypeError("Input image must be a torch.Tensor or np.ndarray")
|
||||
|
||||
def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
|
||||
r"""
|
||||
|
||||
@@ -81,6 +81,7 @@ if is_torch_available():
|
||||
"HiDreamImageLoraLoaderMixin",
|
||||
"SkyReelsV2LoraLoaderMixin",
|
||||
"QwenImageLoraLoaderMixin",
|
||||
"Flux2LoraLoaderMixin",
|
||||
]
|
||||
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
|
||||
_import_structure["ip_adapter"] = [
|
||||
@@ -113,6 +114,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AuraFlowLoraLoaderMixin,
|
||||
CogVideoXLoraLoaderMixin,
|
||||
CogView4LoraLoaderMixin,
|
||||
Flux2LoraLoaderMixin,
|
||||
FluxLoraLoaderMixin,
|
||||
HiDreamImageLoraLoaderMixin,
|
||||
HunyuanVideoLoraLoaderMixin,
|
||||
|
||||
@@ -1977,14 +1977,34 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
||||
"time_projection.1.diff_b"
|
||||
)
|
||||
|
||||
if any("head.head" in k for k in state_dict):
|
||||
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop(
|
||||
f"head.head.{lora_down_key}.weight"
|
||||
)
|
||||
converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(f"head.head.{lora_up_key}.weight")
|
||||
if any("head.head" in k for k in original_state_dict):
|
||||
if any(f"head.head.{lora_down_key}.weight" in k for k in state_dict):
|
||||
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop(
|
||||
f"head.head.{lora_down_key}.weight"
|
||||
)
|
||||
if any(f"head.head.{lora_up_key}.weight" in k for k in state_dict):
|
||||
converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(
|
||||
f"head.head.{lora_up_key}.weight"
|
||||
)
|
||||
if "head.head.diff_b" in original_state_dict:
|
||||
converted_state_dict["proj_out.lora_B.bias"] = original_state_dict.pop("head.head.diff_b")
|
||||
|
||||
# Notes: https://huggingface.co/lightx2v/Wan2.2-Distill-Loras
|
||||
# This is my (sayakpaul) assumption that this particular key belongs to the down matrix.
|
||||
# Since for this particular LoRA, we don't have the corresponding up matrix, I will use
|
||||
# an identity.
|
||||
if any("head.head" in k and k.endswith(".diff") for k in state_dict):
|
||||
if f"head.head.{lora_down_key}.weight" in state_dict:
|
||||
logger.info(
|
||||
f"The state dict seems to be have both `head.head.diff` and `head.head.{lora_down_key}.weight` keys, which is unexpected."
|
||||
)
|
||||
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop("head.head.diff")
|
||||
down_matrix_head = converted_state_dict["proj_out.lora_A.weight"]
|
||||
up_matrix_shape = (down_matrix_head.shape[0], converted_state_dict["proj_out.lora_B.bias"].shape[0])
|
||||
converted_state_dict["proj_out.lora_B.weight"] = torch.eye(
|
||||
*up_matrix_shape, dtype=down_matrix_head.dtype, device=down_matrix_head.device
|
||||
).T
|
||||
|
||||
for text_time in ["text_embedding", "time_embedding"]:
|
||||
if any(text_time in k for k in original_state_dict):
|
||||
for b_n in [0, 2]:
|
||||
@@ -2193,6 +2213,10 @@ def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
|
||||
|
||||
state_dict = {convert_key(k): v for k, v in state_dict.items()}
|
||||
|
||||
has_default = any("default." in k for k in state_dict)
|
||||
if has_default:
|
||||
state_dict = {k.replace("default.", ""): v for k, v in state_dict.items()}
|
||||
|
||||
converted_state_dict = {}
|
||||
all_keys = list(state_dict.keys())
|
||||
down_key = ".lora_down.weight"
|
||||
@@ -2241,3 +2265,89 @@ def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
|
||||
|
||||
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
|
||||
converted_state_dict = {}
|
||||
|
||||
prefix = "diffusion_model."
|
||||
original_state_dict = {k[len(prefix) :]: v for k, v in state_dict.items()}
|
||||
|
||||
num_double_layers = 8
|
||||
num_single_layers = 48
|
||||
lora_keys = ("lora_A", "lora_B")
|
||||
attn_types = ("img_attn", "txt_attn")
|
||||
|
||||
for sl in range(num_single_layers):
|
||||
single_block_prefix = f"single_blocks.{sl}"
|
||||
attn_prefix = f"single_transformer_blocks.{sl}.attn"
|
||||
|
||||
for lora_key in lora_keys:
|
||||
converted_state_dict[f"{attn_prefix}.to_qkv_mlp_proj.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"{single_block_prefix}.linear1.{lora_key}.weight"
|
||||
)
|
||||
|
||||
converted_state_dict[f"{attn_prefix}.to_out.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"{single_block_prefix}.linear2.{lora_key}.weight"
|
||||
)
|
||||
|
||||
for dl in range(num_double_layers):
|
||||
transformer_block_prefix = f"transformer_blocks.{dl}"
|
||||
|
||||
for lora_key in lora_keys:
|
||||
for attn_type in attn_types:
|
||||
attn_prefix = f"{transformer_block_prefix}.attn"
|
||||
qkv_key = f"double_blocks.{dl}.{attn_type}.qkv.{lora_key}.weight"
|
||||
fused_qkv_weight = original_state_dict.pop(qkv_key)
|
||||
|
||||
if lora_key == "lora_A":
|
||||
diff_attn_proj_keys = (
|
||||
["to_q", "to_k", "to_v"]
|
||||
if attn_type == "img_attn"
|
||||
else ["add_q_proj", "add_k_proj", "add_v_proj"]
|
||||
)
|
||||
for proj_key in diff_attn_proj_keys:
|
||||
converted_state_dict[f"{attn_prefix}.{proj_key}.{lora_key}.weight"] = torch.cat(
|
||||
[fused_qkv_weight]
|
||||
)
|
||||
else:
|
||||
sample_q, sample_k, sample_v = torch.chunk(fused_qkv_weight, 3, dim=0)
|
||||
|
||||
if attn_type == "img_attn":
|
||||
converted_state_dict[f"{attn_prefix}.to_q.{lora_key}.weight"] = torch.cat([sample_q])
|
||||
converted_state_dict[f"{attn_prefix}.to_k.{lora_key}.weight"] = torch.cat([sample_k])
|
||||
converted_state_dict[f"{attn_prefix}.to_v.{lora_key}.weight"] = torch.cat([sample_v])
|
||||
else:
|
||||
converted_state_dict[f"{attn_prefix}.add_q_proj.{lora_key}.weight"] = torch.cat([sample_q])
|
||||
converted_state_dict[f"{attn_prefix}.add_k_proj.{lora_key}.weight"] = torch.cat([sample_k])
|
||||
converted_state_dict[f"{attn_prefix}.add_v_proj.{lora_key}.weight"] = torch.cat([sample_v])
|
||||
|
||||
proj_mappings = [
|
||||
("img_attn.proj", "attn.to_out.0"),
|
||||
("txt_attn.proj", "attn.to_add_out"),
|
||||
]
|
||||
for org_proj, diff_proj in proj_mappings:
|
||||
for lora_key in lora_keys:
|
||||
original_key = f"double_blocks.{dl}.{org_proj}.{lora_key}.weight"
|
||||
diffusers_key = f"{transformer_block_prefix}.{diff_proj}.{lora_key}.weight"
|
||||
converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
|
||||
|
||||
mlp_mappings = [
|
||||
("img_mlp.0", "ff.linear_in"),
|
||||
("img_mlp.2", "ff.linear_out"),
|
||||
("txt_mlp.0", "ff_context.linear_in"),
|
||||
("txt_mlp.2", "ff_context.linear_out"),
|
||||
]
|
||||
for org_mlp, diff_mlp in mlp_mappings:
|
||||
for lora_key in lora_keys:
|
||||
original_key = f"double_blocks.{dl}.{org_mlp}.{lora_key}.weight"
|
||||
diffusers_key = f"{transformer_block_prefix}.{diff_mlp}.{lora_key}.weight"
|
||||
converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
|
||||
|
||||
if len(original_state_dict) > 0:
|
||||
raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")
|
||||
|
||||
for key in list(converted_state_dict.keys()):
|
||||
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
@@ -45,6 +45,7 @@ from .lora_conversion_utils import (
|
||||
_convert_hunyuan_video_lora_to_diffusers,
|
||||
_convert_kohya_flux_lora_to_diffusers,
|
||||
_convert_musubi_wan_lora_to_diffusers,
|
||||
_convert_non_diffusers_flux2_lora_to_diffusers,
|
||||
_convert_non_diffusers_hidream_lora_to_diffusers,
|
||||
_convert_non_diffusers_lora_to_diffusers,
|
||||
_convert_non_diffusers_ltxv_lora_to_diffusers,
|
||||
@@ -4940,7 +4941,8 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
|
||||
has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
|
||||
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
|
||||
has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict)
|
||||
if has_alphas_in_sd or has_lora_unet or has_diffusion_model:
|
||||
has_default = any("default." in k for k in state_dict)
|
||||
if has_alphas_in_sd or has_lora_unet or has_diffusion_model or has_default:
|
||||
state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict)
|
||||
|
||||
out = (state_dict, metadata) if return_lora_metadata else state_dict
|
||||
@@ -5083,6 +5085,209 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class Flux2LoraLoaderMixin(LoraBaseMixin):
|
||||
r"""
|
||||
Load LoRA layers into [`Flux2Transformer2DModel`]. Specific to [`Flux2Pipeline`].
|
||||
"""
|
||||
|
||||
_lora_loadable_modules = ["transformer"]
|
||||
transformer_name = TRANSFORMER_NAME
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
|
||||
"""
|
||||
# Load the main state dict first which has the LoRA layers for either of
|
||||
# transformer and text encoder or both.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
||||
|
||||
state_dict, metadata = _fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
||||
weight_name=weight_name,
|
||||
use_safetensors=use_safetensors,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
allow_pickle=allow_pickle,
|
||||
)
|
||||
|
||||
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
||||
if is_dora_scale_present:
|
||||
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
is_ai_toolkit = any(k.startswith("diffusion_model.") for k in state_dict)
|
||||
if is_ai_toolkit:
|
||||
state_dict = _convert_non_diffusers_flux2_lora_to_diffusers(state_dict)
|
||||
|
||||
out = (state_dict, metadata) if return_lora_metadata else state_dict
|
||||
return out
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# if a dict is passed, copy it instead of modifying it inplace
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
kwargs["return_lora_metadata"] = True
|
||||
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel
|
||||
def load_lora_into_transformer(
|
||||
cls,
|
||||
state_dict,
|
||||
transformer,
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
hotswap: bool = False,
|
||||
metadata=None,
|
||||
):
|
||||
"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# Load the layers corresponding to transformer.
|
||||
logger.info(f"Loading {cls.transformer_name}.")
|
||||
transformer.load_lora_adapter(
|
||||
state_dict,
|
||||
network_alphas=None,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
transformer_lora_adapter_metadata: Optional[dict] = None,
|
||||
):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
|
||||
"""
|
||||
lora_layers = {}
|
||||
lora_metadata = {}
|
||||
|
||||
if transformer_lora_layers:
|
||||
lora_layers[cls.transformer_name] = transformer_lora_layers
|
||||
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
|
||||
|
||||
if not lora_layers:
|
||||
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
|
||||
|
||||
cls._save_lora_weights(
|
||||
save_directory=save_directory,
|
||||
lora_layers=lora_layers,
|
||||
lora_metadata=lora_metadata,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
|
||||
"""
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
|
||||
|
||||
@@ -62,6 +62,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
|
||||
"WanVACETransformer3DModel": lambda model_cls, weights: weights,
|
||||
"ChromaTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"QwenImageTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"Flux2Transformer2DModel": lambda model_cls, weights: weights,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ from .single_file_utils import (
|
||||
convert_chroma_transformer_checkpoint_to_diffusers,
|
||||
convert_controlnet_checkpoint,
|
||||
convert_cosmos_transformer_checkpoint_to_diffusers,
|
||||
convert_flux2_transformer_checkpoint_to_diffusers,
|
||||
convert_flux_transformer_checkpoint_to_diffusers,
|
||||
convert_hidream_transformer_to_diffusers,
|
||||
convert_hunyuan_video_transformer_to_diffusers,
|
||||
@@ -162,6 +163,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"checkpoint_mapping_fn": lambda x: x,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"Flux2Transformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_flux2_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -140,6 +140,7 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"net.blocks.0.self_attn.q_proj.weight",
|
||||
"net.pos_embedder.dim_spatial_range",
|
||||
],
|
||||
"flux2": ["model.diffusion_model.single_stream_modulation.lin.weight", "single_stream_modulation.lin.weight"],
|
||||
}
|
||||
|
||||
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
@@ -189,6 +190,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
|
||||
"flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
|
||||
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
|
||||
"flux-2-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.2-dev"},
|
||||
"ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"},
|
||||
"ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"},
|
||||
"ltx-video-0.9.5": {"pretrained_model_name_or_path": "Lightricks/LTX-Video-0.9.5"},
|
||||
@@ -649,6 +651,9 @@ def infer_diffusers_model_type(checkpoint):
|
||||
else:
|
||||
model_type = "animatediff_v3"
|
||||
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux2"]):
|
||||
model_type = "flux-2-dev"
|
||||
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]):
|
||||
if any(
|
||||
g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
|
||||
@@ -3647,3 +3652,168 @@ def convert_cosmos_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
handler_fn_inplace(key, converted_state_dict)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_flux2_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
FLUX2_TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
# Image and text input projections
|
||||
"img_in": "x_embedder",
|
||||
"txt_in": "context_embedder",
|
||||
# Timestep and guidance embeddings
|
||||
"time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1",
|
||||
"time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2",
|
||||
"guidance_in.in_layer": "time_guidance_embed.guidance_embedder.linear_1",
|
||||
"guidance_in.out_layer": "time_guidance_embed.guidance_embedder.linear_2",
|
||||
# Modulation parameters
|
||||
"double_stream_modulation_img.lin": "double_stream_modulation_img.linear",
|
||||
"double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear",
|
||||
"single_stream_modulation.lin": "single_stream_modulation.linear",
|
||||
# Final output layer
|
||||
# "final_layer.adaLN_modulation.1": "norm_out.linear", # Handle separately since we need to swap mod params
|
||||
"final_layer.linear": "proj_out",
|
||||
}
|
||||
|
||||
FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP = {
|
||||
"final_layer.adaLN_modulation.1": "norm_out.linear",
|
||||
}
|
||||
|
||||
FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP = {
|
||||
# Handle fused QKV projections separately as we need to break into Q, K, V projections
|
||||
"img_attn.norm.query_norm": "attn.norm_q",
|
||||
"img_attn.norm.key_norm": "attn.norm_k",
|
||||
"img_attn.proj": "attn.to_out.0",
|
||||
"img_mlp.0": "ff.linear_in",
|
||||
"img_mlp.2": "ff.linear_out",
|
||||
"txt_attn.norm.query_norm": "attn.norm_added_q",
|
||||
"txt_attn.norm.key_norm": "attn.norm_added_k",
|
||||
"txt_attn.proj": "attn.to_add_out",
|
||||
"txt_mlp.0": "ff_context.linear_in",
|
||||
"txt_mlp.2": "ff_context.linear_out",
|
||||
}
|
||||
|
||||
FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP = {
|
||||
"linear1": "attn.to_qkv_mlp_proj",
|
||||
"linear2": "attn.to_out",
|
||||
"norm.query_norm": "attn.norm_q",
|
||||
"norm.key_norm": "attn.norm_k",
|
||||
}
|
||||
|
||||
def convert_flux2_single_stream_blocks(key: str, state_dict: dict[str, object]) -> None:
|
||||
# Skip if not a weight, bias, or scale
|
||||
if ".weight" not in key and ".bias" not in key and ".scale" not in key:
|
||||
return
|
||||
|
||||
# Mapping:
|
||||
# - single_blocks.{N}.linear1 --> single_transformer_blocks.{N}.attn.to_qkv_mlp_proj
|
||||
# - single_blocks.{N}.linear2 --> single_transformer_blocks.{N}.attn.to_out
|
||||
# - single_blocks.{N}.norm.query_norm.scale --> single_transformer_blocks.{N}.attn.norm_q.weight
|
||||
# - single_blocks.{N}.norm.key_norm.scale --> single_transformer_blocks.{N}.attn.norm_k.weight
|
||||
new_prefix = "single_transformer_blocks"
|
||||
if "single_blocks." in key:
|
||||
parts = key.split(".")
|
||||
block_idx = parts[1]
|
||||
within_block_name = ".".join(parts[2:-1])
|
||||
param_type = parts[-1]
|
||||
|
||||
if param_type == "scale":
|
||||
param_type = "weight"
|
||||
|
||||
new_within_block_name = FLUX2_TRANSFORMER_SINGLE_BLOCK_KEY_MAP[within_block_name]
|
||||
new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type])
|
||||
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
|
||||
return
|
||||
|
||||
def convert_ada_layer_norm_weights(key: str, state_dict: dict[str, object]) -> None:
|
||||
# Skip if not a weight
|
||||
if ".weight" not in key:
|
||||
return
|
||||
|
||||
# If adaLN_modulation is in the key, swap scale and shift parameters
|
||||
# Original implementation is (shift, scale); diffusers implementation is (scale, shift)
|
||||
if "adaLN_modulation" in key:
|
||||
key_without_param_type, param_type = key.rsplit(".", maxsplit=1)
|
||||
# Assume all such keys are in the AdaLayerNorm key map
|
||||
new_key_without_param_type = FLUX2_TRANSFORMER_ADA_LAYER_NORM_KEY_MAP[key_without_param_type]
|
||||
new_key = ".".join([new_key_without_param_type, param_type])
|
||||
|
||||
swapped_weight = swap_scale_shift(state_dict.pop(key), 0)
|
||||
state_dict[new_key] = swapped_weight
|
||||
|
||||
return
|
||||
|
||||
def convert_flux2_double_stream_blocks(key: str, state_dict: dict[str, object]) -> None:
|
||||
# Skip if not a weight, bias, or scale
|
||||
if ".weight" not in key and ".bias" not in key and ".scale" not in key:
|
||||
return
|
||||
|
||||
new_prefix = "transformer_blocks"
|
||||
if "double_blocks." in key:
|
||||
parts = key.split(".")
|
||||
block_idx = parts[1]
|
||||
modality_block_name = parts[2] # img_attn, img_mlp, txt_attn, txt_mlp
|
||||
within_block_name = ".".join(parts[2:-1])
|
||||
param_type = parts[-1]
|
||||
|
||||
if param_type == "scale":
|
||||
param_type = "weight"
|
||||
|
||||
if "qkv" in within_block_name:
|
||||
fused_qkv_weight = state_dict.pop(key)
|
||||
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
|
||||
if "img" in modality_block_name:
|
||||
# double_blocks.{N}.img_attn.qkv --> transformer_blocks.{N}.attn.{to_q|to_k|to_v}
|
||||
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
|
||||
new_q_name = "attn.to_q"
|
||||
new_k_name = "attn.to_k"
|
||||
new_v_name = "attn.to_v"
|
||||
elif "txt" in modality_block_name:
|
||||
# double_blocks.{N}.txt_attn.qkv --> transformer_blocks.{N}.attn.{add_q_proj|add_k_proj|add_v_proj}
|
||||
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
|
||||
new_q_name = "attn.add_q_proj"
|
||||
new_k_name = "attn.add_k_proj"
|
||||
new_v_name = "attn.add_v_proj"
|
||||
new_q_key = ".".join([new_prefix, block_idx, new_q_name, param_type])
|
||||
new_k_key = ".".join([new_prefix, block_idx, new_k_name, param_type])
|
||||
new_v_key = ".".join([new_prefix, block_idx, new_v_name, param_type])
|
||||
state_dict[new_q_key] = to_q_weight
|
||||
state_dict[new_k_key] = to_k_weight
|
||||
state_dict[new_v_key] = to_v_weight
|
||||
else:
|
||||
new_within_block_name = FLUX2_TRANSFORMER_DOUBLE_BLOCK_KEY_MAP[within_block_name]
|
||||
new_key = ".".join([new_prefix, block_idx, new_within_block_name, param_type])
|
||||
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
return
|
||||
|
||||
def update_state_dict(state_dict: dict[str, object], old_key: str, new_key: str) -> None:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"adaLN_modulation": convert_ada_layer_norm_weights,
|
||||
"double_blocks": convert_flux2_double_stream_blocks,
|
||||
"single_blocks": convert_flux2_single_stream_blocks,
|
||||
}
|
||||
|
||||
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(converted_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in FLUX2_TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
|
||||
update_state_dict(converted_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(converted_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, converted_state_dict)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
@@ -35,7 +35,10 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"]
|
||||
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
|
||||
_import_structure["autoencoders.autoencoder_kl_cosmos"] = ["AutoencoderKLCosmos"]
|
||||
_import_structure["autoencoders.autoencoder_kl_flux2"] = ["AutoencoderKLFlux2"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuanimage"] = ["AutoencoderKLHunyuanImage"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
|
||||
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
|
||||
@@ -82,26 +85,34 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_bria"] = ["BriaTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_bria_fibo"] = ["BriaFiboTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_chronoedit"] = ["ChronoEditTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_cosmos"] = ["CosmosTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
|
||||
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_prx"] = ["PRXTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_sana_video"] = ["SanaVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
|
||||
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
|
||||
_import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_wan_animate"] = ["WanAnimateTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"]
|
||||
_import_structure["transformers.transformer_z_image"] = ["ZImageTransformer2DModel"]
|
||||
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
|
||||
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
|
||||
_import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
|
||||
@@ -132,6 +143,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLAllegro,
|
||||
AutoencoderKLCogVideoX,
|
||||
AutoencoderKLCosmos,
|
||||
AutoencoderKLFlux2,
|
||||
AutoencoderKLHunyuanImage,
|
||||
AutoencoderKLHunyuanImageRefiner,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMagvit,
|
||||
@@ -168,8 +182,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .transformers import (
|
||||
AllegroTransformer3DModel,
|
||||
AuraFlowTransformer2DModel,
|
||||
BriaFiboTransformer2DModel,
|
||||
BriaTransformer2DModel,
|
||||
ChromaTransformer2DModel,
|
||||
ChronoEditTransformer3DModel,
|
||||
CogVideoXTransformer3DModel,
|
||||
CogView3PlusTransformer2DModel,
|
||||
CogView4Transformer2DModel,
|
||||
@@ -178,9 +194,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
DiTTransformer2DModel,
|
||||
DualTransformer2DModel,
|
||||
EasyAnimateTransformer3DModel,
|
||||
Flux2Transformer2DModel,
|
||||
FluxTransformer2DModel,
|
||||
HiDreamImageTransformer2DModel,
|
||||
HunyuanDiT2DModel,
|
||||
HunyuanImageTransformer2DModel,
|
||||
HunyuanVideoFramepackTransformer3DModel,
|
||||
HunyuanVideoTransformer3DModel,
|
||||
Kandinsky5Transformer3DModel,
|
||||
@@ -192,16 +210,20 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
OmniGenTransformer2DModel,
|
||||
PixArtTransformer2DModel,
|
||||
PriorTransformer,
|
||||
PRXTransformer2DModel,
|
||||
QwenImageTransformer2DModel,
|
||||
SanaTransformer2DModel,
|
||||
SanaVideoTransformer3DModel,
|
||||
SD3Transformer2DModel,
|
||||
SkyReelsV2Transformer3DModel,
|
||||
StableAudioDiTModel,
|
||||
T5FilmDecoder,
|
||||
Transformer2DModel,
|
||||
TransformerTemporalModel,
|
||||
WanAnimateTransformer3DModel,
|
||||
WanTransformer3DModel,
|
||||
WanVACETransformer3DModel,
|
||||
ZImageTransformer2DModel,
|
||||
)
|
||||
from .unets import (
|
||||
I2VGenXLUNet,
|
||||
|
||||
@@ -44,11 +44,16 @@ class ContextParallelConfig:
|
||||
|
||||
Args:
|
||||
ring_degree (`int`, *optional*, defaults to `1`):
|
||||
Number of devices to use for ring attention within a context parallel region. Must be a divisor of the
|
||||
total number of devices in the context parallel mesh.
|
||||
Number of devices to use for Ring Attention. Sequence is split across devices. Each device computes
|
||||
attention between its local Q and KV chunks passed sequentially around ring. Lower memory (only holds 1/N
|
||||
of KV at a time), overlaps compute with communication, but requires N iterations to see all tokens. Best
|
||||
for long sequences with limited memory/bandwidth. Number of devices to use for ring attention within a
|
||||
context parallel region. Must be a divisor of the total number of devices in the context parallel mesh.
|
||||
ulysses_degree (`int`, *optional*, defaults to `1`):
|
||||
Number of devices to use for ulysses attention within a context parallel region. Must be a divisor of the
|
||||
total number of devices in the context parallel mesh.
|
||||
Number of devices to use for Ulysses Attention. Sequence split is across devices. Each device computes
|
||||
local QKV, then all-gathers all KV chunks to compute full attention in one pass. Higher memory (stores all
|
||||
KV), requires high-bandwidth all-to-all communication, but lower latency. Best for moderate sequences with
|
||||
good interconnect bandwidth.
|
||||
convert_to_fp32 (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert output and LSE to float32 for ring attention numerical stability.
|
||||
rotate_method (`str`, *optional*, defaults to `"allgather"`):
|
||||
@@ -79,29 +84,46 @@ class ContextParallelConfig:
|
||||
if self.ulysses_degree is None:
|
||||
self.ulysses_degree = 1
|
||||
|
||||
if self.ring_degree == 1 and self.ulysses_degree == 1:
|
||||
raise ValueError(
|
||||
"Either ring_degree or ulysses_degree must be greater than 1 in order to use context parallel inference"
|
||||
)
|
||||
if self.ring_degree < 1 or self.ulysses_degree < 1:
|
||||
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
|
||||
if self.ring_degree > 1 and self.ulysses_degree > 1:
|
||||
raise ValueError(
|
||||
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
|
||||
)
|
||||
if self.rotate_method != "allgather":
|
||||
raise NotImplementedError(
|
||||
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
|
||||
)
|
||||
|
||||
@property
|
||||
def mesh_shape(self) -> Tuple[int, int]:
|
||||
return (self.ring_degree, self.ulysses_degree)
|
||||
|
||||
@property
|
||||
def mesh_dim_names(self) -> Tuple[str, str]:
|
||||
"""Dimension names for the device mesh."""
|
||||
return ("ring", "ulysses")
|
||||
|
||||
def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.distributed.device_mesh.DeviceMesh):
|
||||
self._rank = rank
|
||||
self._world_size = world_size
|
||||
self._device = device
|
||||
self._mesh = mesh
|
||||
if self.ring_degree is None:
|
||||
self.ring_degree = 1
|
||||
if self.ulysses_degree is None:
|
||||
self.ulysses_degree = 1
|
||||
if self.rotate_method != "allgather":
|
||||
raise NotImplementedError(
|
||||
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
|
||||
|
||||
if self.ulysses_degree * self.ring_degree > world_size:
|
||||
raise ValueError(
|
||||
f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
|
||||
)
|
||||
if self._flattened_mesh is None:
|
||||
self._flattened_mesh = self._mesh._flatten()
|
||||
if self._ring_mesh is None:
|
||||
self._ring_mesh = self._mesh["ring"]
|
||||
if self._ulysses_mesh is None:
|
||||
self._ulysses_mesh = self._mesh["ulysses"]
|
||||
if self._ring_local_rank is None:
|
||||
self._ring_local_rank = self._ring_mesh.get_local_rank()
|
||||
if self._ulysses_local_rank is None:
|
||||
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
|
||||
|
||||
self._flattened_mesh = self._mesh._flatten()
|
||||
self._ring_mesh = self._mesh["ring"]
|
||||
self._ulysses_mesh = self._mesh["ulysses"]
|
||||
self._ring_local_rank = self._ring_mesh.get_local_rank()
|
||||
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -119,7 +141,7 @@ class ParallelConfig:
|
||||
_rank: int = None
|
||||
_world_size: int = None
|
||||
_device: torch.device = None
|
||||
_cp_mesh: torch.distributed.device_mesh.DeviceMesh = None
|
||||
_mesh: torch.distributed.device_mesh.DeviceMesh = None
|
||||
|
||||
def setup(
|
||||
self,
|
||||
@@ -127,14 +149,14 @@ class ParallelConfig:
|
||||
world_size: int,
|
||||
device: torch.device,
|
||||
*,
|
||||
cp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
|
||||
mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
|
||||
):
|
||||
self._rank = rank
|
||||
self._world_size = world_size
|
||||
self._device = device
|
||||
self._cp_mesh = cp_mesh
|
||||
self._mesh = mesh
|
||||
if self.context_parallel_config is not None:
|
||||
self.context_parallel_config.setup(rank, world_size, device, cp_mesh)
|
||||
self.context_parallel_config.setup(rank, world_size, device, mesh)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
||||
@@ -105,7 +105,7 @@ class AttentionMixin:
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, AttentionModuleMixin):
|
||||
if isinstance(module, AttentionModuleMixin) and module._supports_qkv_fusion:
|
||||
module.fuse_projections()
|
||||
|
||||
def unfuse_qkv_projections(self):
|
||||
@@ -114,13 +114,14 @@ class AttentionMixin:
|
||||
> [!WARNING] > This API is 🧪 experimental.
|
||||
"""
|
||||
for module in self.modules():
|
||||
if isinstance(module, AttentionModuleMixin):
|
||||
if isinstance(module, AttentionModuleMixin) and module._supports_qkv_fusion:
|
||||
module.unfuse_projections()
|
||||
|
||||
|
||||
class AttentionModuleMixin:
|
||||
_default_processor_cls = None
|
||||
_available_processors = []
|
||||
_supports_qkv_fusion = True
|
||||
fused_projections = False
|
||||
|
||||
def set_processor(self, processor: AttentionProcessor) -> None:
|
||||
@@ -248,6 +249,14 @@ class AttentionModuleMixin:
|
||||
"""
|
||||
Fuse the query, key, and value projections into a single projection for efficiency.
|
||||
"""
|
||||
# Skip if the AttentionModuleMixin subclass does not support fusion (for example, the QKV projections in Flux2
|
||||
# single stream blocks are always fused)
|
||||
if not self._supports_qkv_fusion:
|
||||
logger.debug(
|
||||
f"{self.__class__.__name__} does not support fusing QKV projections, so `fuse_projections` will no-op."
|
||||
)
|
||||
return
|
||||
|
||||
# Skip if already fused
|
||||
if getattr(self, "fused_projections", False):
|
||||
return
|
||||
@@ -307,6 +316,11 @@ class AttentionModuleMixin:
|
||||
"""
|
||||
Unfuse the query, key, and value projections back to separate projections.
|
||||
"""
|
||||
# Skip if the AttentionModuleMixin subclass does not support fusion (for example, the QKV projections in Flux2
|
||||
# single stream blocks are always fused)
|
||||
if not self._supports_qkv_fusion:
|
||||
return
|
||||
|
||||
# Skip if not fused
|
||||
if not getattr(self, "fused_projections", False):
|
||||
return
|
||||
|
||||
@@ -16,8 +16,9 @@ import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -27,6 +28,8 @@ if torch.distributed.is_available():
|
||||
|
||||
from ..utils import (
|
||||
get_logger,
|
||||
is_aiter_available,
|
||||
is_aiter_version,
|
||||
is_flash_attn_3_available,
|
||||
is_flash_attn_available,
|
||||
is_flash_attn_version,
|
||||
@@ -40,13 +43,14 @@ from ..utils import (
|
||||
is_xformers_available,
|
||||
is_xformers_version,
|
||||
)
|
||||
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ENABLE_HUB_KERNELS
|
||||
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._modeling_parallel import ParallelConfig
|
||||
|
||||
_REQUIRED_FLASH_VERSION = "2.6.3"
|
||||
_REQUIRED_AITER_VERSION = "0.1.5"
|
||||
_REQUIRED_SAGE_VERSION = "2.1.1"
|
||||
_REQUIRED_FLEX_VERSION = "2.5.0"
|
||||
_REQUIRED_XLA_VERSION = "2.2"
|
||||
@@ -54,6 +58,7 @@ _REQUIRED_XFORMERS_VERSION = "0.0.29"
|
||||
|
||||
_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
|
||||
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
|
||||
_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION)
|
||||
_CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION)
|
||||
_CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION)
|
||||
_CAN_USE_NPU_ATTN = is_torch_npu_available()
|
||||
@@ -78,17 +83,10 @@ else:
|
||||
flash_attn_3_func = None
|
||||
flash_attn_3_varlen_func = None
|
||||
|
||||
if DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||
if not is_kernels_available():
|
||||
raise ImportError(
|
||||
"To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
|
||||
)
|
||||
from ..utils.kernels_utils import _get_fa3_from_hub
|
||||
|
||||
flash_attn_interface_hub = _get_fa3_from_hub()
|
||||
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func
|
||||
if _CAN_USE_AITER_ATTN:
|
||||
from aiter import flash_attn_func as aiter_flash_attn_func
|
||||
else:
|
||||
flash_attn_3_func_hub = None
|
||||
aiter_flash_attn_func = None
|
||||
|
||||
if _CAN_USE_SAGE_ATTN:
|
||||
from sageattention import (
|
||||
@@ -162,22 +160,22 @@ logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
# - CP with sage attention, flex, xformers, other missing backends
|
||||
# - Add support for normal and CP training with backends that don't support it yet
|
||||
|
||||
_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"]
|
||||
_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"]
|
||||
_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"]
|
||||
|
||||
|
||||
class AttentionBackendName(str, Enum):
|
||||
# EAGER = "eager"
|
||||
|
||||
# `flash-attn`
|
||||
FLASH = "flash"
|
||||
FLASH_HUB = "flash_hub"
|
||||
FLASH_VARLEN = "flash_varlen"
|
||||
_FLASH_3 = "_flash_3"
|
||||
_FLASH_VARLEN_3 = "_flash_varlen_3"
|
||||
_FLASH_3_HUB = "_flash_3_hub"
|
||||
# _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet.
|
||||
|
||||
# `aiter`
|
||||
AITER = "aiter"
|
||||
|
||||
# PyTorch native
|
||||
FLEX = "flex"
|
||||
NATIVE = "native"
|
||||
@@ -190,6 +188,7 @@ class AttentionBackendName(str, Enum):
|
||||
|
||||
# `sageattention`
|
||||
SAGE = "sage"
|
||||
SAGE_HUB = "sage_hub"
|
||||
SAGE_VARLEN = "sage_varlen"
|
||||
_SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda"
|
||||
_SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90"
|
||||
@@ -207,7 +206,7 @@ class _AttentionBackendRegistry:
|
||||
_backends = {}
|
||||
_constraints = {}
|
||||
_supported_arg_names = {}
|
||||
_supports_context_parallel = {}
|
||||
_supports_context_parallel = set()
|
||||
_active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND)
|
||||
_checks_enabled = DIFFUSERS_ATTN_CHECKS
|
||||
|
||||
@@ -224,7 +223,9 @@ class _AttentionBackendRegistry:
|
||||
cls._backends[backend] = func
|
||||
cls._constraints[backend] = constraints or []
|
||||
cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys())
|
||||
cls._supports_context_parallel[backend] = supports_context_parallel
|
||||
if supports_context_parallel:
|
||||
cls._supports_context_parallel.add(backend.value)
|
||||
|
||||
return func
|
||||
|
||||
return decorator
|
||||
@@ -238,15 +239,37 @@ class _AttentionBackendRegistry:
|
||||
return list(cls._backends.keys())
|
||||
|
||||
@classmethod
|
||||
def _is_context_parallel_enabled(
|
||||
cls, backend: AttentionBackendName, parallel_config: Optional["ParallelConfig"]
|
||||
def _is_context_parallel_available(
|
||||
cls,
|
||||
backend: AttentionBackendName,
|
||||
) -> bool:
|
||||
supports_context_parallel = backend in cls._supports_context_parallel
|
||||
is_degree_greater_than_1 = parallel_config is not None and (
|
||||
parallel_config.context_parallel_config.ring_degree > 1
|
||||
or parallel_config.context_parallel_config.ulysses_degree > 1
|
||||
)
|
||||
return supports_context_parallel and is_degree_greater_than_1
|
||||
supports_context_parallel = backend.value in cls._supports_context_parallel
|
||||
return supports_context_parallel
|
||||
|
||||
|
||||
@dataclass
|
||||
class _HubKernelConfig:
|
||||
"""Configuration for downloading and using a hub-based attention kernel."""
|
||||
|
||||
repo_id: str
|
||||
function_attr: str
|
||||
revision: Optional[str] = None
|
||||
kernel_fn: Optional[Callable] = None
|
||||
|
||||
|
||||
# Registry for hub-based attention kernels
|
||||
_HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = {
|
||||
# TODO: temporary revision for now. Remove when merged upstream into `main`.
|
||||
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
|
||||
),
|
||||
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_func", revision=None
|
||||
),
|
||||
AttentionBackendName.SAGE_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/sage_attention", function_attr="sageattn", revision=None
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@@ -293,14 +316,6 @@ def dispatch_attention_fn(
|
||||
backend_name = AttentionBackendName(backend)
|
||||
backend_fn = _AttentionBackendRegistry._backends.get(backend_name)
|
||||
|
||||
if parallel_config is not None and not _AttentionBackendRegistry._is_context_parallel_enabled(
|
||||
backend_name, parallel_config
|
||||
):
|
||||
raise ValueError(
|
||||
f"Backend {backend_name} either does not support context parallelism or context parallelism "
|
||||
f"was enabled with a world size of 1."
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
"query": query,
|
||||
"key": key,
|
||||
@@ -379,12 +394,18 @@ def _check_shape(
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
# Expected shapes:
|
||||
# query: (batch_size, seq_len_q, num_heads, head_dim)
|
||||
# key: (batch_size, seq_len_kv, num_heads, head_dim)
|
||||
# value: (batch_size, seq_len_kv, num_heads, head_dim)
|
||||
# attn_mask: (seq_len_q, seq_len_kv) or (batch_size, seq_len_q, seq_len_kv)
|
||||
# or (batch_size, num_heads, seq_len_q, seq_len_kv)
|
||||
if query.shape[-1] != key.shape[-1]:
|
||||
raise ValueError("Query and key must have the same last dimension.")
|
||||
if query.shape[-2] != value.shape[-2]:
|
||||
raise ValueError("Query and value must have the same second to last dimension.")
|
||||
if attn_mask is not None and attn_mask.shape[-1] != key.shape[-2]:
|
||||
raise ValueError("Attention mask must match the key's second to last dimension.")
|
||||
raise ValueError("Query and key must have the same head dimension.")
|
||||
if key.shape[-3] != value.shape[-3]:
|
||||
raise ValueError("Key and value must have the same sequence length.")
|
||||
if attn_mask is not None and attn_mask.shape[-1] != key.shape[-3]:
|
||||
raise ValueError("Attention mask must match the key's sequence length.")
|
||||
|
||||
|
||||
# ===== Helper functions =====
|
||||
@@ -403,15 +424,17 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
||||
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
|
||||
)
|
||||
|
||||
# TODO: add support Hub variant of FA3 varlen later
|
||||
elif backend in [AttentionBackendName._FLASH_3_HUB]:
|
||||
if not DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||
raise RuntimeError(
|
||||
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
|
||||
)
|
||||
# TODO: add support Hub variant of varlen later
|
||||
elif backend in [AttentionBackendName._FLASH_3_HUB, AttentionBackendName.FLASH_HUB, AttentionBackendName.SAGE_HUB]:
|
||||
if not is_kernels_available():
|
||||
raise RuntimeError(
|
||||
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
|
||||
f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
|
||||
)
|
||||
|
||||
elif backend == AttentionBackendName.AITER:
|
||||
if not _CAN_USE_AITER_ATTN:
|
||||
raise RuntimeError(
|
||||
f"Aiter Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `aiter>={_REQUIRED_AITER_VERSION}`."
|
||||
)
|
||||
|
||||
elif backend in [
|
||||
@@ -555,6 +578,29 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
||||
return q_idx >= kv_idx
|
||||
|
||||
|
||||
# ===== Helpers for downloading kernels =====
|
||||
def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
|
||||
if backend not in _HUB_KERNELS_REGISTRY:
|
||||
return
|
||||
config = _HUB_KERNELS_REGISTRY[backend]
|
||||
|
||||
if config.kernel_fn is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
from kernels import get_kernel
|
||||
|
||||
kernel_module = get_kernel(config.repo_id, revision=config.revision)
|
||||
kernel_func = getattr(kernel_module, config.function_attr)
|
||||
|
||||
# Cache the downloaded kernel function in the config object
|
||||
config.kernel_fn = kernel_func
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}")
|
||||
raise
|
||||
|
||||
|
||||
# ===== torch op registrations =====
|
||||
# Registrations are required for fullgraph tracing compatibility
|
||||
# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
|
||||
@@ -630,6 +676,86 @@ def _(
|
||||
# ===== Helper functions to use attention backends with templated CP autograd functions =====
|
||||
|
||||
|
||||
def _native_attention_forward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
_save_ctx: bool = True,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
):
|
||||
# Native attention does not return_lse
|
||||
if return_lse:
|
||||
raise ValueError("Native attention does not support return_lse=True")
|
||||
|
||||
# used for backward pass
|
||||
if _save_ctx:
|
||||
ctx.save_for_backward(query, key, value)
|
||||
ctx.attn_mask = attn_mask
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.is_causal = is_causal
|
||||
ctx.scale = scale
|
||||
ctx.enable_gqa = enable_gqa
|
||||
|
||||
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=is_causal,
|
||||
scale=scale,
|
||||
enable_gqa=enable_gqa,
|
||||
)
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _native_attention_backward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
grad_out: torch.Tensor,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
query, key, value = ctx.saved_tensors
|
||||
|
||||
query.requires_grad_(True)
|
||||
key.requires_grad_(True)
|
||||
value.requires_grad_(True)
|
||||
|
||||
query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query=query_t,
|
||||
key=key_t,
|
||||
value=value_t,
|
||||
attn_mask=ctx.attn_mask,
|
||||
dropout_p=ctx.dropout_p,
|
||||
is_causal=ctx.is_causal,
|
||||
scale=ctx.scale,
|
||||
enable_gqa=ctx.enable_gqa,
|
||||
)
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
|
||||
grad_out_t = grad_out.permute(0, 2, 1, 3)
|
||||
grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
|
||||
outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out_t, retain_graph=False
|
||||
)
|
||||
|
||||
grad_query = grad_query_t.permute(0, 2, 1, 3)
|
||||
grad_key = grad_key_t.permute(0, 2, 1, 3)
|
||||
grad_value = grad_value_t.permute(0, 2, 1, 3)
|
||||
|
||||
return grad_query, grad_key, grad_value
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14958
|
||||
# forward declaration:
|
||||
# aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
|
||||
@@ -1228,6 +1354,38 @@ def _flash_attention(
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.FLASH_HUB,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=False,
|
||||
)
|
||||
def _flash_attention_hub(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
lse = None
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
return_attn_probs=return_lse,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.FLASH_VARLEN,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
@@ -1309,6 +1467,7 @@ def _flash_attention_3(
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName._FLASH_3_HUB,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=False,
|
||||
)
|
||||
def _flash_attention_3_hub(
|
||||
query: torch.Tensor,
|
||||
@@ -1322,7 +1481,11 @@ def _flash_attention_3_hub(
|
||||
return_attn_probs: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
out = flash_attn_3_func_hub(
|
||||
if _parallel_config:
|
||||
raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.")
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
@@ -1397,6 +1560,47 @@ def _flash_varlen_attention_3(
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.AITER,
|
||||
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
)
|
||||
def _aiter_flash_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if not return_lse and torch.is_grad_enabled():
|
||||
# aiter requires return_lse=True by assertion when gradients are enabled.
|
||||
out, lse, *_ = aiter_flash_attn_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
return_lse=True,
|
||||
)
|
||||
else:
|
||||
out = aiter_flash_attn_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
return_lse=return_lse,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.FLEX,
|
||||
constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],
|
||||
@@ -1463,6 +1667,7 @@ def _native_flex_attention(
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.NATIVE,
|
||||
constraints=[_check_device, _check_shape],
|
||||
supports_context_parallel=True,
|
||||
)
|
||||
def _native_attention(
|
||||
query: torch.Tensor,
|
||||
@@ -1478,18 +1683,35 @@ def _native_attention(
|
||||
) -> torch.Tensor:
|
||||
if return_lse:
|
||||
raise ValueError("Native attention backend does not support setting `return_lse=True`.")
|
||||
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=is_causal,
|
||||
scale=scale,
|
||||
enable_gqa=enable_gqa,
|
||||
)
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
if _parallel_config is None:
|
||||
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=dropout_p,
|
||||
is_causal=is_causal,
|
||||
scale=scale,
|
||||
enable_gqa=enable_gqa,
|
||||
)
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
else:
|
||||
out = _templated_context_parallel_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
scale,
|
||||
enable_gqa,
|
||||
return_lse,
|
||||
forward_op=_native_attention_forward_op,
|
||||
backward_op=_native_attention_backward_op,
|
||||
_parallel_config=_parallel_config,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@@ -1756,6 +1978,38 @@ def _sage_attention(
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.SAGE_HUB,
|
||||
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=False,
|
||||
)
|
||||
def _sage_attention_hub(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
lse = None
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
|
||||
if _parallel_config is None:
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
tensor_layout="NHD",
|
||||
is_causal=is_causal,
|
||||
sm_scale=scale,
|
||||
return_lse=return_lse,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.SAGE_VARLEN,
|
||||
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user