Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| be55fa631f |
@@ -116,7 +116,6 @@ jobs:
|
||||
run:
|
||||
shell: bash
|
||||
strategy:
|
||||
fail-fast: false
|
||||
max-parallel: 2
|
||||
matrix:
|
||||
module: [models, schedulers, lora, others, single_file, examples]
|
||||
@@ -291,118 +290,64 @@ jobs:
|
||||
pip install slack_sdk tabulate
|
||||
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
# M1 runner currently not well supported
|
||||
# TODO: (Dhruv) add these back when we setup better testing for Apple Silicon
|
||||
# run_nightly_tests_apple_m1:
|
||||
# name: Nightly PyTorch MPS tests on MacOS
|
||||
# runs-on: [ self-hosted, apple-m1 ]
|
||||
# if: github.event_name == 'schedule'
|
||||
#
|
||||
# steps:
|
||||
# - name: Checkout diffusers
|
||||
# uses: actions/checkout@v3
|
||||
# with:
|
||||
# fetch-depth: 2
|
||||
#
|
||||
# - name: Clean checkout
|
||||
# shell: arch -arch arm64 bash {0}
|
||||
# run: |
|
||||
# git clean -fxd
|
||||
# - name: Setup miniconda
|
||||
# uses: ./.github/actions/setup-miniconda
|
||||
# with:
|
||||
# python-version: 3.9
|
||||
#
|
||||
# - name: Install dependencies
|
||||
# shell: arch -arch arm64 bash {0}
|
||||
# run: |
|
||||
# ${CONDA_RUN} python -m pip install --upgrade pip uv
|
||||
# ${CONDA_RUN} python -m uv pip install -e [quality,test]
|
||||
# ${CONDA_RUN} python -m uv pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
# ${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate
|
||||
# ${CONDA_RUN} python -m uv pip install pytest-reportlog
|
||||
# - name: Environment
|
||||
# shell: arch -arch arm64 bash {0}
|
||||
# run: |
|
||||
# ${CONDA_RUN} python utils/print_env.py
|
||||
# - name: Run nightly PyTorch tests on M1 (MPS)
|
||||
# shell: arch -arch arm64 bash {0}
|
||||
# env:
|
||||
# HF_HOME: /System/Volumes/Data/mnt/cache
|
||||
# HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
# run: |
|
||||
# ${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps \
|
||||
# --report-log=tests_torch_mps.log \
|
||||
# tests/
|
||||
# - name: Failure short reports
|
||||
# if: ${{ failure() }}
|
||||
# run: cat reports/tests_torch_mps_failures_short.txt
|
||||
#
|
||||
# - name: Test suite reports artifacts
|
||||
# if: ${{ always() }}
|
||||
# uses: actions/upload-artifact@v2
|
||||
# with:
|
||||
# name: torch_mps_test_reports
|
||||
# path: reports
|
||||
#
|
||||
# - name: Generate Report and Notify Channel
|
||||
# if: always()
|
||||
# run: |
|
||||
# pip install slack_sdk tabulate
|
||||
# python utils/log_reports.py >> $GITHUB_STEP_SUMMARY run_nightly_tests_apple_m1:
|
||||
# name: Nightly PyTorch MPS tests on MacOS
|
||||
# runs-on: [ self-hosted, apple-m1 ]
|
||||
# if: github.event_name == 'schedule'
|
||||
#
|
||||
# steps:
|
||||
# - name: Checkout diffusers
|
||||
# uses: actions/checkout@v3
|
||||
# with:
|
||||
# fetch-depth: 2
|
||||
#
|
||||
# - name: Clean checkout
|
||||
# shell: arch -arch arm64 bash {0}
|
||||
# run: |
|
||||
# git clean -fxd
|
||||
# - name: Setup miniconda
|
||||
# uses: ./.github/actions/setup-miniconda
|
||||
# with:
|
||||
# python-version: 3.9
|
||||
#
|
||||
# - name: Install dependencies
|
||||
# shell: arch -arch arm64 bash {0}
|
||||
# run: |
|
||||
# ${CONDA_RUN} python -m pip install --upgrade pip uv
|
||||
# ${CONDA_RUN} python -m uv pip install -e [quality,test]
|
||||
# ${CONDA_RUN} python -m uv pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
# ${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate
|
||||
# ${CONDA_RUN} python -m uv pip install pytest-reportlog
|
||||
# - name: Environment
|
||||
# shell: arch -arch arm64 bash {0}
|
||||
# run: |
|
||||
# ${CONDA_RUN} python utils/print_env.py
|
||||
# - name: Run nightly PyTorch tests on M1 (MPS)
|
||||
# shell: arch -arch arm64 bash {0}
|
||||
# env:
|
||||
# HF_HOME: /System/Volumes/Data/mnt/cache
|
||||
# HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
# run: |
|
||||
# ${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps \
|
||||
# --report-log=tests_torch_mps.log \
|
||||
# tests/
|
||||
# - name: Failure short reports
|
||||
# if: ${{ failure() }}
|
||||
# run: cat reports/tests_torch_mps_failures_short.txt
|
||||
#
|
||||
# - name: Test suite reports artifacts
|
||||
# if: ${{ always() }}
|
||||
# uses: actions/upload-artifact@v2
|
||||
# with:
|
||||
# name: torch_mps_test_reports
|
||||
# path: reports
|
||||
#
|
||||
# - name: Generate Report and Notify Channel
|
||||
# if: always()
|
||||
# run: |
|
||||
# pip install slack_sdk tabulate
|
||||
# python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
run_nightly_tests_apple_m1:
|
||||
name: Nightly PyTorch MPS tests on MacOS
|
||||
runs-on: [ self-hosted, apple-m1 ]
|
||||
if: github.event_name == 'schedule'
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Clean checkout
|
||||
shell: arch -arch arm64 bash {0}
|
||||
run: |
|
||||
git clean -fxd
|
||||
|
||||
- name: Setup miniconda
|
||||
uses: ./.github/actions/setup-miniconda
|
||||
with:
|
||||
python-version: 3.9
|
||||
|
||||
- name: Install dependencies
|
||||
shell: arch -arch arm64 bash {0}
|
||||
run: |
|
||||
${CONDA_RUN} python -m pip install --upgrade pip uv
|
||||
${CONDA_RUN} python -m uv pip install -e [quality,test]
|
||||
${CONDA_RUN} python -m uv pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate
|
||||
${CONDA_RUN} python -m uv pip install pytest-reportlog
|
||||
|
||||
- name: Environment
|
||||
shell: arch -arch arm64 bash {0}
|
||||
run: |
|
||||
${CONDA_RUN} python utils/print_env.py
|
||||
|
||||
- name: Run nightly PyTorch tests on M1 (MPS)
|
||||
shell: arch -arch arm64 bash {0}
|
||||
env:
|
||||
HF_HOME: /System/Volumes/Data/mnt/cache
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
run: |
|
||||
${CONDA_RUN} python -m pytest -n 1 -s -v --make-reports=tests_torch_mps \
|
||||
--report-log=tests_torch_mps.log \
|
||||
tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_torch_mps_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: torch_mps_test_reports
|
||||
path: reports
|
||||
|
||||
- name: Generate Report and Notify Channel
|
||||
if: always()
|
||||
run: |
|
||||
pip install slack_sdk tabulate
|
||||
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
name: Fast GPU Tests on main
|
||||
name: Slow Tests on main
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -112,8 +112,6 @@ jobs:
|
||||
run:
|
||||
shell: bash
|
||||
strategy:
|
||||
fail-fast: false
|
||||
max-parallel: 2
|
||||
matrix:
|
||||
module: [models, schedulers, lora, others, single_file]
|
||||
steps:
|
||||
|
||||
@@ -1,386 +0,0 @@
|
||||
name: Release Fast GPU Tests on main
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- "v*.*.*-release"
|
||||
- "v*.*.*-patch"
|
||||
|
||||
env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
OMP_NUM_THREADS: 8
|
||||
MKL_NUM_THREADS: 8
|
||||
PYTEST_TIMEOUT: 600
|
||||
PIPELINE_USAGE_CUTOFF: 50000
|
||||
|
||||
jobs:
|
||||
setup_torch_cuda_pipeline_matrix:
|
||||
name: Setup Torch Pipelines CUDA Slow Tests Matrix
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
outputs:
|
||||
pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
- name: Fetch Pipeline Matrix
|
||||
id: fetch_pipeline_matrix
|
||||
run: |
|
||||
matrix=$(python utils/fetch_torch_cuda_pipeline_test_matrix.py)
|
||||
echo $matrix
|
||||
echo "pipeline_test_matrix=$matrix" >> $GITHUB_OUTPUT
|
||||
- name: Pipeline Tests Artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: test-pipelines.json
|
||||
path: reports
|
||||
|
||||
torch_pipelines_cuda_tests:
|
||||
name: Torch Pipelines CUDA Tests
|
||||
needs: setup_torch_cuda_pipeline_matrix
|
||||
strategy:
|
||||
fail-fast: false
|
||||
max-parallel: 8
|
||||
matrix:
|
||||
module: ${{ fromJson(needs.setup_torch_cuda_pipeline_matrix.outputs.pipeline_test_matrix) }}
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "16gb" --ipc host --gpus 0
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: NVIDIA-SMI
|
||||
run: |
|
||||
nvidia-smi
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
- name: Slow PyTorch CUDA checkpoint tests on Ubuntu
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
|
||||
tests/pipelines/${{ matrix.module }}
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_pipeline_${{ matrix.module }}_cuda_stats.txt
|
||||
cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: pipeline_${{ matrix.module }}_test_reports
|
||||
path: reports
|
||||
|
||||
torch_cuda_tests:
|
||||
name: Torch CUDA Tests
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "16gb" --ipc host --gpus 0
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
strategy:
|
||||
fail-fast: false
|
||||
max-parallel: 2
|
||||
matrix:
|
||||
module: [models, schedulers, lora, others, single_file]
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
python -m uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run PyTorch CUDA tests
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_torch_cuda \
|
||||
tests/${{ matrix.module }}
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_torch_cuda_stats.txt
|
||||
cat reports/tests_torch_cuda_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: torch_cuda_test_reports
|
||||
path: reports
|
||||
|
||||
flax_tpu_tests:
|
||||
name: Flax TPU Tests
|
||||
runs-on: docker-tpu
|
||||
container:
|
||||
image: diffusers/diffusers-flax-tpu
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ --privileged
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run slow Flax TPU tests
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 0 \
|
||||
-s -v -k "Flax" \
|
||||
--make-reports=tests_flax_tpu \
|
||||
tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_flax_tpu_stats.txt
|
||||
cat reports/tests_flax_tpu_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: flax_tpu_test_reports
|
||||
path: reports
|
||||
|
||||
onnx_cuda_tests:
|
||||
name: ONNX CUDA Tests
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: diffusers/diffusers-onnxruntime-cuda
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ --gpus 0
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run slow ONNXRuntime CUDA tests
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Onnx" \
|
||||
--make-reports=tests_onnx_cuda \
|
||||
tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_onnx_cuda_stats.txt
|
||||
cat reports/tests_onnx_cuda_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: onnx_cuda_test_reports
|
||||
path: reports
|
||||
|
||||
run_torch_compile_tests:
|
||||
name: PyTorch Compile CUDA tests
|
||||
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-compile-cuda
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: NVIDIA-SMI
|
||||
run: |
|
||||
nvidia-smi
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test,training]
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
- name: Run example tests on GPU
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
RUN_COMPILE: yes
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_torch_compile_cuda_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: torch_compile_test_reports
|
||||
path: reports
|
||||
|
||||
run_xformers_tests:
|
||||
name: PyTorch xformers CUDA tests
|
||||
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-xformers-cuda
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: NVIDIA-SMI
|
||||
run: |
|
||||
nvidia-smi
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test,training]
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
- name: Run example tests on GPU
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_torch_xformers_cuda_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: torch_xformers_test_reports
|
||||
path: reports
|
||||
|
||||
run_examples_tests:
|
||||
name: Examples PyTorch CUDA tests on Ubuntu
|
||||
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: NVIDIA-SMI
|
||||
run: |
|
||||
nvidia-smi
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test,training]
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run example tests on GPU
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install timm
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/examples_torch_cuda_stats.txt
|
||||
cat reports/examples_torch_cuda_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: examples_test_reports
|
||||
path: reports
|
||||
+2
-2
@@ -57,7 +57,7 @@ Any question or comment related to the Diffusers library can be asked on the [di
|
||||
- ...
|
||||
|
||||
Every question that is asked on the forum or on Discord actively encourages the community to publicly
|
||||
share knowledge and might very well help a beginner in the future who has the same question you're
|
||||
share knowledge and might very well help a beginner in the future that has the same question you're
|
||||
having. Please do pose any questions you might have.
|
||||
In the same spirit, you are of immense help to the community by answering such questions because this way you are publicly documenting knowledge for everybody to learn from.
|
||||
|
||||
@@ -503,4 +503,4 @@ $ git push --set-upstream origin your-branch-for-syncing
|
||||
|
||||
### Style guide
|
||||
|
||||
For documentation strings, 🧨 Diffusers follows the [Google style](https://google.github.io/styleguide/pyguide.html).
|
||||
For documentation strings, 🧨 Diffusers follows the [Google style](https://google.github.io/styleguide/pyguide.html).
|
||||
+2
-2
@@ -15,7 +15,7 @@ specific language governing permissions and limitations under the License.
|
||||
🧨 Diffusers provides **state-of-the-art** pretrained diffusion models across multiple modalities.
|
||||
Its purpose is to serve as a **modular toolbox** for both inference and training.
|
||||
|
||||
We aim to build a library that stands the test of time and therefore take API design very seriously.
|
||||
We aim at building a library that stands the test of time and therefore take API design very seriously.
|
||||
|
||||
In a nutshell, Diffusers is built to be a natural extension of PyTorch. Therefore, most of our design choices are based on [PyTorch's Design Principles](https://pytorch.org/docs/stable/community/design.html#pytorch-design-philosophy). Let's go over the most important ones:
|
||||
|
||||
@@ -107,4 +107,4 @@ The following design principles are followed:
|
||||
- Every scheduler exposes the timesteps to be "looped over" via a `timesteps` attribute, which is an array of timesteps the model will be called upon.
|
||||
- The `step(...)` function takes a predicted model output and the "current" sample (x_t) and returns the "previous", slightly more denoised sample (x_t-1).
|
||||
- Given the complexity of diffusion schedulers, the `step` function does not expose all the complexity and can be a bit of a "black box".
|
||||
- In almost all cases, novel schedulers shall be implemented in a new scheduling file.
|
||||
- In almost all cases, novel schedulers shall be implemented in a new scheduling file.
|
||||
@@ -15,7 +15,9 @@
|
||||
|
||||
# CogVideoX
|
||||
|
||||
[CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://arxiv.org/abs/2408.06072) from Tsinghua University & ZhipuAI, by Zhuoyi Yang, Jiayan Teng, Wendi Zheng, Ming Ding, Shiyu Huang, Jiazheng Xu, Yuanming Yang, Wenyi Hong, Xiaohan Zhang, Guanyu Feng, Da Yin, Xiaotao Gu, Yuxuan Zhang, Weihan Wang, Yean Cheng, Ting Liu, Bin Xu, Yuxiao Dong, Jie Tang.
|
||||
<!-- TODO: update paper with ArXiv link when ready. -->
|
||||
|
||||
[CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://github.com/THUDM/CogVideo/blob/main/resources/CogVideoX.pdf) from Tsinghua University & ZhipuAI.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
@@ -29,10 +31,6 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
|
||||
|
||||
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).
|
||||
|
||||
There are two models available that can be used with the CogVideoX pipeline:
|
||||
- [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b)
|
||||
- [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b)
|
||||
|
||||
## Inference
|
||||
|
||||
Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency.
|
||||
@@ -45,42 +43,43 @@ from diffusers import CogVideoXPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b").to("cuda")
|
||||
prompt = (
|
||||
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
|
||||
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
|
||||
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
|
||||
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
|
||||
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
|
||||
"atmosphere of this unique musical performance."
|
||||
)
|
||||
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
|
||||
export_to_video(video, "output.mp4", fps=8)
|
||||
```
|
||||
|
||||
Then change the memory layout of the pipelines `transformer` component to `torch.channels_last`:
|
||||
Then change the memory layout of the pipelines `transformer` and `vae` components to `torch.channels-last`:
|
||||
|
||||
```python
|
||||
pipe.transformer.to(memory_format=torch.channels_last)
|
||||
pipeline.transformer.to(memory_format=torch.channels_last)
|
||||
pipeline.vae.to(memory_format=torch.channels_last)
|
||||
```
|
||||
|
||||
Finally, compile the components and run inference:
|
||||
|
||||
```python
|
||||
pipe.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
|
||||
pipeline.transformer = torch.compile(pipeline.transformer)
|
||||
pipeline.vae.decode = torch.compile(pipeline.vae.decode)
|
||||
|
||||
# CogVideoX works well with long and well-described prompts
|
||||
# CogVideoX works very well with long and well-described prompts
|
||||
prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance."
|
||||
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
|
||||
video = pipeline(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
|
||||
```
|
||||
|
||||
The [benchmark](https://gist.github.com/a-r-r-o-w/5183d75e452a368fd17448fcc810bd3f) results on an 80GB A100 machine are:
|
||||
The [benchmark](TODO: link) results on an 80GB A100 machine are:
|
||||
|
||||
```
|
||||
Without torch.compile(): Average inference time: 96.89 seconds.
|
||||
With torch.compile(): Average inference time: 76.27 seconds.
|
||||
Without torch.compile(): Average inference time: TODO seconds.
|
||||
With torch.compile(): Average inference time: TODO seconds.
|
||||
```
|
||||
|
||||
### Memory optimization
|
||||
|
||||
CogVideoX-2b requires about 19 GB of GPU memory to decode 49 frames (6 seconds of video at 8 FPS) with output resolution 720x480 (W x H), which makes it not possible to run on consumer GPUs or free-tier T4 Colab. The following memory optimizations could be used to reduce the memory footprint. For replication, you can refer to [this](https://gist.github.com/a-r-r-o-w/3959a03f15be5c9bd1fe545b09dfcc93) script.
|
||||
|
||||
- `pipe.enable_model_cpu_offload()`:
|
||||
- Without enabling cpu offloading, memory usage is `33 GB`
|
||||
- With enabling cpu offloading, memory usage is `19 GB`
|
||||
- `pipe.vae.enable_tiling()`:
|
||||
- With enabling cpu offloading and tiling, memory usage is `11 GB`
|
||||
- `pipe.vae.enable_slicing()`
|
||||
|
||||
## CogVideoXPipeline
|
||||
|
||||
[[autodoc]] CogVideoXPipeline
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
<!--Copyright 2023 The HuggingFace Team and The InstantX Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
@@ -22,16 +22,7 @@ The abstract from the paper is:
|
||||
|
||||
*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*
|
||||
|
||||
This controlnet code is mainly implemented by [The InstantX Team](https://huggingface.co/InstantX). The inpainting-related code was developed by [The Alimama Creative Team](https://huggingface.co/alimama-creative). You can find pre-trained checkpoints for SD3-ControlNet in the table below:
|
||||
|
||||
|
||||
| ControlNet type | Developer | Link |
|
||||
| -------- | ---------- | ---- |
|
||||
| Canny | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/InstantX/SD3-Controlnet-Canny) |
|
||||
| Pose | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/InstantX/SD3-Controlnet-Pose) |
|
||||
| Tile | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/InstantX/SD3-Controlnet-Tile) |
|
||||
| Inpainting | [The AlimamaCreative Team](https://huggingface.co/alimama-creative) | [link](https://huggingface.co/alimama-creative/SD3-Controlnet-Inpainting) |
|
||||
|
||||
This code is implemented by [The InstantX Team](https://huggingface.co/InstantX). You can find pre-trained checkpoints for SD3-ControlNet on [The InstantX Team](https://huggingface.co/InstantX) Hub profile.
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -44,10 +35,5 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusion3ControlNetInpaintingPipeline
|
||||
[[autodoc]] pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet_inpainting.StableDiffusion3ControlNetInpaintingPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusion3PipelineOutput
|
||||
[[autodoc]] pipelines.stable_diffusion_3.pipeline_output.StableDiffusion3PipelineOutput
|
||||
|
||||
@@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||

|
||||
|
||||
Kolors is a large-scale text-to-image generation model based on latent diffusion, developed by [the Kuaishou Kolors team](https://github.com/Kwai-Kolors/Kolors). Trained on billions of text-image pairs, Kolors exhibits significant advantages over both open-source and closed-source models in visual quality, complex semantic accuracy, and text rendering for both Chinese and English characters. Furthermore, Kolors supports both Chinese and English inputs, demonstrating strong performance in understanding and generating Chinese-specific content. For more details, please refer to this [technical report](https://github.com/Kwai-Kolors/Kolors/blob/master/imgs/Kolors_paper.pdf).
|
||||
Kolors is a large-scale text-to-image generation model based on latent diffusion, developed by [the Kuaishou Kolors team](kwai-kolors@kuaishou.com). Trained on billions of text-image pairs, Kolors exhibits significant advantages over both open-source and closed-source models in visual quality, complex semantic accuracy, and text rendering for both Chinese and English characters. Furthermore, Kolors supports both Chinese and English inputs, demonstrating strong performance in understanding and generating Chinese-specific content. For more details, please refer to this [technical report](https://github.com/Kwai-Kolors/Kolors/blob/master/imgs/Kolors_paper.pdf).
|
||||
|
||||
The abstract from the technical report is:
|
||||
|
||||
@@ -74,7 +74,7 @@ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
||||
|
||||
pipe = KolorsPipeline.from_pretrained(
|
||||
"Kwai-Kolors/Kolors-diffusers", image_encoder=image_encoder, torch_dtype=torch.float16, variant="fp16"
|
||||
)
|
||||
).to("cuda")
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)
|
||||
|
||||
pipe.load_ip_adapter(
|
||||
|
||||
@@ -20,7 +20,7 @@ The abstract from the paper is:
|
||||
|
||||
*Recent studies have demonstrated that diffusion models are capable of generating high-quality samples, but their quality heavily depends on sampling guidance techniques, such as classifier guidance (CG) and classifier-free guidance (CFG). These techniques are often not applicable in unconditional generation or in various downstream tasks such as image restoration. In this paper, we propose a novel sampling guidance, called Perturbed-Attention Guidance (PAG), which improves diffusion sample quality across both unconditional and conditional settings, achieving this without requiring additional training or the integration of external modules. PAG is designed to progressively enhance the structure of samples throughout the denoising process. It involves generating intermediate samples with degraded structure by substituting selected self-attention maps in diffusion U-Net with an identity matrix, by considering the self-attention mechanisms' ability to capture structural information, and guiding the denoising process away from these degraded samples. In both ADM and Stable Diffusion, PAG surprisingly improves sample quality in conditional and even unconditional scenarios. Moreover, PAG significantly improves the baseline performance in various downstream tasks where existing guidances such as CG or CFG cannot be fully utilized, including ControlNet with empty prompts and image restoration such as inpainting and deblurring.*
|
||||
|
||||
PAG can be used by specifying the `pag_applied_layers` as a parameter when instantiating a PAG pipeline. It can be a single string or a list of strings. Each string can be a unique layer identifier or a regular expression to identify one or more layers.
|
||||
PAG can be used by specifying the `pag_applied_layers` as a parameter when instantiating a PAG pipeline. It can be a single string or a list of strings. Each string can be a unique layer identifier or a regular expression to identify one or more layers.
|
||||
|
||||
- Full identifier as a normal string: `down_blocks.2.attentions.0.transformer_blocks.0.attn1.processor`
|
||||
- Full identifier as a RegEx: `down_blocks.2.(attentions|motion_modules).0.transformer_blocks.0.attn1.processor`
|
||||
@@ -46,7 +46,7 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial
|
||||
## KolorsPAGPipeline
|
||||
[[autodoc]] KolorsPAGPipeline
|
||||
- all
|
||||
- __call__
|
||||
- __call__
|
||||
|
||||
## StableDiffusionPAGPipeline
|
||||
[[autodoc]] StableDiffusionPAGPipeline
|
||||
@@ -78,10 +78,6 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusionXLControlNetPAGImg2ImgPipeline
|
||||
[[autodoc]] StableDiffusionXLControlNetPAGImg2ImgPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusion3PAGPipeline
|
||||
[[autodoc]] StableDiffusion3PAGPipeline
|
||||
|
||||
@@ -238,7 +238,7 @@ Pretty impressive! Let's tweak the second image - corresponding to the `Generato
|
||||
```python
|
||||
prompts = [
|
||||
"portrait photo of the oldest warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
|
||||
"portrait photo of an old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
|
||||
"portrait photo of a old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
|
||||
"portrait photo of a warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
|
||||
"portrait photo of a young warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
|
||||
]
|
||||
|
||||
@@ -48,7 +48,7 @@ accelerate launch run_distributed.py --num_processes=2
|
||||
|
||||
<Tip>
|
||||
|
||||
Refer to this minimal example [script](https://gist.github.com/sayakpaul/cfaebd221820d7b43fae638b4dfa01ba) for running inference across multiple GPUs. To learn more, take a look at the [Distributed Inference with 🤗 Accelerate](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate) guide.
|
||||
To learn more, take a look at the [Distributed Inference with 🤗 Accelerate](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate) guide.
|
||||
|
||||
</Tip>
|
||||
|
||||
@@ -108,4 +108,4 @@ torchrun run_distributed.py --nproc_per_node=2
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> You can use `device_map` within a [`DiffusionPipeline`] to distribute its model-level components on multiple devices. Refer to the [Device placement](../tutorials/inference_with_big_models#device-placement) guide to learn more.
|
||||
> You can use `device_map` within a [`DiffusionPipeline`] to distribute its model-level components on multiple devices. Refer to the [Device placement](../tutorials/inference_with_big_models#device-placement) guide to learn more.
|
||||
@@ -10,7 +10,7 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Diffusers에 기여하는 방법 🧨 [[how-to-contribute-to-diffusers-]]
|
||||
# Diffusers에 기여하는 방법 🧨
|
||||
|
||||
오픈 소스 커뮤니티에서의 기여를 환영합니다! 누구나 참여할 수 있으며, 코드뿐만 아니라 질문에 답변하거나 문서를 개선하는 등 모든 유형의 참여가 가치 있고 감사히 여겨집니다. 질문에 답변하고 다른 사람들을 도와주며 소통하고 문서를 개선하는 것은 모두 커뮤니티에게 큰 도움이 됩니다. 따라서 관심이 있다면 두려워하지 말고 참여해보세요!
|
||||
|
||||
@@ -18,9 +18,9 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
어떤 방식으로든 기여하려는 경우, 우리는 개방적이고 환영하며 친근한 커뮤니티의 일부가 되기 위해 노력하고 있습니다. 우리의 [행동 강령](https://github.com/huggingface/diffusers/blob/main/CODE_OF_CONDUCT.md)을 읽고 상호 작용 중에 이를 존중하도록 주의해주시기 바랍니다. 또한 프로젝트를 안내하는 [윤리 지침](https://huggingface.co/docs/diffusers/conceptual/ethical_guidelines)에 익숙해지고 동일한 투명성과 책임성의 원칙을 준수해주시기를 부탁드립니다.
|
||||
|
||||
우리는 커뮤니티로부터의 피드백을 매우 중요하게 생각하므로, 라이브러리를 개선하는 데 도움이 될 가치 있는 피드백이 있다고 생각되면 망설이지 말고 의견을 제시해주세요 - 모든 메시지, 댓글, 이슈, Pull Request(PR)는 읽히고 고려됩니다.
|
||||
우리는 커뮤니티로부터의 피드백을 매우 중요하게 생각하므로, 라이브러리를 개선하는 데 도움이 될 가치 있는 피드백이 있다고 생각되면 망설이지 말고 의견을 제시해주세요 - 모든 메시지, 댓글, 이슈, 풀 리퀘스트(PR)는 읽히고 고려됩니다.
|
||||
|
||||
## 개요 [[overview]]
|
||||
## 개요
|
||||
|
||||
이슈에 있는 질문에 답변하는 것에서부터 코어 라이브러리에 새로운 diffusion 모델을 추가하는 것까지 다양한 방법으로 기여를 할 수 있습니다.
|
||||
|
||||
@@ -38,9 +38,9 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
앞서 말한 대로, **모든 기여는 커뮤니티에게 가치가 있습니다**. 이어지는 부분에서 각 기여에 대해 조금 더 자세히 설명하겠습니다.
|
||||
|
||||
4부터 9까지의 모든 기여에는 Pull Request을 열어야 합니다. [Pull Request 열기](#how-to-open-a-pr)에서 자세히 설명되어 있습니다.
|
||||
4부터 9까지의 모든 기여에는 PR을 열어야 합니다. [PR을 열기](#how-to-open-a-pr)에서 자세히 설명되어 있습니다.
|
||||
|
||||
### 1. Diffusers 토론 포럼이나 Diffusers Discord에서 질문하고 답변하기 [[1-asking-and-answering-questions-on-the-diffusers-discussion-forum-or-on-the-diffusers-discord]]
|
||||
### 1. Diffusers 토론 포럼이나 Diffusers Discord에서 질문하고 답변하기
|
||||
|
||||
Diffusers 라이브러리와 관련된 모든 질문이나 의견은 [토론 포럼](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63)이나 [Discord](https://discord.gg/G7tWnz98XR)에서 할 수 있습니다. 이러한 질문과 의견에는 다음과 같은 내용이 포함됩니다(하지만 이에 국한되지는 않습니다):
|
||||
- 지식을 공유하기 위해서 훈련 또는 추론 실험에 대한 결과 보고
|
||||
@@ -54,7 +54,7 @@ Diffusers 라이브러리와 관련된 모든 질문이나 의견은 [토론 포
|
||||
- Diffusion 모델에 대한 윤리적 질문
|
||||
- ...
|
||||
|
||||
포럼이나 Discord에서 질문을 하면 커뮤니티가 지식을 공개적으로 공유하도록 장려되며, 향후 동일한 질문을 가진 초보자에게도 도움이 될 수 있습니다. 따라서 궁금한 질문은 언제든지 하시기 바랍니다.
|
||||
포럼이나 Discord에서 질문을 하면 커뮤니티가 지식을 공개적으로 공유하도록 장려되며, 미래에 동일한 질문을 가진 초보자에게도 도움이 될 수 있습니다. 따라서 궁금한 질문은 언제든지 하시기 바랍니다.
|
||||
또한, 이러한 질문에 답변하는 것은 커뮤니티에게 매우 큰 도움이 됩니다. 왜냐하면 이렇게 하면 모두가 학습할 수 있는 공개적인 지식을 문서화하기 때문입니다.
|
||||
|
||||
**주의**하십시오. 질문이나 답변에 투자하는 노력이 많을수록 공개적으로 문서화된 지식의 품질이 높아집니다. 마찬가지로, 잘 정의되고 잘 답변된 질문은 모두에게 접근 가능한 고품질 지식 데이터베이스를 만들어줍니다. 반면에 잘못된 질문이나 답변은 공개 지식 데이터베이스의 전반적인 품질을 낮출 수 있습니다.
|
||||
@@ -64,9 +64,9 @@ Diffusers 라이브러리와 관련된 모든 질문이나 의견은 [토론 포
|
||||
[*포럼*](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63)은 구글과 같은 검색 엔진에서 더 잘 색인화됩니다. 게시물은 인기에 따라 순위가 매겨지며, 시간순으로 정렬되지 않습니다. 따라서 이전에 게시한 질문과 답변을 쉽게 찾을 수 있습니다.
|
||||
또한, 포럼에 게시된 질문과 답변은 쉽게 링크할 수 있습니다.
|
||||
반면 *Discord*는 채팅 형식으로 되어 있어 빠른 대화를 유도합니다.
|
||||
질문에 대한 답변을 빠르게 받을 수는 있겠지만, 시간이 지나면 질문이 더 이상 보이지 않습니다. 또한, Discord에서 이전에 게시된 정보를 찾는 것은 훨씬 어렵습니다. 따라서 포럼을 사용하여 고품질의 질문과 답변을 하여 커뮤니티를 위한 오래 지속되는 지식을 만들기를 권장합니다. Discord에서의 토론이 매우 흥미로운 답변과 결론을 이끌어내는 경우, 해당 정보를 포럼에 게시하여 향후 독자들에게 더 쉽게 액세스할 수 있도록 권장합니다.
|
||||
질문에 대한 답변을 빠르게 받을 수는 있겠지만, 시간이 지나면 질문이 더 이상 보이지 않습니다. 또한, Discord에서 이전에 게시된 정보를 찾는 것은 훨씬 어렵습니다. 따라서 포럼을 사용하여 고품질의 질문과 답변을 하여 커뮤니티를 위한 오래 지속되는 지식을 만들기를 권장합니다. Discord에서의 토론이 매우 흥미로운 답변과 결론을 이끌어내는 경우, 해당 정보를 포럼에 게시하여 미래 독자들에게 더 쉽게 액세스할 수 있도록 권장합니다.
|
||||
|
||||
### 2. GitHub 이슈 탭에서 새로운 이슈 열기 [[2-opening-new-issues-on-the-github-issues-tab]]
|
||||
### 2. GitHub 이슈 탭에서 새로운 이슈 열기
|
||||
|
||||
🧨 Diffusers 라이브러리는 사용자들이 마주치는 문제를 알려주는 덕분에 견고하고 신뢰할 수 있습니다. 따라서 이슈를 보고해주셔서 감사합니다.
|
||||
|
||||
@@ -81,52 +81,53 @@ Diffusers 라이브러리와 관련된 모든 질문이나 의견은 [토론 포
|
||||
- 이슈가 최신 Diffusers 버전으로 업데이트하면 해결될 수 있는지 확인해주세요. 이슈를 게시하기 전에 `python -c "import diffusers; print(diffusers.__version__)"` 명령을 실행하여 현재 사용 중인 Diffusers 버전이 최신 버전과 일치하거나 더 높은지 확인해주세요.
|
||||
- 새로운 이슈를 열 때 투자하는 노력이 많을수록 답변의 품질이 높아지고 Diffusers 이슈 전체의 품질도 향상됩니다.
|
||||
|
||||
#### 2.1 재현 가능한 최소한의 버그 리포트 [[21-reproducible-minimal-bug-reports]]
|
||||
#### 2.1 재현가능하고 최소한인 버그 리포트
|
||||
|
||||
새로운 이슈는 일반적으로 다음과 같은 내용을 포함합니다.
|
||||
|
||||
버그 리포트는 항상 재현 가능한 코드 조각을 포함하고 가능한 한 최소한이어야 하며 간결해야 합니다.
|
||||
버그 보고서는 항상 재현 가능한 코드 조각을 포함하고 가능한 한 최소한이어야 하며 간결해야 합니다.
|
||||
자세히 말하면:
|
||||
- 버그를 가능한 한 좁혀야 합니다. **전체 코드 파일을 그냥 던지지 마세요**.
|
||||
- 코드의 서식을 지정해야 합니다.
|
||||
- Diffusers가 의존하는 외부 라이브러리를 제외한 다른 외부 라이브러리는 포함하지 마십시오.
|
||||
- **항상** 사용자 환경에 대한 모든 필요한 정보를 제공하세요. 이를 위해 쉘에서 `diffusers-cli env`를 실행하고 표시된 정보를 이슈에 복사하여 붙여넣을 수 있습니다.
|
||||
- 이슈를 설명해야 합니다. 독자가 문제가 무엇인지, 왜 문제가 되는지 모른다면 이슈를 해결할 수 없습니다.
|
||||
- **항상** 독자가 가능한 한 적은 노력으로 문제를 재현할 수 있어야 합니다. 코드 조각이 라이브러리가 없거나 정의되지 않은 변수 때문에 실행되지 않는 경우 독자가 도움을 줄 수 없습니다. 재현 가능한 코드 조각이 가능한 한 최소화되고 간단한 Python 셸에 복사하여 붙여넣을 수 있도록 해야 합니다.
|
||||
- **반드시** 환경에 대한 모든 필요한 정보를 제공해야 합니다. 이를 위해 쉘에서 `diffusers-cli env`를 실행하고 표시된 정보를 이슈에 복사하여 붙여넣을 수 있습니다.
|
||||
- 이슈를 설명해야 합니다. 독자가 문제가 무엇이며 왜 문제인지 모르면 해결할 수 없습니다.
|
||||
- **항상** 독자가 가능한 한 적은 노력으로 문제를 재현할 수 있도록 해야 합니다. 코드 조각이 라이브러리가 없거나 정의되지 않은 변수 때문에 실행되지 않는 경우 독자가 도움을 줄 수 없습니다. 재현 가능한 코드 조각이 가능한 한 최소화되고 간단한 Python 셸에 복사하여 붙여넣을 수 있도록 해야 합니다.
|
||||
- 문제를 재현하기 위해 모델과/또는 데이터셋이 필요한 경우 독자가 해당 모델이나 데이터셋에 접근할 수 있도록 해야 합니다. 모델이나 데이터셋을 [Hub](https://huggingface.co)에 업로드하여 쉽게 다운로드할 수 있도록 할 수 있습니다. 문제 재현을 가능한 한 쉽게하기 위해 모델과 데이터셋을 가능한 한 작게 유지하려고 노력하세요.
|
||||
|
||||
자세한 내용은 [좋은 이슈 작성 방법](#how-to-write-a-good-issue) 섹션을 참조하세요.
|
||||
|
||||
버그 리포트를 열려면 [여기](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=bug&projects=&template=bug-report.yml)를 클릭하세요.
|
||||
버그 보고서를 열려면 [여기](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=bug&projects=&template=bug-report.yml)를 클릭하세요.
|
||||
|
||||
|
||||
#### 2.2. 기능 요청 [[22-feature-requests]]
|
||||
#### 2.2. 기능 요청
|
||||
|
||||
세계적인 기능 요청은 다음 사항을 다룹니다:
|
||||
|
||||
1. 먼저 동기부여:
|
||||
* 라이브러리와 관련된 문제/불만이 있나요? 그렇다면 왜 그런지 설명해주세요. 문제를 보여주는 코드 조각을 제공하는 것이 가장 좋습니다.
|
||||
* 라이브러리와 관련된 문제/불만이 있는가요? 그렇다면 왜 그런지 설명해주세요. 문제를 보여주는 코드 조각을 제공하는 것이 가장 좋습니다.
|
||||
* 프로젝트에 필요한 기능인가요? 우리는 그에 대해 듣고 싶습니다!
|
||||
* 커뮤니티에 도움이 될 수 있는 것을 작업했고 그것에 대해 생각하고 있는가요? 멋지네요! 어떤 문제를 해결했는지 알려주세요.
|
||||
2. 기능을 *상세히 설명하는* 문단을 작성해주세요;
|
||||
3. 향후 사용을 보여주는 **코드 조각**을 제공해주세요;
|
||||
4. 논문과 관련된 내용인 경우 링크를 첨부해주세요;
|
||||
5. 도움이 될 수 있다고 생각되는 추가 정보(그림, 스크린샷 등)를 첨부해주세요.
|
||||
3. 미래 사용을 보여주는 **코드 조각**을 제공해주세요;
|
||||
4. 이것이 논문과 관련된 경우 링크를 첨부해주세요;
|
||||
5. 도움이 될 수 있는 추가 정보(그림, 스크린샷 등)를 첨부해주세요.
|
||||
|
||||
기능 요청은 [여기](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feature_request.md&title=)에서 열 수 있습니다.
|
||||
|
||||
#### 2.3 피드백 [[23-feedback]]
|
||||
#### 2.3 피드백
|
||||
|
||||
라이브러리 디자인과 그것이 왜 좋은지 또는 나쁜지에 대한 이유에 대한 피드백은 핵심 메인테이너가 사용자 친화적인 라이브러리를 만드는 데 엄청난 도움이 됩니다. 현재 디자인 철학을 이해하려면 [여기](https://huggingface.co/docs/diffusers/conceptual/philosophy)를 참조해 주세요. 특정 디자인 선택이 현재 디자인 철학과 맞지 않는다고 생각되면, 그 이유와 어떻게 변경되어야 하는지 설명해 주세요. 반대로 특정 디자인 선택이 디자인 철학을 너무 따르기 때문에 사용 사례를 제한한다고 생각되면, 그 이유와 어떻게 변경되어야 하는지 설명해 주세요. 특정 디자인 선택이 매우 유용하다고 생각되면, 향후 디자인 결정에 큰 도움이 되므로 이에 대한 의견을 남겨 주세요.
|
||||
라이브러리 디자인과 그것이 왜 좋은지 또는 나쁜지에 대한 이유에 대한 피드백은 핵심 메인테이너가 사용자 친화적인 라이브러리를 만드는 데 엄청난 도움이 됩니다. 현재 디자인 철학을 이해하려면 [여기](https://huggingface.co/docs/diffusers/conceptual/philosophy)를 참조해 주세요. 특정 디자인 선택이 현재 디자인 철학과 맞지 않는다고 생각되면, 그 이유와 어떻게 변경되어야 하는지 설명해 주세요. 반대로 특정 디자인 선택이 디자인 철학을 너무 따르기 때문에 사용 사례를 제한한다고 생각되면, 그 이유와 어떻게 변경되어야 하는지 설명해 주세요. 특정 디자인 선택이 매우 유용하다고 생각되면, 미래의 디자인 결정에 큰 도움이 되므로 이에 대한 의견을 남겨 주세요.
|
||||
|
||||
피드백에 관한 이슈는 [여기](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=)에서 열 수 있습니다.
|
||||
|
||||
#### 2.4 기술적인 질문 [[24-technical-questions]]
|
||||
#### 2.4 기술적인 질문
|
||||
|
||||
기술적인 질문은 주로 라이브러리의 특정 코드가 왜 특정 방식으로 작성되었는지 또는 코드의 특정 부분이 무엇을 하는지에 대한 질문입니다. 질문하신 코드 부분에 대한 링크를 제공하고 해당 코드 부분이 이해하기 어려운 이유에 대한 자세한 설명을 해주시기 바랍니다.
|
||||
|
||||
기술적인 질문에 관한 이슈를 [여기](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=bug&template=bug-report.yml)에서 열 수 있습니다.
|
||||
|
||||
#### 2.5 새로운 모델, 스케줄러 또는 파이프라인 추가 제안 [[25-proposal-to-add-a-new-model-scheduler-or-pipeline]]
|
||||
#### 2.5 새로운 모델, 스케줄러 또는 파이프라인 추가 제안
|
||||
|
||||
만약 diffusion 모델 커뮤니티에서 Diffusers 라이브러리에 추가하고 싶은 새로운 모델, 파이프라인 또는 스케줄러가 있다면, 다음 정보를 제공해주세요:
|
||||
|
||||
@@ -134,34 +135,34 @@ Diffusers 라이브러리와 관련된 모든 질문이나 의견은 [토론 포
|
||||
* 해당 모델의 오픈 소스 구현에 대한 링크
|
||||
* 모델 가중치가 있는 경우, 가중치의 링크
|
||||
|
||||
직접 모델에 기여하고 싶다면, 가장 잘 안내해드릴 수 있습니다. 또한, 가능하다면 구성 요소(모델, 스케줄러, 파이프라인 등)의 원저자를 GitHub 핸들로 태그하는 것을 잊지 마세요.
|
||||
모델에 직접 기여하고자 하는 경우, 최선의 안내를 위해 우리에게 알려주세요. 또한, 가능하다면 구성 요소(모델, 스케줄러, 파이프라인 등)의 원래 저자를 GitHub 핸들로 태그하는 것을 잊지 마세요.
|
||||
|
||||
모델/파이프라인/스케줄러에 대한 요청을 [여기](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=New+model%2Fpipeline%2Fscheduler&template=new-model-addition.yml)에서 열 수 있습니다.
|
||||
|
||||
### 3. GitHub 이슈 탭에서 문제에 대한 답변하기 [[3-answering-issues-on-the-github-issues-tab]]
|
||||
### 3. GitHub 이슈 탭에서 문제에 대한 답변하기
|
||||
|
||||
GitHub에서 이슈에 대한 답변을 하기 위해서는 Diffusers에 대한 기술적인 지식이 필요할 수 있지만, 정확한 답변이 아니더라도 모두가 시도해기를 권장합니다. 이슈에 대한 고품질 답변을 제공하기 위한 몇 가지 팁:
|
||||
- 가능한 한 간결하고 최소한으로 유지합니다.
|
||||
- 주제에 집중합니다. 이슈에 대한 답변은 해당 이슈에 관련된 내용에만 집중해야 합니다.
|
||||
- 자신의 주장을 증명하거나 장려하는 코드, 논문 또는 기타 출처는 링크를 제공하세요.
|
||||
- 코드, 논문 또는 다른 소스를 제공하여 답변을 증명하거나 지지합니다.
|
||||
- 코드로 답변합니다. 간단한 코드 조각이 이슈에 대한 답변이거나 이슈를 해결하는 방법을 보여준다면, 완전히 재현 가능한 코드 조각을 제공해주세요.
|
||||
|
||||
또한, 많은 이슈들은 단순히 주제와 무관하거나 다른 이슈의 중복이거나 관련이 없는 경우가 많습니다. 이러한 이슈들에 대한 답변을 제공하고, 이슈 작성자에게 더 정확한 정보를 제공하거나, 중복된 이슈에 대한 링크를 제공하거나, [포럼](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) 이나 [Discord](https://discord.gg/G7tWnz98XR)로 리디렉션하는 것은 메인테이너에게 큰 도움이 됩니다.
|
||||
|
||||
이슈가 올바른 버그 보고서이고 소스 코드에서 수정이 필요하다고 확인한 경우, 다음 섹션을 살펴보세요.
|
||||
|
||||
다음 모든 기여에 대해서는 PR을 열여야 합니다. [Pull Request 열기](#how-to-open-a-pr) 섹션에서 자세히 설명되어 있습니다.
|
||||
다음 모든 기여에 대해서는 PR을 열여야 합니다. [PR 열기](#how-to-open-a-pr) 섹션에서 자세히 설명되어 있습니다.
|
||||
|
||||
### 4. "Good first issue" 고치기 [[4-fixing-a-good-first-issue]]
|
||||
### 4. "Good first issue" 고치기
|
||||
|
||||
*Good first issues*는 [Good first issue](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) 라벨로 표시됩니다. 일반적으로, 이슈는 이미 잠재적인 해결책이 어떻게 보이는지 설명하고 있어서 수정하기 쉽습니다.
|
||||
만약 이슈가 아직 닫히지 않았고 이 문제를 해결해보고 싶다면, "이 이슈를 해결해보고 싶습니다."라는 메시지를 남기면 됩니다. 일반적으로 세 가지 시나리오가 있습니다:
|
||||
- a.) 이슈 설명에 이미 수정 사항을 제안하는 경우, 해결책이 이해되고 합리적으로 보인다면, PR 또는 드래프트 PR을 열어서 수정할 수 있습니다.
|
||||
- b.) 이슈 설명에 수정 사항이 제안되어 있지 않은 경우, 제안한 수정 사항이 가능할지 물어볼 수 있고, Diffusers 팀의 누군가가 곧 답변해줄 것입니다. 만약 어떻게 수정할지 좋은 아이디어가 있다면, 직접 PR을 열어도 됩니다.
|
||||
- a.) 이슈 설명이 이미 해결책을 제안합니다. 이 경우, 해결책이 이해되고 합리적으로 보인다면, PR 또는 드래프트 PR을 열어서 수정할 수 있습니다.
|
||||
- b.) 이슈 설명이 해결책을 제안하지 않습니다. 이 경우, 어떤 해결책이 가능할지 물어볼 수 있고, Diffusers 팀의 누군가가 곧 답변해줄 것입니다. 만약 어떻게 수정할지 좋은 아이디어가 있다면, 직접 PR을 열어도 됩니다.
|
||||
- c.) 이미 이 문제를 해결하기 위해 열린 PR이 있지만, 이슈가 아직 닫히지 않았습니다. PR이 더 이상 진행되지 않았다면, 새로운 PR을 열고 이전 PR에 링크를 걸면 됩니다. PR은 종종 원래 기여자가 갑자기 시간을 내지 못해 더 이상 진행하지 못하는 경우에 더 이상 진행되지 않게 됩니다. 이는 오픈 소스에서 자주 발생하는 일이며 매우 정상적인 상황입니다. 이 경우, 커뮤니티는 새로 시도하고 기존 PR의 지식을 활용해주면 매우 기쁠 것입니다. 이미 PR이 있고 활성화되어 있다면, 제안을 해주거나 PR을 검토하거나 PR에 기여할 수 있는지 물어보는 등 작성자를 도와줄 수 있습니다.
|
||||
|
||||
|
||||
### 5. 문서에 기여하기 [[5-contribute-to-the-documentation]]
|
||||
### 5. 문서에 기여하기
|
||||
|
||||
좋은 라이브러리는 항상 좋은 문서를 갖고 있습니다! 공식 문서는 라이브러리를 처음 사용하는 사용자들에게 첫 번째 접점 중 하나이며, 따라서 문서에 기여하는 것은 매우 가치 있는 기여입니다.
|
||||
|
||||
@@ -179,7 +180,7 @@ GitHub에서 이슈에 대한 답변을 하기 위해서는 Diffusers에 대한
|
||||
문서에 대한 변경 사항을 로컬에서 확인하는 방법은 [이 페이지](https://github.com/huggingface/diffusers/tree/main/docs)를 참조해주세요.
|
||||
|
||||
|
||||
### 6. 커뮤니티 파이프라인에 기여하기 [[6-contribute-a-community-pipeline]]
|
||||
### 6. 커뮤니티 파이프라인에 기여하기
|
||||
|
||||
> [!TIP]
|
||||
> 커뮤니티 파이프라인에 대해 자세히 알아보려면 [커뮤니티 파이프라인](../using-diffusers/custom_pipeline_overview#community-pipelines) 가이드를 읽어보세요. 커뮤니티 파이프라인이 왜 필요한지 궁금하다면 GitHub 이슈 [#841](https://github.com/huggingface/diffusers/issues/841)를 확인해보세요 (기본적으로, 우리는 diffusion 모델이 추론에 사용될 수 있는 모든 방법을 유지할 수 없지만 커뮤니티가 이를 구축하는 것을 방해하고 싶지 않습니다).
|
||||
@@ -245,7 +246,7 @@ output = pipeline()
|
||||
<hfoptions id="pipeline type">
|
||||
<hfoption id="GitHub pipeline">
|
||||
|
||||
GitHub 파이프라인을 공유하려면 Diffusers [저장소](https://github.com/huggingface/diffusers)에서 Pull Request를 열고 one_step_unet.py 파일을 [examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) 하위 폴더에 추가하세요.
|
||||
GitHub 파이프라인을 공유하려면 Diffusers [저장소](https://github.com/huggingface/diffusers)에서 PR을 열고 one_step_unet.py 파일을 [examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) 하위 폴더에 추가하세요.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Hub pipeline">
|
||||
@@ -255,7 +256,7 @@ Hub 파이프라인을 공유하려면, 허브에 모델 저장소를 생성하
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### 7. 훈련 예제에 기여하기 [[7-contribute-to-training-examples]]
|
||||
### 7. 훈련 예제에 기여하기
|
||||
|
||||
Diffusers 예제는 [examples](https://github.com/huggingface/diffusers/tree/main/examples) 폴더에 있는 훈련 스크립트의 모음입니다.
|
||||
|
||||
@@ -267,7 +268,7 @@ Diffusers 예제는 [examples](https://github.com/huggingface/diffusers/tree/mai
|
||||
연구용 훈련 예제는 [examples/research_projects](https://github.com/huggingface/diffusers/tree/main/examples/research_projects)에 위치하며, 공식 훈련 예제는 `research_projects` 및 `community` 폴더를 제외한 [examples](https://github.com/huggingface/diffusers/tree/main/examples)의 모든 폴더를 포함합니다.
|
||||
공식 훈련 예제는 Diffusers의 핵심 메인테이너가 유지 관리하며, 연구용 훈련 예제는 커뮤니티가 유지 관리합니다.
|
||||
이는 공식 파이프라인 vs 커뮤니티 파이프라인에 대한 [6. 커뮤니티 파이프라인 기여하기](#6-contribute-a-community-pipeline)에서 제시한 이유와 동일합니다: 핵심 메인테이너가 diffusion 모델의 모든 가능한 훈련 방법을 유지 관리하는 것은 현실적으로 불가능합니다.
|
||||
Diffusers 핵심 메인테이너와 커뮤니티가 특정 훈련 패러다임을 너무 실험적이거나 충분히 대중적이지 않다고 판단한다면, 해당 훈련 코드는 `research_projects` 폴더에 넣고 작성자에 의해 관리되어야 합니다.
|
||||
Diffusers 핵심 메인테잉너와 커뮤니티가 특정 훈련 패러다임을 너무 실험적이거나 충분히 인기 없는 것으로 간주하는 경우, 해당 훈련 코드는 `research_projects` 폴더에 넣고 작성자가 유지 관리해야 합니다.
|
||||
|
||||
공식 훈련 및 연구 예제는 하나 이상의 훈련 스크립트, requirements.txt 파일 및 README.md 파일을 포함하는 디렉토리로 구성됩니다. 사용자가 훈련 예제를 사용하려면 리포지토리를 복제해야 합니다:
|
||||
|
||||
@@ -297,14 +298,14 @@ Diffusers와 긴밀하게 통합되어 있기 때문에, 기여자들이 [Accele
|
||||
|
||||
만약 공식 훈련 예제에 기여하는 경우, [examples/test_examples.py](https://github.com/huggingface/diffusers/blob/main/examples/test_examples.py)에 테스트를 추가하는 것도 확인해주세요. 비공식 훈련 예제에는 이 작업이 필요하지 않습니다.
|
||||
|
||||
### 8. "Good second issue" 고치기 [[8-fixing-a-good-second-issue]]
|
||||
### 8. "Good second issue" 고치기
|
||||
|
||||
"Good second issue"는 [Good second issue](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22Good+second+issue%22) 라벨로 표시됩니다. Good second issue는 [Good first issues](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)보다 해결하기가 더 복잡합니다.
|
||||
이슈 설명은 일반적으로 이슈를 해결하는 방법에 대해 덜 구체적이며, 관심 있는 기여자는 라이브러리에 대한 꽤 깊은 이해가 필요합니다.
|
||||
Good second issue를 해결하고자 하는 경우, 해당 이슈를 해결하기 위해 PR을 열고 PR을 이슈에 링크하세요. 이미 해당 이슈에 대한 PR이 열려있지만 병합되지 않은 경우, 왜 병합되지 않았는지 이해하기 위해 살펴보고 개선된 PR을 열어보세요.
|
||||
Good second issue는 일반적으로 Good first issue 이슈보다 병합하기가 더 어려우므로, 핵심 메인테이너에게 도움을 요청하는 것이 좋습니다. PR이 거의 완료된 경우, 핵심 메인테이너는 PR에 참여하여 커밋하고 병합을 진행할 수 있습니다.
|
||||
|
||||
### 9. 파이프라인, 모델, 스케줄러 추가하기 [[9-adding-pipelines-models-schedulers]]
|
||||
### 9. 파이프라인, 모델, 스케줄러 추가하기
|
||||
|
||||
파이프라인, 모델, 스케줄러는 Diffusers 라이브러리에서 가장 중요한 부분입니다.
|
||||
이들은 최첨단 diffusion 기술에 쉽게 접근하도록 하며, 따라서 커뮤니티가 강력한 생성형 AI 애플리케이션을 만들 수 있도록 합니다.
|
||||
@@ -322,9 +323,9 @@ PR에 원본 코드베이스/논문 링크를 추가하고, 가능하면 PR에
|
||||
|
||||
PR에서 막힌 경우나 도움이 필요한 경우, 첫 번째 리뷰나 도움을 요청하는 메시지를 남기는 것을 주저하지 마세요.
|
||||
|
||||
#### Copied from mechanism [[copied-from-mechanism]]
|
||||
#### Copied from mechanism
|
||||
|
||||
`# Copied from mechanism` 은 파이프라인, 모델 또는 스케줄러 코드를 추가할 때 이해해야 할 독특하고 중요한 기능입니다. 이것은 Diffusers 코드베이스 전반에서 볼 수 있으며, 이를 사용하는 이유는 코드베이스를 이해하고 유지 관리하기 쉽게 만들기 위해서입니다. `# Copied from mechanism` 으로 표시된 코드는 복사한 코드와 정확히 동일하도록 강제됩니다. 이렇게 하면 `make fix-copies`를 실행할 때마다 여러 파일에 걸쳐 변경 사항을 쉽게 업데이트하고 전파할 수 있습니다.
|
||||
`# Copied from mechanism` 은 파이프라인, 모델 또는 스케줄러 코드를 추가할 때 이해해야 할 독특하고 중요한 기능입니다. Diffusers 코드베이스 전체에서 이를 자주 볼 수 있는데, 이를 사용하는 이유는 코드베이스를 이해하기 쉽고 유지 관리하기 쉽게 유지하기 위함입니다. `# Copied from mechanism` 으로 표시된 코드는 복사한 코드와 정확히 동일하도록 강제됩니다. 이를 통해 `make fix-copies`를 실행할 때 많은 파일에 걸쳐 변경 사항을 쉽게 업데이트하고 전파할 수 있습니다.
|
||||
|
||||
예를 들어, 아래 코드 예제에서 [`~diffusers.pipelines.stable_diffusion.StableDiffusionPipelineOutput`]은 원래 코드이며, `AltDiffusionPipelineOutput`은 `# Copied from mechanism`을 사용하여 복사합니다. 유일한 차이점은 클래스 접두사를 `Stable`에서 `Alt`로 변경한 것입니다.
|
||||
|
||||
@@ -346,7 +347,7 @@ class AltDiffusionPipelineOutput(BaseOutput):
|
||||
|
||||
더 자세히 알고 싶다면 [~Don't~ Repeat Yourself*](https://huggingface.co/blog/transformers-design-philosophy#4-machine-learning-models-are-static) 블로그 포스트의 이 섹션을 읽어보세요.
|
||||
|
||||
## 좋은 이슈 작성 방법 [[how-to-write-a-good-issue]]
|
||||
## 좋은 이슈 작성 방법
|
||||
|
||||
**이슈를 잘 작성할수록 빠르게 해결될 가능성이 높아집니다.**
|
||||
|
||||
@@ -355,16 +356,16 @@ class AltDiffusionPipelineOutput(BaseOutput):
|
||||
3. **재현 가능성**: 재현 가능한 코드 조각이 없으면 해결할 수 없습니다. 버그를 발견한 경우, 유지 관리자는 그 버그를 재현할 수 있어야 합니다. 이슈에 재현 가능한 코드 조각을 포함해야 합니다. 코드 조각은 Python 인터프리터에 복사하여 붙여넣을 수 있는 형태여야 합니다. 코드 조각이 작동해야 합니다. 즉, 누락된 import나 이미지에 대한 링크가 없어야 합니다. 이슈에는 오류 메시지와 정확히 동일한 오류 메시지를 재현하기 위해 수정하지 않고 복사하여 붙여넣을 수 있는 코드 조각이 포함되어야 합니다. 이슈에 사용자의 로컬 모델 가중치나 로컬 데이터를 사용하는 경우, 독자가 액세스할 수 없는 경우 이슈를 해결할 수 없습니다. 데이터나 모델을 공유할 수 없는 경우, 더미 모델이나 더미 데이터를 만들어 사용해보세요.
|
||||
4. **간결성**: 가능한 한 간결하게 유지하여 독자가 문제를 빠르게 이해할 수 있도록 도와주세요. 문제와 관련이 없는 코드나 정보는 모두 제거해주세요. 버그를 발견한 경우, 문제를 설명하는 가장 간단한 코드 예제를 만들어보세요. 버그를 발견한 후에는 작업 흐름 전체를 문제에 던지는 것이 아니라, 에러가 발생하는 훈련 코드의 어느 부분이 문제인지 먼저 이해하고 몇 줄로 재현해보세요. 전체 데이터셋 대신 더미 데이터를 사용해보세요.
|
||||
5. 링크 추가하기. 특정한 이름, 메서드, 또는 모델을 참조하는 경우, 독자가 더 잘 이해할 수 있도록 링크를 제공해주세요. 특정 PR이나 이슈를 참조하는 경우, 해당 이슈에 링크를 걸어주세요. 독자가 무엇을 말하는지 알고 있다고 가정하지 마세요. 이슈에 링크를 추가할수록 좋습니다.
|
||||
6. 포맷팅. 코드를 파이썬 코드 구문으로, 에러 메시지를 일반 코드 구문으로 형식화하여 이슈를 깔끔하게 작성하세요. 자세한 내용은 [GitHub 공식 포맷팅 문서](https://docs.github.com/en/get-started/writing-on-github/getting-started-with-writing-and-formatting-on-github/basic-writing-and-formatting-syntax)를 참조하세요.
|
||||
7. 여러분의 이슈를 단순히 해결해야 할 티켓으로 생각하지 말고, 잘 작성된 백과사전 항목으로 생각해보세요. 추가된 모든 이슈는 공개적으로 이용 가능한 지식에 대한 기여입니다. 잘 작성된 이슈를 추가함으로써 메인테이너가 여러분의 이슈를 더 쉽게 해결할 수 있게 할 뿐만 아니라, 전체 커뮤니티가 라이브러리의 특정 측면을 더 잘 이해할 수 있도록 도움을 주게 됩니다.
|
||||
6. 포맷팅. 파이썬 코드 구문으로 코드를 포맷팅하고, 일반 코드 구문으로 에러 메시지를 포맷팅해주세요. 자세한 내용은 [공식 GitHub 포맷팅 문서](https://docs.github.com/en/get-started/writing-on-github/getting-started-with-writing-and-formatting-on-github/basic-writing-and-formatting-syntax)를 참조하세요.
|
||||
7. 이슈를 해결해야 하는 티켓이 아니라, 잘 작성된 백과사전 항목으로 생각해보세요. 추가된 이슈는 공개적으로 사용 가능한 지식에 기여하는 것입니다. 잘 작성된 이슈를 추가함으로써 메인테이너가 문제를 해결하는 데 도움을 주는 것뿐만 아니라, 전체 커뮤니티가 라이브러리의 특정 측면을 더 잘 이해할 수 있도록 도움을 주는 것입니다.
|
||||
|
||||
## 좋은 PR 작성 방법 [[how-to-write-a-good-pr]]
|
||||
## 좋은 PR 작성 방법
|
||||
|
||||
1. 카멜레온이 되세요. 기존의 디자인 패턴과 구문을 이해하고, 여러분이 추가하는 코드가 기존 코드베이스와 자연스럽게 어우러지도록 해야 합니다. 기존 디자인 패턴이나 사용자 인터페이스와 크게 다른 Pull Request들은 병합되지 않습니다.
|
||||
2. 레이저처럼 집중하세요. Pull Request는 하나의 문제, 오직 하나의 문제만 해결해야 합니다. "이왕 추가하는 김에 다른 문제도 고치자"는 함정에 빠지지 않도록 주의하세요. 여러 개의 관련 없는 문제를 해결하는 한 번에 해결하는 Pull Request들은 검토하기가 훨씬 더 어렵습니다.
|
||||
1. 카멜레온이 되세요. 기존의 디자인 패턴과 구문을 이해하고, 코드 추가가 기존 코드베이스에 매끄럽게 흐르도록 해야 합니다. 기존 디자인 패턴이나 사용자 인터페이스와 크게 다른 PR은 병합되지 않습니다.
|
||||
2. 초점을 맞추세요. 하나의 문제만 해결하는 PR을 작성해야 합니다. "추가하면서 다른 문제도 해결하기"에 빠지지 않도록 주의하세요. 여러 개의 관련 없는 문제를 해결하는 PR을 작성하는 것은 리뷰하기가 훨씬 어렵습니다.
|
||||
3. 도움이 되는 경우, 추가한 내용이 어떻게 사용되는지 예제 코드 조각을 추가해보세요.
|
||||
4. Pull Request의 제목은 기여 내용을 요약해야 합니다.
|
||||
5. Pull Request가 이슈를 해결하는 경우, Pull Request의 설명에 이슈 번호를 언급하여 연결되도록 해주세요 (이슈를 참조하는 사람들이 작업 중임을 알 수 있도록).
|
||||
4. PR의 제목은 기여 내용을 요약해야 합니다.
|
||||
5. PR이 이슈를 해결하는 경우, PR 설명에 이슈 번호를 언급하여 연결되도록 해주세요 (이슈를 참조하는 사람들이 작업 중임을 알 수 있도록).
|
||||
6. 진행 중인 작업을 나타내려면 제목에 `[WIP]`를 접두사로 붙여주세요. 이는 중복 작업을 피하고, 병합 준비가 된 PR과 구분할 수 있도록 도움이 됩니다.
|
||||
7. [좋은 이슈를 작성하는 방법](#how-to-write-a-good-issue)에 설명된 대로 텍스트를 구성하고 형식을 지정해보세요.
|
||||
8. 기존 테스트가 통과하는지 확인하세요
|
||||
@@ -373,10 +374,10 @@ class AltDiffusionPipelineOutput(BaseOutput):
|
||||
`RUN_SLOW=1 python -m pytest tests/test_my_new_model.py`.
|
||||
CircleCI는 느린 테스트를 실행하지 않지만, GitHub Actions는 매일 실행합니다!
|
||||
10. 모든 공개 메서드는 마크다운과 잘 작동하는 정보성 docstring을 가져야 합니다. 예시로 [`pipeline_latent_diffusion.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py)를 참조하세요.
|
||||
11. 리포지토리가 빠르게 성장하고 있기 때문에, 리포지토리에 큰 부담을 주는 파일이 추가되지 않도록 주의해야 합니다. 이미지, 비디오 및 기타 텍스트가 아닌 파일을 포함합니다. 이러한 파일을 배치하기 위해 hf.co 호스팅 `dataset`인 [`hf-internal-testing`](https://huggingface.co/hf-internal-testing) 또는 [huggingface/documentation-images](https://huggingface.co/datasets/huggingface/documentation-images)를 활용하는 것이 우선입니다.
|
||||
11. 레포지토리가 빠르게 성장하고 있기 때문에, 레포지토리에 큰 부담을 주는 파일이 추가되지 않도록 주의해야 합니다. 이미지, 비디오 및 기타 텍스트가 아닌 파일을 포함합니다. 이러한 파일을 배치하기 위해 hf.co 호스팅 `dataset`인 [`hf-internal-testing`](https://huggingface.co/hf-internal-testing) 또는 [huggingface/documentation-images](https://huggingface.co/datasets/huggingface/documentation-images)를 활용하는 것이 우선입니다.
|
||||
외부 기여인 경우, 이미지를 PR에 추가하고 Hugging Face 구성원에게 이미지를 이 데이터셋으로 이동하도록 요청하세요.
|
||||
|
||||
## PR을 열기 위한 방법 [[how-to-open-a-pr]]
|
||||
## PR을 열기 위한 방법
|
||||
|
||||
코드를 작성하기 전에, 이미 누군가가 같은 작업을 하고 있는지 확인하기 위해 기존의 PR이나 이슈를 검색하는 것이 좋습니다. 확실하지 않은 경우, 피드백을 받기 위해 이슈를 열어보는 것이 항상 좋은 아이디어입니다.
|
||||
|
||||
@@ -402,7 +403,7 @@ CircleCI는 느린 테스트를 실행하지 않지만, GitHub Actions는 매일
|
||||
|
||||
`main` 브랜치 위에서 **절대** 작업하지 마세요.
|
||||
|
||||
4. 가상 환경에서 다음 명령을 실행하여 개발 환경을 설정하세요:
|
||||
1. 가상 환경에서 다음 명령을 실행하여 개발 환경을 설정하세요:
|
||||
|
||||
```bash
|
||||
$ pip install -e ".[dev]"
|
||||
@@ -466,7 +467,7 @@ CircleCI는 느린 테스트를 실행하지 않지만, GitHub Actions는 매일
|
||||
|
||||
7. 메인테이너가 변경 사항을 요청하는 것은 괜찮습니다. 핵심 기여자들에게도 일어나는 일입니다! 따라서 변경 사항을 Pull request에서 볼 수 있도록 로컬 브랜치에서 작업하고 변경 사항을 포크에 푸시하면 자동으로 Pull request에 나타납니다.
|
||||
|
||||
### 테스트 [[tests]]
|
||||
### 테스트
|
||||
|
||||
라이브러리 동작과 여러 예제를 테스트하기 위해 포괄적인 테스트 묶음이 포함되어 있습니다. 라이브러리 테스트는 [tests 폴더](https://github.com/huggingface/diffusers/tree/main/tests)에서 찾을 수 있습니다.
|
||||
|
||||
@@ -493,7 +494,7 @@ $ python -m unittest discover -s tests -t . -v
|
||||
$ python -m unittest discover -s examples -t examples -v
|
||||
```
|
||||
|
||||
### upstream(HuggingFace) main과 forked main 동기화하기 [[syncing-forked-main-with-upstream-huggingface-main]]
|
||||
### upstream(main)과 forked main 동기화하기
|
||||
|
||||
upstream 저장소에 불필요한 참조 노트를 추가하고 관련 개발자에게 알림을 보내는 것을 피하기 위해,
|
||||
forked 저장소의 main 브랜치를 동기화할 때 다음 단계를 따르세요:
|
||||
@@ -506,6 +507,6 @@ $ git commit -m '<your message without GitHub references>'
|
||||
$ git push --set-upstream origin your-branch-for-syncing
|
||||
```
|
||||
|
||||
### 스타일 가이드 [[style-guide]]
|
||||
### 스타일 가이드
|
||||
|
||||
Documentation string에 대해서는, 🧨 Diffusers는 [Google 스타일](https://google.github.io/styleguide/pyguide.html)을 따릅니다.
|
||||
|
||||
@@ -71,7 +71,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ from diffusers.utils import BaseOutput, check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
|
||||
class MarigoldDepthOutput(BaseOutput):
|
||||
|
||||
@@ -73,7 +73,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -78,7 +78,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -63,7 +63,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -3,17 +3,17 @@
|
||||
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject.
|
||||
|
||||
The `train_dreambooth_flux.py` script shows how to implement the training procedure and adapt it for [FLUX.1 [dev]](https://blackforestlabs.ai/announcing-black-forest-labs/). We also provide a LoRA implementation in the `train_dreambooth_lora_flux.py` script.
|
||||
> [!NOTE]
|
||||
> [!NOTE]
|
||||
> **Memory consumption**
|
||||
>
|
||||
> Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements -
|
||||
>
|
||||
> Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements -
|
||||
> a LoRA with a rank of 16 (w/ all components trained) can exceed 40GB of VRAM for training.
|
||||
> For more tips & guidance on training on a resource-constrained device please visit [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX.md)
|
||||
> For more tips & guidance on training on a resource-constrained device please visit [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX.md)
|
||||
|
||||
|
||||
> [!NOTE]
|
||||
> **Gated model**
|
||||
>
|
||||
>
|
||||
> As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
|
||||
|
||||
```bash
|
||||
@@ -163,7 +163,7 @@ To do so, just specify `--train_text_encoder` while launching training. Please k
|
||||
|
||||
> [!NOTE]
|
||||
> FLUX.1 has 2 text encoders (CLIP L/14 and T5-v1.1-XXL).
|
||||
By enabling `--train_text_encoder`, fine-tuning of the **CLIP encoder** is performed.
|
||||
By enabling `--train_text_encoder`, fine-tuning of the **CLIP encoder** is performed.
|
||||
> At the moment, T5 fine-tuning is not supported and weights remain frozen when text encoder training is enabled.
|
||||
|
||||
To perform DreamBooth LoRA with text-encoder training, run:
|
||||
|
||||
@@ -63,7 +63,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
# Cache compiled models across invocations of this script.
|
||||
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
||||
|
||||
@@ -64,7 +64,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -70,7 +70,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -1454,7 +1454,7 @@ def main(args):
|
||||
)
|
||||
|
||||
# Clear the memory here
|
||||
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
|
||||
if not args.train_text_encoder and train_dataset.custom_instance_prompts:
|
||||
del tokenizers, text_encoders
|
||||
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
|
||||
del text_encoder_one, text_encoder_two, text_encoder_three
|
||||
|
||||
@@ -78,7 +78,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -2,8 +2,8 @@ diffusers==0.20.1
|
||||
accelerate==0.23.0
|
||||
transformers==4.38.0
|
||||
peft==0.5.0
|
||||
torch==2.2.0
|
||||
torch==2.0.1
|
||||
torchvision>=0.16
|
||||
ftfy==6.1.1
|
||||
tensorboard==2.14.0
|
||||
Jinja2==3.1.4
|
||||
Jinja2==3.1.3
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -55,7 +55,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
@@ -1084,7 +1084,7 @@ def main(args):
|
||||
|
||||
# Add noise to the model input according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps).to(dtype=weight_dtype)
|
||||
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
|
||||
|
||||
# time ids
|
||||
def compute_time_ids(original_size, crops_coords_top_left):
|
||||
@@ -1101,7 +1101,7 @@ def main(args):
|
||||
|
||||
# Predict the noise residual
|
||||
unet_added_conditions = {"time_ids": add_time_ids}
|
||||
prompt_embeds = batch["prompt_embeds"].to(accelerator.device, dtype=weight_dtype)
|
||||
prompt_embeds = batch["prompt_embeds"].to(accelerator.device)
|
||||
pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device)
|
||||
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
|
||||
model_pred = unet(
|
||||
|
||||
@@ -109,9 +109,6 @@ import torch
|
||||
model_id = "path-to-your-trained-model"
|
||||
pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to("cuda")
|
||||
|
||||
repo_id_embeds = "path-to-your-learned-embeds"
|
||||
pipe.load_textual_inversion(repo_id_embeds)
|
||||
|
||||
prompt = "A <cat-toy> backpack"
|
||||
|
||||
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
|
||||
|
||||
@@ -81,7 +81,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ else:
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.31.0.dev0")
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -86,9 +86,6 @@ TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"key_layernorm_list": reassign_query_key_layernorm_inplace,
|
||||
"adaln_layer.adaLN_modulations": reassign_adaln_norm_inplace,
|
||||
"embed_tokens": remove_keys_inplace,
|
||||
"freqs_sin": remove_keys_inplace,
|
||||
"freqs_cos": remove_keys_inplace,
|
||||
"position_embedding": remove_keys_inplace,
|
||||
}
|
||||
|
||||
VAE_KEYS_RENAME_DICT = {
|
||||
@@ -126,21 +123,11 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
|
||||
def convert_transformer(
|
||||
ckpt_path: str,
|
||||
num_layers: int,
|
||||
num_attention_heads: int,
|
||||
use_rotary_positional_embeddings: bool,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
def convert_transformer(ckpt_path: str):
|
||||
PREFIX_KEY = "model.diffusion_model."
|
||||
|
||||
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
|
||||
transformer = CogVideoXTransformer3DModel(
|
||||
num_layers=num_layers,
|
||||
num_attention_heads=num_attention_heads,
|
||||
use_rotary_positional_embeddings=use_rotary_positional_embeddings,
|
||||
).to(dtype=dtype)
|
||||
transformer = CogVideoXTransformer3DModel()
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[len(PREFIX_KEY) :]
|
||||
@@ -158,9 +145,9 @@ def convert_transformer(
|
||||
return transformer
|
||||
|
||||
|
||||
def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
|
||||
def convert_vae(ckpt_path: str):
|
||||
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
|
||||
vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor).to(dtype=dtype)
|
||||
vae = AutoencoderKLCogVideoX()
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
@@ -185,26 +172,13 @@ def get_args():
|
||||
)
|
||||
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
parser.add_argument("--fp16", action="store_true", default=False, help="Whether to save the model weights in fp16")
|
||||
parser.add_argument("--bf16", action="store_true", default=False, help="Whether to save the model weights in bf16")
|
||||
parser.add_argument("--fp16", action="store_true", default=True, help="Whether to save the model weights in fp16")
|
||||
parser.add_argument(
|
||||
"--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
|
||||
)
|
||||
# For CogVideoX-2B, num_layers is 30. For 5B, it is 42
|
||||
parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks")
|
||||
# For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
|
||||
parser.add_argument("--num_attention_heads", type=int, default=30, help="Number of attention heads")
|
||||
# For CogVideoX-2B, use_rotary_positional_embeddings is False. For 5B, it is True
|
||||
parser.add_argument(
|
||||
"--use_rotary_positional_embeddings", action="store_true", default=False, help="Whether to use RoPE or not"
|
||||
)
|
||||
# For CogVideoX-2B, scaling_factor is 1.15258426. For 5B, it is 0.7
|
||||
parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
|
||||
# For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
|
||||
parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -214,33 +188,18 @@ if __name__ == "__main__":
|
||||
transformer = None
|
||||
vae = None
|
||||
|
||||
if args.fp16 and args.bf16:
|
||||
raise ValueError("You cannot pass both --fp16 and --bf16 at the same time.")
|
||||
|
||||
dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32
|
||||
|
||||
if args.transformer_ckpt_path is not None:
|
||||
transformer = convert_transformer(
|
||||
args.transformer_ckpt_path,
|
||||
args.num_layers,
|
||||
args.num_attention_heads,
|
||||
args.use_rotary_positional_embeddings,
|
||||
dtype,
|
||||
)
|
||||
transformer = convert_transformer(args.transformer_ckpt_path)
|
||||
if args.vae_ckpt_path is not None:
|
||||
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype)
|
||||
vae = convert_vae(args.vae_ckpt_path)
|
||||
|
||||
text_encoder_id = "google/t5-v1_1-xxl"
|
||||
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
|
||||
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
|
||||
|
||||
# Apparently, the conversion does not work any more without this :shrug:
|
||||
for param in text_encoder.parameters():
|
||||
param.data = param.data.contiguous()
|
||||
|
||||
scheduler = CogVideoXDDIMScheduler.from_config(
|
||||
{
|
||||
"snr_shift_scale": args.snr_shift_scale,
|
||||
"snr_shift_scale": 3.0,
|
||||
"beta_end": 0.012,
|
||||
"beta_schedule": "scaled_linear",
|
||||
"beta_start": 0.00085,
|
||||
@@ -249,7 +208,7 @@ if __name__ == "__main__":
|
||||
"prediction_type": "v_prediction",
|
||||
"rescale_betas_zero_snr": True,
|
||||
"set_alpha_to_one": True,
|
||||
"timestep_spacing": "trailing",
|
||||
"timestep_spacing": "linspace",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -259,10 +218,5 @@ if __name__ == "__main__":
|
||||
|
||||
if args.fp16:
|
||||
pipe = pipe.to(dtype=torch.float16)
|
||||
if args.bf16:
|
||||
pipe = pipe.to(dtype=torch.bfloat16)
|
||||
|
||||
# We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
|
||||
# for users to specify variant when the default is not fp32 and they want to run with the correct default (which
|
||||
# is either fp16/bf16 here).
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, push_to_hub=args.push_to_hub)
|
||||
|
||||
@@ -254,7 +254,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
|
||||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.31.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="0.30.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
description="State-of-the-art diffusion in PyTorch and JAX.",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
__version__ = "0.31.0.dev0"
|
||||
__version__ = "0.30.0.dev0"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -88,7 +88,6 @@ else:
|
||||
"ControlNetModel",
|
||||
"ControlNetXSAdapter",
|
||||
"DiTTransformer2DModel",
|
||||
"FluxControlNetModel",
|
||||
"FluxTransformer2DModel",
|
||||
"HunyuanDiT2DControlNetModel",
|
||||
"HunyuanDiT2DModel",
|
||||
@@ -255,7 +254,6 @@ else:
|
||||
"CLIPImageProjection",
|
||||
"CogVideoXPipeline",
|
||||
"CycleDiffusionPipeline",
|
||||
"FluxControlNetPipeline",
|
||||
"FluxPipeline",
|
||||
"HunyuanDiTControlNetPipeline",
|
||||
"HunyuanDiTPAGPipeline",
|
||||
@@ -310,7 +308,6 @@ else:
|
||||
"StableCascadeCombinedPipeline",
|
||||
"StableCascadeDecoderPipeline",
|
||||
"StableCascadePriorPipeline",
|
||||
"StableDiffusion3ControlNetInpaintingPipeline",
|
||||
"StableDiffusion3ControlNetPipeline",
|
||||
"StableDiffusion3Img2ImgPipeline",
|
||||
"StableDiffusion3InpaintPipeline",
|
||||
@@ -346,7 +343,6 @@ else:
|
||||
"StableDiffusionXLAdapterPipeline",
|
||||
"StableDiffusionXLControlNetImg2ImgPipeline",
|
||||
"StableDiffusionXLControlNetInpaintPipeline",
|
||||
"StableDiffusionXLControlNetPAGImg2ImgPipeline",
|
||||
"StableDiffusionXLControlNetPAGPipeline",
|
||||
"StableDiffusionXLControlNetPipeline",
|
||||
"StableDiffusionXLControlNetXSPipeline",
|
||||
@@ -553,7 +549,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ControlNetModel,
|
||||
ControlNetXSAdapter,
|
||||
DiTTransformer2DModel,
|
||||
FluxControlNetModel,
|
||||
FluxTransformer2DModel,
|
||||
HunyuanDiT2DControlNetModel,
|
||||
HunyuanDiT2DModel,
|
||||
@@ -698,7 +693,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
CLIPImageProjection,
|
||||
CogVideoXPipeline,
|
||||
CycleDiffusionPipeline,
|
||||
FluxControlNetPipeline,
|
||||
FluxPipeline,
|
||||
HunyuanDiTControlNetPipeline,
|
||||
HunyuanDiTPAGPipeline,
|
||||
@@ -788,7 +782,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionXLAdapterPipeline,
|
||||
StableDiffusionXLControlNetImg2ImgPipeline,
|
||||
StableDiffusionXLControlNetInpaintPipeline,
|
||||
StableDiffusionXLControlNetPAGImg2ImgPipeline,
|
||||
StableDiffusionXLControlNetPAGPipeline,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
StableDiffusionXLControlNetXSPipeline,
|
||||
|
||||
@@ -222,11 +222,7 @@ class IPAdapterMixin:
|
||||
|
||||
# create feature extractor if it has not been registered to the pipeline yet
|
||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
|
||||
# FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224
|
||||
default_clip_size = 224
|
||||
clip_image_size = (
|
||||
self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size
|
||||
)
|
||||
clip_image_size = self.image_encoder.config.image_size
|
||||
feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
|
||||
self.register_modules(feature_extractor=feature_extractor)
|
||||
|
||||
|
||||
@@ -280,9 +280,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
A standard state dict containing the lora layer parameters. The key should be prefixed with an
|
||||
additional `text_encoder` to distinguish between unet lora layers.
|
||||
network_alphas (`Dict[str, float]`):
|
||||
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
||||
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
||||
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
||||
See `LoRALinearLayer` for more details.
|
||||
text_encoder (`CLIPTextModel`):
|
||||
The text encoder model to load the LoRA layers into.
|
||||
prefix (`str`):
|
||||
@@ -755,9 +753,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
A standard state dict containing the lora layer parameters. The key should be prefixed with an
|
||||
additional `text_encoder` to distinguish between unet lora layers.
|
||||
network_alphas (`Dict[str, float]`):
|
||||
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
||||
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
||||
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
||||
See `LoRALinearLayer` for more details.
|
||||
text_encoder (`CLIPTextModel`):
|
||||
The text encoder model to load the LoRA layers into.
|
||||
prefix (`str`):
|
||||
@@ -1253,9 +1249,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
A standard state dict containing the lora layer parameters. The key should be prefixed with an
|
||||
additional `text_encoder` to distinguish between unet lora layers.
|
||||
network_alphas (`Dict[str, float]`):
|
||||
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
||||
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
||||
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
||||
See `LoRALinearLayer` for more details.
|
||||
text_encoder (`CLIPTextModel`):
|
||||
The text encoder model to load the LoRA layers into.
|
||||
prefix (`str`):
|
||||
@@ -1495,10 +1489,10 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
return_alphas: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -1583,26 +1577,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
allow_pickle=allow_pickle,
|
||||
)
|
||||
|
||||
# For state dicts like
|
||||
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
|
||||
keys = list(state_dict.keys())
|
||||
network_alphas = {}
|
||||
for k in keys:
|
||||
if "alpha" in k:
|
||||
alpha_value = state_dict.get(k)
|
||||
if (torch.is_tensor(alpha_value) and torch.is_floating_point(alpha_value)) or isinstance(
|
||||
alpha_value, float
|
||||
):
|
||||
network_alphas[k] = state_dict.pop(k)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue."
|
||||
)
|
||||
|
||||
if return_alphas:
|
||||
return state_dict, network_alphas
|
||||
else:
|
||||
return state_dict
|
||||
return state_dict
|
||||
|
||||
def load_lora_weights(
|
||||
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
||||
@@ -1636,9 +1611,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
state_dict, network_alphas = self.lora_state_dict(
|
||||
pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
|
||||
)
|
||||
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
@@ -1646,7 +1619,6 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
@@ -1656,7 +1628,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
if len(text_encoder_state_dict) > 0:
|
||||
self.load_lora_into_text_encoder(
|
||||
text_encoder_state_dict,
|
||||
network_alphas=network_alphas,
|
||||
network_alphas=None,
|
||||
text_encoder=self.text_encoder,
|
||||
prefix="text_encoder",
|
||||
lora_scale=self.lora_scale,
|
||||
@@ -1665,7 +1637,8 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None):
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer
|
||||
def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
|
||||
@@ -1674,10 +1647,6 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
||||
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
||||
encoder lora layers.
|
||||
network_alphas (`Dict[str, float]`):
|
||||
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
||||
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
||||
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
||||
transformer (`SD3Transformer2DModel`):
|
||||
The Transformer model to load the LoRA layers into.
|
||||
adapter_name (`str`, *optional*):
|
||||
@@ -1709,12 +1678,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
if "lora_B" in key:
|
||||
rank[key] = val.shape[1]
|
||||
|
||||
if network_alphas is not None and len(network_alphas) >= 1:
|
||||
prefix = cls.transformer_name
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
@@ -1771,9 +1735,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
A standard state dict containing the lora layer parameters. The key should be prefixed with an
|
||||
additional `text_encoder` to distinguish between unet lora layers.
|
||||
network_alphas (`Dict[str, float]`):
|
||||
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
||||
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
||||
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
||||
See `LoRALinearLayer` for more details.
|
||||
text_encoder (`CLIPTextModel`):
|
||||
The text encoder model to load the LoRA layers into.
|
||||
prefix (`str`):
|
||||
@@ -2006,9 +1968,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
||||
encoder lora layers.
|
||||
network_alphas (`Dict[str, float]`):
|
||||
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
||||
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
||||
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
||||
See `LoRALinearLayer` for more details.
|
||||
unet (`UNet2DConditionModel`):
|
||||
The UNet model to load the LoRA layers into.
|
||||
adapter_name (`str`, *optional*):
|
||||
@@ -2101,9 +2061,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
A standard state dict containing the lora layer parameters. The key should be prefixed with an
|
||||
additional `text_encoder` to distinguish between unet lora layers.
|
||||
network_alphas (`Dict[str, float]`):
|
||||
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
||||
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
||||
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
||||
See `LoRALinearLayer` for more details.
|
||||
text_encoder (`CLIPTextModel`):
|
||||
The text encoder model to load the LoRA layers into.
|
||||
prefix (`str`):
|
||||
|
||||
@@ -23,7 +23,6 @@ from packaging import version
|
||||
from ..utils import deprecate, is_transformers_available, logging
|
||||
from .single_file_utils import (
|
||||
SingleFileComponentError,
|
||||
_is_legacy_scheduler_kwargs,
|
||||
_is_model_weights_in_cached_folder,
|
||||
_legacy_load_clip_tokenizer,
|
||||
_legacy_load_safety_checker,
|
||||
@@ -43,6 +42,7 @@ logger = logging.get_logger(__name__)
|
||||
# Legacy behaviour. `from_single_file` does not load the safety checker unless explicitly provided
|
||||
SINGLE_FILE_OPTIONAL_COMPONENTS = ["safety_checker"]
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
import transformers
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
@@ -135,7 +135,7 @@ def load_single_file_sub_model(
|
||||
class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only
|
||||
)
|
||||
|
||||
elif is_diffusers_scheduler and (is_legacy_loading or _is_legacy_scheduler_kwargs(kwargs)):
|
||||
elif is_diffusers_scheduler and is_legacy_loading:
|
||||
loaded_sub_model = _legacy_load_scheduler(
|
||||
class_obj, checkpoint=checkpoint, component_name=name, original_config=original_config, **kwargs
|
||||
)
|
||||
|
||||
@@ -79,10 +79,7 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
|
||||
"animatediff_scribble": "controlnet_cond_embedding.conv_in.weight",
|
||||
"animatediff_rgb": "controlnet_cond_embedding.weight",
|
||||
"flux": [
|
||||
"double_blocks.0.img_attn.norm.key_norm.scale",
|
||||
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
|
||||
],
|
||||
"flux": "double_blocks.0.img_attn.norm.key_norm.scale",
|
||||
}
|
||||
|
||||
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
@@ -261,7 +258,7 @@ SCHEDULER_DEFAULT_CONFIG = {
|
||||
"timestep_spacing": "leading",
|
||||
}
|
||||
|
||||
LDM_VAE_KEYS = ["first_stage_model.", "vae."]
|
||||
LDM_VAE_KEY = "first_stage_model."
|
||||
LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
|
||||
PLAYGROUND_VAE_SCALING_FACTOR = 0.5
|
||||
LDM_UNET_KEY = "model.diffusion_model."
|
||||
@@ -270,8 +267,8 @@ LDM_CLIP_PREFIX_TO_REMOVE = [
|
||||
"cond_stage_model.transformer.",
|
||||
"conditioner.embedders.0.transformer.",
|
||||
]
|
||||
OPEN_CLIP_PREFIX = "conditioner.embedders.0.model."
|
||||
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
|
||||
SCHEDULER_LEGACY_KWARGS = ["prediction_type", "scheduler_type"]
|
||||
|
||||
VALID_URL_PREFIXES = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]
|
||||
|
||||
@@ -321,10 +318,6 @@ def _is_model_weights_in_cached_folder(cached_folder, name):
|
||||
return weights_exist
|
||||
|
||||
|
||||
def _is_legacy_scheduler_kwargs(kwargs):
|
||||
return any(k in SCHEDULER_LEGACY_KWARGS for k in kwargs.keys())
|
||||
|
||||
|
||||
def load_single_file_checkpoint(
|
||||
pretrained_model_link_or_path,
|
||||
force_download=False,
|
||||
@@ -456,8 +449,6 @@ def infer_diffusers_model_type(checkpoint):
|
||||
):
|
||||
if CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024:
|
||||
model_type = "inpainting_v2"
|
||||
elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint:
|
||||
model_type = "xl_inpaint"
|
||||
else:
|
||||
model_type = "inpainting"
|
||||
|
||||
@@ -525,10 +516,8 @@ def infer_diffusers_model_type(checkpoint):
|
||||
else:
|
||||
model_type = "animatediff_v3"
|
||||
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]):
|
||||
if any(
|
||||
g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
|
||||
):
|
||||
elif CHECKPOINT_KEY_NAMES["flux"] in checkpoint:
|
||||
if "guidance_in.in_layer.bias" in checkpoint:
|
||||
model_type = "flux-dev"
|
||||
else:
|
||||
model_type = "flux-schnell"
|
||||
@@ -1187,11 +1176,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
# remove the LDM_VAE_KEY prefix from the ldm checkpoint keys so that it is easier to map them to diffusers keys
|
||||
vae_state_dict = {}
|
||||
keys = list(checkpoint.keys())
|
||||
vae_key = ""
|
||||
for ldm_vae_key in LDM_VAE_KEYS:
|
||||
if any(k.startswith(ldm_vae_key) for k in keys):
|
||||
vae_key = ldm_vae_key
|
||||
|
||||
vae_key = LDM_VAE_KEY if any(k.startswith(LDM_VAE_KEY) for k in keys) else ""
|
||||
for key in keys:
|
||||
if key.startswith(vae_key):
|
||||
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
||||
@@ -1492,22 +1477,14 @@ def _legacy_load_scheduler(
|
||||
|
||||
if scheduler_type is not None:
|
||||
deprecation_message = (
|
||||
"Please pass an instance of a Scheduler object directly to the `scheduler` argument in `from_single_file`\n\n"
|
||||
"Example:\n\n"
|
||||
"from diffusers import StableDiffusionPipeline, DDIMScheduler\n\n"
|
||||
"scheduler = DDIMScheduler()\n"
|
||||
"pipe = StableDiffusionPipeline.from_single_file(<checkpoint path>, scheduler=scheduler)\n"
|
||||
"Please pass an instance of a Scheduler object directly to the `scheduler` argument in `from_single_file`."
|
||||
)
|
||||
deprecate("scheduler_type", "1.0.0", deprecation_message)
|
||||
|
||||
if prediction_type is not None:
|
||||
deprecation_message = (
|
||||
"Please configure an instance of a Scheduler with the appropriate `prediction_type` and "
|
||||
"pass the object directly to the `scheduler` argument in `from_single_file`.\n\n"
|
||||
"Example:\n\n"
|
||||
"from diffusers import StableDiffusionPipeline, DDIMScheduler\n\n"
|
||||
'scheduler = DDIMScheduler(prediction_type="v_prediction")\n'
|
||||
"pipe = StableDiffusionPipeline.from_single_file(<checkpoint path>, scheduler=scheduler)\n"
|
||||
"Please configure an instance of a Scheduler with the appropriate `prediction_type` "
|
||||
"and pass the object directly to the `scheduler` argument in `from_single_file`."
|
||||
)
|
||||
deprecate("prediction_type", "1.0.0", deprecation_message)
|
||||
|
||||
@@ -1904,10 +1881,6 @@ def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
|
||||
def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {}
|
||||
keys = list(checkpoint.keys())
|
||||
for k in keys:
|
||||
if "model.diffusion_model." in k:
|
||||
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
||||
|
||||
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
|
||||
num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
|
||||
|
||||
@@ -35,7 +35,6 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
||||
_import_structure["autoencoders.vq_model"] = ["VQModel"]
|
||||
_import_structure["controlnet"] = ["ControlNetModel"]
|
||||
_import_structure["controlnet_flux"] = ["FluxControlNetModel"]
|
||||
_import_structure["controlnet_hunyuan"] = ["HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel"]
|
||||
_import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
|
||||
_import_structure["controlnet_sparsectrl"] = ["SparseControlNetModel"]
|
||||
@@ -88,7 +87,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
VQModel,
|
||||
)
|
||||
from .controlnet import ControlNetModel
|
||||
from .controlnet_flux import FluxControlNetModel
|
||||
from .controlnet_hunyuan import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel
|
||||
from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
|
||||
from .controlnet_sparsectrl import SparseControlNetModel
|
||||
|
||||
@@ -1695,6 +1695,81 @@ class FusedAuraFlowAttnProcessor2_0:
|
||||
return hidden_states
|
||||
|
||||
|
||||
# YiYi to-do: refactor rope related functions/classes
|
||||
def apply_rope(xq, xk, freqs_cis):
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
|
||||
|
||||
class FluxSingleAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if image_rotary_emb is not None:
|
||||
# YiYi to-do: update uising apply_rotary_emb
|
||||
# from ..embeddings import apply_rotary_emb
|
||||
# query = apply_rotary_emb(query, image_rotary_emb)
|
||||
# key = apply_rotary_emb(key, image_rotary_emb)
|
||||
query, key = apply_rope(query, key, image_rotary_emb)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FluxAttnProcessor2_0:
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||
|
||||
@@ -1710,7 +1785,16 @@ class FluxAttnProcessor2_0:
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
input_ndim = hidden_states.ndim
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
context_input_ndim = encoder_hidden_states.ndim
|
||||
if context_input_ndim == 4:
|
||||
batch_size, channel, height, width = encoder_hidden_states.shape
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size = encoder_hidden_states.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
query = attn.to_q(hidden_states)
|
||||
@@ -1729,293 +1813,58 @@ class FluxAttnProcessor2_0:
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
||||
if encoder_hidden_states is not None:
|
||||
# `context` projections.
|
||||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
||||
|
||||
# attention
|
||||
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = (
|
||||
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FusedFluxAttnProcessor2_0:
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
|
||||
# `sample` projections.
|
||||
qkv = attn.to_qkv(hidden_states)
|
||||
split_size = qkv.shape[-1] // 3
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
||||
# `context` projections.
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
|
||||
split_size = encoder_qkv.shape[-1] // 3
|
||||
(
|
||||
encoder_hidden_states_query_proj,
|
||||
encoder_hidden_states_key_proj,
|
||||
encoder_hidden_states_value_proj,
|
||||
) = torch.split(encoder_qkv, split_size, dim=-1)
|
||||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
||||
|
||||
# attention
|
||||
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
||||
# attention
|
||||
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
# YiYi to-do: update uising apply_rotary_emb
|
||||
# from ..embeddings import apply_rotary_emb
|
||||
# query = apply_rotary_emb(query, image_rotary_emb)
|
||||
# key = apply_rotary_emb(key, image_rotary_emb)
|
||||
query, key = apply_rope(query, key, image_rotary_emb)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = (
|
||||
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CogVideoXAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
||||
query and key vectors, but does not include spatial normalization.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
encoder_hidden_states, hidden_states = (
|
||||
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if image_rotary_emb is not None:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
||||
if not attn.is_cross_attention:
|
||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
||||
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
||||
)
|
||||
return hidden_states, encoder_hidden_states
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
if context_input_ndim == 4:
|
||||
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
|
||||
class FusedCogVideoXAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
||||
query and key vectors, but does not include spatial normalization.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
qkv = attn.to_qkv(hidden_states)
|
||||
split_size = qkv.shape[-1] // 3
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if image_rotary_emb is not None:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
||||
if not attn.is_cross_attention:
|
||||
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
||||
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
||||
)
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
@@ -4256,17 +4105,6 @@ class LoRAAttnAddedKVProcessor:
|
||||
pass
|
||||
|
||||
|
||||
class FluxSingleAttnProcessor2_0(FluxAttnProcessor2_0):
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
deprecation_message = "`FluxSingleAttnProcessor2_0` is deprecated and will be removed in a future version. Please use `FluxAttnProcessor2_0` instead."
|
||||
deprecate("FluxSingleAttnProcessor2_0", "0.32.0", deprecation_message)
|
||||
super().__init__()
|
||||
|
||||
|
||||
ADDED_KV_ATTENTION_PROCESSORS = (
|
||||
AttnAddedKVProcessor,
|
||||
SlicedAttnAddedKVProcessor,
|
||||
|
||||
@@ -36,7 +36,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class CogVideoXSafeConv3d(nn.Conv3d):
|
||||
r"""
|
||||
"""
|
||||
A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
|
||||
"""
|
||||
|
||||
@@ -68,12 +68,12 @@ class CogVideoXCausalConv3d(nn.Module):
|
||||
r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`): Number of channels in the input tensor.
|
||||
out_channels (`int`): Number of output channels produced by the convolution.
|
||||
kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
|
||||
stride (`int`, defaults to `1`): Stride of the convolution.
|
||||
dilation (`int`, defaults to `1`): Dilation rate of the convolution.
|
||||
pad_mode (`str`, defaults to `"constant"`): Padding mode.
|
||||
in_channels (int): Number of channels in the input tensor.
|
||||
out_channels (int): Number of output channels.
|
||||
kernel_size (Union[int, Tuple[int, int, int]]): Size of the convolutional kernel.
|
||||
stride (int, optional): Stride of the convolution. Default is 1.
|
||||
dilation (int, optional): Dilation rate of the convolution. Default is 1.
|
||||
pad_mode (str, optional): Padding mode. Default is "constant".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -118,12 +118,19 @@ class CogVideoXCausalConv3d(nn.Module):
|
||||
self.conv_cache = None
|
||||
|
||||
def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
dim = self.temporal_dim
|
||||
kernel_size = self.time_kernel_size
|
||||
if kernel_size > 1:
|
||||
cached_inputs = (
|
||||
[self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
|
||||
)
|
||||
inputs = torch.cat(cached_inputs + [inputs], dim=2)
|
||||
if kernel_size == 1:
|
||||
return inputs
|
||||
|
||||
inputs = inputs.transpose(0, dim)
|
||||
|
||||
if self.conv_cache is not None:
|
||||
inputs = torch.cat([self.conv_cache.transpose(0, dim).to(inputs.device), inputs], dim=0)
|
||||
else:
|
||||
inputs = torch.cat([inputs[:1]] * (kernel_size - 1) + [inputs], dim=0)
|
||||
|
||||
inputs = inputs.transpose(0, dim).contiguous()
|
||||
return inputs
|
||||
|
||||
def _clear_fake_context_parallel_cache(self):
|
||||
@@ -131,17 +138,16 @@ class CogVideoXCausalConv3d(nn.Module):
|
||||
self.conv_cache = None
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
inputs = self.fake_context_parallel_forward(inputs)
|
||||
input_parallel = self.fake_context_parallel_forward(inputs)
|
||||
|
||||
self._clear_fake_context_parallel_cache()
|
||||
# Note: we could move these to the cpu for a lower maximum memory usage but its only a few
|
||||
# hundred megabytes and so let's not do it for now
|
||||
self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
|
||||
self.conv_cache = input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous().detach().clone().cpu()
|
||||
|
||||
padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
|
||||
inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
|
||||
input_parallel = F.pad(input_parallel, padding_2d, mode="constant", value=0)
|
||||
|
||||
output = self.conv(inputs)
|
||||
output_parallel = self.conv(input_parallel)
|
||||
output = output_parallel
|
||||
return output
|
||||
|
||||
|
||||
@@ -157,8 +163,6 @@ class CogVideoXSpatialNorm3D(nn.Module):
|
||||
The number of channels for input to group normalization layer, and output of the spatial norm layer.
|
||||
zq_channels (`int`):
|
||||
The number of channels for the quantized vector as described in the paper.
|
||||
groups (`int`):
|
||||
Number of groups to separate the channels into for group normalization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -193,26 +197,17 @@ class CogVideoXResnetBlock3D(nn.Module):
|
||||
A 3D ResNet block used in the CogVideoX model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
Number of input channels.
|
||||
out_channels (`int`, *optional*):
|
||||
Number of output channels. If None, defaults to `in_channels`.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
Dropout rate.
|
||||
temb_channels (`int`, defaults to `512`):
|
||||
Number of time embedding channels.
|
||||
groups (`int`, defaults to `32`):
|
||||
Number of groups to separate the channels into for group normalization.
|
||||
eps (`float`, defaults to `1e-6`):
|
||||
Epsilon value for normalization layers.
|
||||
non_linearity (`str`, defaults to `"swish"`):
|
||||
Activation function to use.
|
||||
conv_shortcut (bool, defaults to `False`):
|
||||
Whether or not to use a convolution shortcut.
|
||||
spatial_norm_dim (`int`, *optional*):
|
||||
The dimension to use for spatial norm if it is to be used instead of group norm.
|
||||
pad_mode (str, defaults to `"first"`):
|
||||
Padding mode.
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (Optional[int], optional):
|
||||
Number of output channels. If None, defaults to `in_channels`. Default is None.
|
||||
dropout (float, optional): Dropout rate. Default is 0.0.
|
||||
temb_channels (int, optional): Number of time embedding channels. Default is 512.
|
||||
groups (int, optional): Number of groups for group normalization. Default is 32.
|
||||
eps (float, optional): Epsilon value for normalization layers. Default is 1e-6.
|
||||
non_linearity (str, optional): Activation function to use. Default is "swish".
|
||||
conv_shortcut (bool, optional): If True, use a convolutional shortcut. Default is False.
|
||||
spatial_norm_dim (Optional[int], optional): Dimension of the spatial normalization. Default is None.
|
||||
pad_mode (str, optional): Padding mode. Default is "first".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -314,28 +309,18 @@ class CogVideoXDownBlock3D(nn.Module):
|
||||
A downsampling block used in the CogVideoX model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
Number of input channels.
|
||||
out_channels (`int`, *optional*):
|
||||
Number of output channels. If None, defaults to `in_channels`.
|
||||
temb_channels (`int`, defaults to `512`):
|
||||
Number of time embedding channels.
|
||||
num_layers (`int`, defaults to `1`):
|
||||
Number of resnet layers.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
Dropout rate.
|
||||
resnet_eps (`float`, defaults to `1e-6`):
|
||||
Epsilon value for normalization layers.
|
||||
resnet_act_fn (`str`, defaults to `"swish"`):
|
||||
Activation function to use.
|
||||
resnet_groups (`int`, defaults to `32`):
|
||||
Number of groups to separate the channels into for group normalization.
|
||||
add_downsample (`bool`, defaults to `True`):
|
||||
Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
|
||||
compress_time (`bool`, defaults to `False`):
|
||||
Whether or not to downsample across temporal dimension.
|
||||
pad_mode (str, defaults to `"first"`):
|
||||
Padding mode.
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
temb_channels (int): Number of time embedding channels.
|
||||
dropout (float, optional): Dropout rate. Default is 0.0.
|
||||
num_layers (int, optional): Number of layers in the block. Default is 1.
|
||||
resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6.
|
||||
resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish".
|
||||
resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32.
|
||||
add_downsample (bool, optional): If True, add a downsampling layer at the end of the block. Default is True.
|
||||
downsample_padding (int, optional): Padding for the downsampling layer. Default is 0.
|
||||
compress_time (bool, optional): If True, apply temporal compression. Default is False.
|
||||
pad_mode (str, optional): Padding mode. Default is "first".
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
@@ -420,24 +405,15 @@ class CogVideoXMidBlock3D(nn.Module):
|
||||
A middle block used in the CogVideoX model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
Number of input channels.
|
||||
temb_channels (`int`, defaults to `512`):
|
||||
Number of time embedding channels.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
Dropout rate.
|
||||
num_layers (`int`, defaults to `1`):
|
||||
Number of resnet layers.
|
||||
resnet_eps (`float`, defaults to `1e-6`):
|
||||
Epsilon value for normalization layers.
|
||||
resnet_act_fn (`str`, defaults to `"swish"`):
|
||||
Activation function to use.
|
||||
resnet_groups (`int`, defaults to `32`):
|
||||
Number of groups to separate the channels into for group normalization.
|
||||
spatial_norm_dim (`int`, *optional*):
|
||||
The dimension to use for spatial norm if it is to be used instead of group norm.
|
||||
pad_mode (str, defaults to `"first"`):
|
||||
Padding mode.
|
||||
in_channels (int): Number of input channels.
|
||||
temb_channels (int): Number of time embedding channels.
|
||||
dropout (float, optional): Dropout rate. Default is 0.0.
|
||||
num_layers (int, optional): Number of layers in the block. Default is 1.
|
||||
resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6.
|
||||
resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish".
|
||||
resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32.
|
||||
spatial_norm_dim (Optional[int], optional): Dimension of the spatial normalization. Default is None.
|
||||
pad_mode (str, optional): Padding mode. Default is "first".
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
@@ -504,30 +480,19 @@ class CogVideoXUpBlock3D(nn.Module):
|
||||
An upsampling block used in the CogVideoX model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`):
|
||||
Number of input channels.
|
||||
out_channels (`int`, *optional*):
|
||||
Number of output channels. If None, defaults to `in_channels`.
|
||||
temb_channels (`int`, defaults to `512`):
|
||||
Number of time embedding channels.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
Dropout rate.
|
||||
num_layers (`int`, defaults to `1`):
|
||||
Number of resnet layers.
|
||||
resnet_eps (`float`, defaults to `1e-6`):
|
||||
Epsilon value for normalization layers.
|
||||
resnet_act_fn (`str`, defaults to `"swish"`):
|
||||
Activation function to use.
|
||||
resnet_groups (`int`, defaults to `32`):
|
||||
Number of groups to separate the channels into for group normalization.
|
||||
spatial_norm_dim (`int`, defaults to `16`):
|
||||
The dimension to use for spatial norm if it is to be used instead of group norm.
|
||||
add_upsample (`bool`, defaults to `True`):
|
||||
Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension.
|
||||
compress_time (`bool`, defaults to `False`):
|
||||
Whether or not to downsample across temporal dimension.
|
||||
pad_mode (str, defaults to `"first"`):
|
||||
Padding mode.
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
temb_channels (int): Number of time embedding channels.
|
||||
dropout (float, optional): Dropout rate. Default is 0.0.
|
||||
num_layers (int, optional): Number of layers in the block. Default is 1.
|
||||
resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6.
|
||||
resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish".
|
||||
resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32.
|
||||
spatial_norm_dim (int, optional): Dimension of the spatial normalization. Default is 16.
|
||||
add_upsample (bool, optional): If True, add an upsampling layer at the end of the block. Default is True.
|
||||
upsample_padding (int, optional): Padding for the upsampling layer. Default is 1.
|
||||
compress_time (bool, optional): If True, apply temporal compression. Default is False.
|
||||
pad_mode (str, optional): Padding mode. Default is "first".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -622,12 +587,14 @@ class CogVideoXEncoder3D(nn.Module):
|
||||
options.
|
||||
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
||||
The number of output channels for each block.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`):
|
||||
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
||||
layers_per_block (`int`, *optional*, defaults to 2):
|
||||
The number of layers per block.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups for normalization.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`):
|
||||
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
||||
double_z (`bool`, *optional*, defaults to `True`):
|
||||
Whether to double the number of output channels for the last block.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
@@ -756,12 +723,14 @@ class CogVideoXDecoder3D(nn.Module):
|
||||
The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
|
||||
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
||||
The number of output channels for each block.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`):
|
||||
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
||||
layers_per_block (`int`, *optional*, defaults to 2):
|
||||
The number of layers per block.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups for normalization.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`):
|
||||
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
|
||||
norm_type (`str`, *optional*, defaults to `"group"`):
|
||||
The normalization type to use. Can be either `"group"` or `"spatial"`.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
@@ -902,7 +871,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
Tuple of block output channels.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
||||
scaling_factor (`float`, *optional*, defaults to `1.15258426`):
|
||||
scaling_factor (`float`, *optional*, defaults to 0.18215):
|
||||
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
||||
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
||||
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
||||
@@ -942,8 +911,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
norm_eps: float = 1e-6,
|
||||
norm_num_groups: int = 32,
|
||||
temporal_compression_ratio: float = 4,
|
||||
sample_height: int = 480,
|
||||
sample_width: int = 720,
|
||||
sample_size: int = 256,
|
||||
scaling_factor: float = 1.15258426,
|
||||
shift_factor: Optional[float] = None,
|
||||
latents_mean: Optional[Tuple[float]] = None,
|
||||
@@ -982,105 +950,25 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
self.use_slicing = False
|
||||
self.use_tiling = False
|
||||
|
||||
# Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
|
||||
# recommended because the temporal parts of the VAE, here, are tricky to understand.
|
||||
# If you decode X latent frames together, the number of output frames is:
|
||||
# (X + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) => X + 6 frames
|
||||
#
|
||||
# Example with num_latent_frames_batch_size = 2:
|
||||
# - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together
|
||||
# => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
|
||||
# => 6 * 8 = 48 frames
|
||||
# - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together
|
||||
# => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) +
|
||||
# ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
|
||||
# => 1 * 9 + 5 * 8 = 49 frames
|
||||
# It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that
|
||||
# setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
|
||||
# number of temporal frames.
|
||||
self.num_latent_frames_batch_size = 2
|
||||
|
||||
# We make the minimum height and width of sample for tiling half that of the generally supported
|
||||
self.tile_sample_min_height = sample_height // 2
|
||||
self.tile_sample_min_width = sample_width // 2
|
||||
self.tile_latent_min_height = int(
|
||||
self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
|
||||
self.tile_sample_min_size = self.config.sample_size
|
||||
sample_size = (
|
||||
self.config.sample_size[0]
|
||||
if isinstance(self.config.sample_size, (list, tuple))
|
||||
else self.config.sample_size
|
||||
)
|
||||
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
|
||||
|
||||
# These are experimental overlap factors that were chosen based on experimentation and seem to work best for
|
||||
# 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX
|
||||
# and so the tiling implementation has only been tested on those specific resolutions.
|
||||
self.tile_overlap_factor_height = 1 / 6
|
||||
self.tile_overlap_factor_width = 1 / 5
|
||||
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
|
||||
self.tile_overlap_factor = 0.25
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def _clear_fake_context_parallel_cache(self):
|
||||
def clear_fake_context_parallel_cache(self):
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, CogVideoXCausalConv3d):
|
||||
logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
|
||||
module._clear_fake_context_parallel_cache()
|
||||
|
||||
def enable_tiling(
|
||||
self,
|
||||
tile_sample_min_height: Optional[int] = None,
|
||||
tile_sample_min_width: Optional[int] = None,
|
||||
tile_overlap_factor_height: Optional[float] = None,
|
||||
tile_overlap_factor_width: Optional[float] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
|
||||
Args:
|
||||
tile_sample_min_height (`int`, *optional*):
|
||||
The minimum height required for a sample to be separated into tiles across the height dimension.
|
||||
tile_sample_min_width (`int`, *optional*):
|
||||
The minimum width required for a sample to be separated into tiles across the width dimension.
|
||||
tile_overlap_factor_height (`int`, *optional*):
|
||||
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
|
||||
no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
|
||||
value might cause more tiles to be processed leading to slow down of the decoding process.
|
||||
tile_overlap_factor_width (`int`, *optional*):
|
||||
The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
|
||||
are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
|
||||
value might cause more tiles to be processed leading to slow down of the decoding process.
|
||||
"""
|
||||
self.use_tiling = True
|
||||
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
|
||||
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
|
||||
self.tile_latent_min_height = int(
|
||||
self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
|
||||
)
|
||||
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
|
||||
self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
|
||||
self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
|
||||
|
||||
def disable_tiling(self) -> None:
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_tiling = False
|
||||
|
||||
def enable_slicing(self) -> None:
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.use_slicing = True
|
||||
|
||||
def disable_slicing(self) -> None:
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
@@ -1105,34 +993,8 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
batch_size, num_channels, num_frames, height, width = z.shape
|
||||
|
||||
if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
|
||||
return self.tiled_decode(z, return_dict=return_dict)
|
||||
|
||||
frame_batch_size = self.num_latent_frames_batch_size
|
||||
dec = []
|
||||
for i in range(num_frames // frame_batch_size):
|
||||
remaining_frames = num_frames % frame_batch_size
|
||||
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
|
||||
end_frame = frame_batch_size * (i + 1) + remaining_frames
|
||||
z_intermediate = z[:, :, start_frame:end_frame]
|
||||
if self.post_quant_conv is not None:
|
||||
z_intermediate = self.post_quant_conv(z_intermediate)
|
||||
z_intermediate = self.decoder(z_intermediate)
|
||||
dec.append(z_intermediate)
|
||||
|
||||
self._clear_fake_context_parallel_cache()
|
||||
dec = torch.cat(dec, dim=2)
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
"""
|
||||
Decode a batch of images.
|
||||
|
||||
@@ -1145,111 +1007,13 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
|
||||
"""
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z).sample
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
||||
for y in range(blend_extent):
|
||||
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
|
||||
y / blend_extent
|
||||
)
|
||||
return b
|
||||
|
||||
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
|
||||
for x in range(blend_extent):
|
||||
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
|
||||
x / blend_extent
|
||||
)
|
||||
return b
|
||||
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
r"""
|
||||
Decode a batch of images using a tiled decoder.
|
||||
|
||||
Args:
|
||||
z (`torch.Tensor`): Input batch of latent vectors.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
"""
|
||||
# Rough memory assessment:
|
||||
# - In CogVideoX-2B, there are a total of 24 CausalConv3d layers.
|
||||
# - The biggest intermediate dimensions are: [1, 128, 9, 480, 720].
|
||||
# - Assume fp16 (2 bytes per value).
|
||||
# Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB
|
||||
#
|
||||
# Memory assessment when using tiling:
|
||||
# - Assume everything as above but now HxW is 240x360 by tiling in half
|
||||
# Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB
|
||||
|
||||
batch_size, num_channels, num_frames, height, width = z.shape
|
||||
|
||||
overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
|
||||
overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
|
||||
blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
|
||||
blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
|
||||
row_limit_height = self.tile_sample_min_height - blend_extent_height
|
||||
row_limit_width = self.tile_sample_min_width - blend_extent_width
|
||||
frame_batch_size = self.num_latent_frames_batch_size
|
||||
|
||||
# Split z into overlapping tiles and decode them separately.
|
||||
# The tiles have an overlap to avoid seams between tiles.
|
||||
rows = []
|
||||
for i in range(0, height, overlap_height):
|
||||
row = []
|
||||
for j in range(0, width, overlap_width):
|
||||
time = []
|
||||
for k in range(num_frames // frame_batch_size):
|
||||
remaining_frames = num_frames % frame_batch_size
|
||||
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
|
||||
end_frame = frame_batch_size * (k + 1) + remaining_frames
|
||||
tile = z[
|
||||
:,
|
||||
:,
|
||||
start_frame:end_frame,
|
||||
i : i + self.tile_latent_min_height,
|
||||
j : j + self.tile_latent_min_width,
|
||||
]
|
||||
if self.post_quant_conv is not None:
|
||||
tile = self.post_quant_conv(tile)
|
||||
tile = self.decoder(tile)
|
||||
time.append(tile)
|
||||
self._clear_fake_context_parallel_cache()
|
||||
row.append(torch.cat(time, dim=2))
|
||||
rows.append(row)
|
||||
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
# blend the above tile and the left tile
|
||||
# to the current tile and add the current tile to the result row
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_extent_width)
|
||||
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
|
||||
result_rows.append(torch.cat(result_row, dim=4))
|
||||
|
||||
dec = torch.cat(result_rows, dim=3)
|
||||
|
||||
if self.post_quant_conv is not None:
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z)
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -1,386 +0,0 @@
|
||||
# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import PeftAdapterMixin
|
||||
from ..models.attention_processor import AttentionProcessor
|
||||
from ..models.modeling_utils import ModelMixin
|
||||
from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from .controlnet import BaseOutput, zero_module
|
||||
from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
|
||||
from .modeling_outputs import Transformer2DModelOutput
|
||||
from .transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class FluxControlNetOutput(BaseOutput):
|
||||
controlnet_block_samples: Tuple[torch.Tensor]
|
||||
controlnet_single_block_samples: Tuple[torch.Tensor]
|
||||
|
||||
|
||||
class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 1,
|
||||
in_channels: int = 64,
|
||||
num_layers: int = 19,
|
||||
num_single_layers: int = 38,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 24,
|
||||
joint_attention_dim: int = 4096,
|
||||
pooled_projection_dim: int = 768,
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: List[int] = [16, 56, 56],
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
|
||||
text_time_guidance_cls = (
|
||||
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
|
||||
)
|
||||
self.time_text_embed = text_time_guidance_cls(
|
||||
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
|
||||
)
|
||||
|
||||
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
|
||||
self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
FluxTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
FluxSingleTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
)
|
||||
for i in range(num_single_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# controlnet_blocks
|
||||
self.controlnet_blocks = nn.ModuleList([])
|
||||
for _ in range(len(self.transformer_blocks)):
|
||||
self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
|
||||
|
||||
self.controlnet_single_blocks = nn.ModuleList([])
|
||||
for _ in range(len(self.single_transformer_blocks)):
|
||||
self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
|
||||
|
||||
self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self):
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@classmethod
|
||||
def from_transformer(
|
||||
cls,
|
||||
transformer,
|
||||
num_layers=4,
|
||||
num_single_layers=10,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 24,
|
||||
load_weights_from_transformer=True,
|
||||
):
|
||||
config = transformer.config
|
||||
config["num_layers"] = num_layers
|
||||
config["num_single_layers"] = num_single_layers
|
||||
config["attention_head_dim"] = attention_head_dim
|
||||
config["num_attention_heads"] = num_attention_heads
|
||||
|
||||
controlnet = cls(**config)
|
||||
|
||||
if load_weights_from_transformer:
|
||||
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
|
||||
controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
|
||||
controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
|
||||
controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict())
|
||||
controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
|
||||
controlnet.single_transformer_blocks.load_state_dict(
|
||||
transformer.single_transformer_blocks.state_dict(), strict=False
|
||||
)
|
||||
|
||||
controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder)
|
||||
|
||||
return controlnet
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
controlnet_cond: torch.Tensor,
|
||||
conditioning_scale: float = 1.0,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
pooled_projections: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
img_ids: torch.Tensor = None,
|
||||
txt_ids: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`FluxTransformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
||||
from the embeddings of input conditions.
|
||||
timestep ( `torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
||||
A list of tensors that if specified are added to the residuals of transformer blocks.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
# add
|
||||
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
|
||||
|
||||
timestep = timestep.to(hidden_states.dtype) * 1000
|
||||
if guidance is not None:
|
||||
guidance = guidance.to(hidden_states.dtype) * 1000
|
||||
else:
|
||||
guidance = None
|
||||
temb = (
|
||||
self.time_text_embed(timestep, pooled_projections)
|
||||
if guidance is None
|
||||
else self.time_text_embed(timestep, guidance, pooled_projections)
|
||||
)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
if txt_ids.ndim == 3:
|
||||
logger.warning(
|
||||
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
||||
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
||||
)
|
||||
txt_ids = txt_ids[0]
|
||||
if img_ids.ndim == 3:
|
||||
logger.warning(
|
||||
"Passing `img_ids` 3d torch.Tensor is deprecated."
|
||||
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
||||
)
|
||||
img_ids = img_ids[0]
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
block_samples = ()
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
block_samples = block_samples + (hidden_states,)
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
single_block_samples = ()
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
|
||||
|
||||
# controlnet block
|
||||
controlnet_block_samples = ()
|
||||
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
|
||||
block_sample = controlnet_block(block_sample)
|
||||
controlnet_block_samples = controlnet_block_samples + (block_sample,)
|
||||
|
||||
controlnet_single_block_samples = ()
|
||||
for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks):
|
||||
single_block_sample = controlnet_block(single_block_sample)
|
||||
controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,)
|
||||
|
||||
# scaling
|
||||
controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
|
||||
controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples]
|
||||
|
||||
#
|
||||
controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
|
||||
controlnet_single_block_samples = (
|
||||
None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
|
||||
)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (controlnet_block_samples, controlnet_single_block_samples)
|
||||
|
||||
return FluxControlNetOutput(
|
||||
controlnet_block_samples=controlnet_block_samples,
|
||||
controlnet_single_block_samples=controlnet_single_block_samples,
|
||||
)
|
||||
@@ -55,7 +55,6 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
pooled_projection_dim: int = 2048,
|
||||
out_channels: int = 16,
|
||||
pos_embed_max_size: int = 96,
|
||||
extra_conditioning_channels: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
default_out_channels = in_channels
|
||||
@@ -99,7 +98,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
height=sample_size,
|
||||
width=sample_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels + extra_conditioning_channels,
|
||||
in_channels=in_channels,
|
||||
embed_dim=self.inner_dim,
|
||||
pos_embed_type=None,
|
||||
)
|
||||
|
||||
@@ -374,90 +374,6 @@ class CogVideoXPatchEmbed(nn.Module):
|
||||
return embeds
|
||||
|
||||
|
||||
def get_3d_rotary_pos_embed(
|
||||
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
RoPE for video tokens with 3D structure.
|
||||
|
||||
Args:
|
||||
embed_dim: (`int`):
|
||||
The embedding dimension size, corresponding to hidden_size_head.
|
||||
crops_coords (`Tuple[int]`):
|
||||
The top-left and bottom-right coordinates of the crop.
|
||||
grid_size (`Tuple[int]`):
|
||||
The grid size of the spatial positional embedding (height, width).
|
||||
temporal_size (`int`):
|
||||
The size of the temporal dimension.
|
||||
theta (`float`):
|
||||
Scaling factor for frequency computation.
|
||||
use_real (`bool`):
|
||||
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
|
||||
"""
|
||||
start, stop = crops_coords
|
||||
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
|
||||
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
|
||||
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
|
||||
|
||||
# Compute dimensions for each axis
|
||||
dim_t = embed_dim // 4
|
||||
dim_h = embed_dim // 8 * 3
|
||||
dim_w = embed_dim // 8 * 3
|
||||
|
||||
# Temporal frequencies
|
||||
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
|
||||
grid_t = torch.from_numpy(grid_t).float()
|
||||
freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
|
||||
freqs_t = freqs_t.repeat_interleave(2, dim=-1)
|
||||
|
||||
# Spatial frequencies for height and width
|
||||
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
|
||||
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
|
||||
grid_h = torch.from_numpy(grid_h).float()
|
||||
grid_w = torch.from_numpy(grid_w).float()
|
||||
freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
|
||||
freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
|
||||
freqs_h = freqs_h.repeat_interleave(2, dim=-1)
|
||||
freqs_w = freqs_w.repeat_interleave(2, dim=-1)
|
||||
|
||||
# Broadcast and concatenate tensors along specified dimension
|
||||
def broadcast(tensors, dim=-1):
|
||||
num_tensors = len(tensors)
|
||||
shape_lens = {len(t.shape) for t in tensors}
|
||||
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
||||
shape_len = list(shape_lens)[0]
|
||||
dim = (dim + shape_len) if dim < 0 else dim
|
||||
dims = list(zip(*(list(t.shape) for t in tensors)))
|
||||
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
||||
assert all(
|
||||
[*(len(set(t[1])) <= 2 for t in expandable_dims)]
|
||||
), "invalid dimensions for broadcastable concatenation"
|
||||
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
|
||||
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
|
||||
expanded_dims.insert(dim, (dim, dims[dim]))
|
||||
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
|
||||
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
|
||||
return torch.cat(tensors, dim=dim)
|
||||
|
||||
freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
|
||||
|
||||
t, h, w, d = freqs.shape
|
||||
freqs = freqs.view(t * h * w, d)
|
||||
|
||||
# Generate sine and cosine components
|
||||
sin = freqs.sin()
|
||||
cos = freqs.cos()
|
||||
|
||||
if use_real:
|
||||
return cos, sin
|
||||
else:
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
|
||||
"""
|
||||
RoPE for image tokens with 2d structure.
|
||||
@@ -530,7 +446,6 @@ def get_1d_rotary_pos_embed(
|
||||
linear_factor=1.0,
|
||||
ntk_factor=1.0,
|
||||
repeat_interleave_real=True,
|
||||
freqs_dtype=torch.float32, # torch.float32 (hunyuan, stable audio), torch.float64 (flux)
|
||||
):
|
||||
"""
|
||||
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
||||
@@ -553,8 +468,6 @@ def get_1d_rotary_pos_embed(
|
||||
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
|
||||
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
|
||||
Otherwise, they are concateanted with themselves.
|
||||
freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`):
|
||||
the dtype of the frequency tensor.
|
||||
Returns:
|
||||
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
||||
"""
|
||||
@@ -563,19 +476,19 @@ def get_1d_rotary_pos_embed(
|
||||
if isinstance(pos, int):
|
||||
pos = np.arange(pos)
|
||||
theta = theta * ntk_factor
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2]
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) / linear_factor # [D/2]
|
||||
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
|
||||
freqs = torch.outer(t, freqs) # type: ignore # [S, D/2]
|
||||
freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
|
||||
if use_real and repeat_interleave_real:
|
||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
|
||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
|
||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
|
||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
elif use_real:
|
||||
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
|
||||
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
|
||||
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1) # [S, D]
|
||||
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1) # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
else:
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs).float() # complex64 # [S, D/2]
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
||||
return freqs_cis
|
||||
|
||||
|
||||
@@ -627,31 +540,6 @@ def apply_rotary_emb(
|
||||
return x_out.type_as(x)
|
||||
|
||||
|
||||
class FluxPosEmbed(nn.Module):
|
||||
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
|
||||
def __init__(self, theta: int, axes_dim: List[int]):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
cos_out = []
|
||||
sin_out = []
|
||||
pos = ids.squeeze().float().cpu().numpy()
|
||||
is_mps = ids.device.type == "mps"
|
||||
freqs_dtype = torch.float32 if is_mps else torch.float64
|
||||
for i in range(n_axes):
|
||||
cos, sin = get_1d_rotary_pos_embed(
|
||||
self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype
|
||||
)
|
||||
cos_out.append(cos)
|
||||
sin_out.append(sin)
|
||||
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
||||
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -263,6 +263,41 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
"""
|
||||
self.set_use_memory_efficient_attention_xformers(False)
|
||||
|
||||
def enable_dynamic_upcasting(self, upcast_dtype=None):
|
||||
upcast_dtype = upcast_dtype or torch.float32
|
||||
downcast_dtype = self.dtype
|
||||
|
||||
def upcast_hook_fn(module):
|
||||
module = module.to(upcast_dtype)
|
||||
|
||||
def downcast_hook_fn(module):
|
||||
module = module.to(downcast_dtype)
|
||||
|
||||
def fn_recursive_upcast(module):
|
||||
has_children = list(module.children())
|
||||
if not has_children:
|
||||
module.register_forward_pre_hook(upcast_hook_fn)
|
||||
module.register_forward_hook(downcast_hook_fn)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_upcast(child)
|
||||
|
||||
for module in self.children():
|
||||
fn_recursive_upcast(module)
|
||||
|
||||
def disable_dynamic_upcasting(self):
|
||||
def fn_recursive_upcast(module):
|
||||
has_children = list(module.children())
|
||||
if not has_children:
|
||||
module._forward_pre_hooks = OrderedDict()
|
||||
module._forward_hooks = OrderedDict()
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_upcast(child)
|
||||
|
||||
for module in self.children():
|
||||
fn_recursive_upcast(module)
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
|
||||
@@ -68,21 +68,6 @@ class AuraFlowPatchEmbed(nn.Module):
|
||||
self.height, self.width = height // patch_size, width // patch_size
|
||||
self.base_size = height // patch_size
|
||||
|
||||
def pe_selection_index_based_on_dim(self, h, w):
|
||||
# select subset of positional embedding based on H, W, where H, W is size of latent
|
||||
# PE will be viewed as 2d-grid, and H/p x W/p of the PE will be selected
|
||||
# because original input are in flattened format, we have to flatten this 2d grid as well.
|
||||
h_p, w_p = h // self.patch_size, w // self.patch_size
|
||||
original_pe_indexes = torch.arange(self.pos_embed.shape[1])
|
||||
h_max, w_max = int(self.pos_embed_max_size**0.5), int(self.pos_embed_max_size**0.5)
|
||||
original_pe_indexes = original_pe_indexes.view(h_max, w_max)
|
||||
starth = h_max // 2 - h_p // 2
|
||||
endh = starth + h_p
|
||||
startw = w_max // 2 - w_p // 2
|
||||
endw = startw + w_p
|
||||
original_pe_indexes = original_pe_indexes[starth:endh, startw:endw]
|
||||
return original_pe_indexes.flatten()
|
||||
|
||||
def forward(self, latent):
|
||||
batch_size, num_channels, height, width = latent.size()
|
||||
latent = latent.view(
|
||||
@@ -95,8 +80,7 @@ class AuraFlowPatchEmbed(nn.Module):
|
||||
)
|
||||
latent = latent.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
|
||||
latent = self.proj(latent)
|
||||
pe_index = self.pe_selection_index_based_on_dim(height, width)
|
||||
return latent + self.pos_embed[:, pe_index]
|
||||
return latent + self.pos_embed
|
||||
|
||||
|
||||
# Taken from the original Aura flow inference code.
|
||||
@@ -274,7 +258,6 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
pos_embed_max_size (`int`, defaults to 4096): Maximum positions to embed from the image latents.
|
||||
"""
|
||||
|
||||
_no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"]
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -22,7 +22,6 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import is_torch_version, logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import Attention, FeedForward
|
||||
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
|
||||
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
@@ -38,20 +37,13 @@ class CogVideoXBlock(nn.Module):
|
||||
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
|
||||
|
||||
Parameters:
|
||||
dim (`int`):
|
||||
The number of channels in the input and output.
|
||||
num_attention_heads (`int`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`):
|
||||
The number of channels in each head.
|
||||
time_embed_dim (`int`):
|
||||
The number of channels in timestep embedding.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
The dropout probability to use.
|
||||
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
||||
Activation function to be used in feed-forward.
|
||||
attention_bias (`bool`, defaults to `False`):
|
||||
Whether or not to use bias in attention projection layers.
|
||||
dim (`int`): The number of channels in the input and output.
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
attention_bias (:
|
||||
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
||||
qk_norm (`bool`, defaults to `True`):
|
||||
Whether or not to use normalization after query and key projections in Attention.
|
||||
norm_elementwise_affine (`bool`, defaults to `True`):
|
||||
@@ -98,7 +90,6 @@ class CogVideoXBlock(nn.Module):
|
||||
eps=1e-6,
|
||||
bias=attention_bias,
|
||||
out_bias=attention_out_bias,
|
||||
processor=CogVideoXAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
# 2. Feed Forward
|
||||
@@ -118,24 +109,24 @@ class CogVideoXBlock(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
# norm & modulate
|
||||
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
||||
hidden_states, encoder_hidden_states, temb
|
||||
)
|
||||
|
||||
# attention
|
||||
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
||||
text_length = norm_encoder_hidden_states.size(1)
|
||||
|
||||
# CogVideoX uses concatenated text + video embeddings with self-attention instead of using
|
||||
# them in cross-attention individually
|
||||
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
||||
attn_output = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
encoder_hidden_states=None,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
|
||||
hidden_states = hidden_states + gate_msa * attn_output[:, text_length:]
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_output[:, :text_length]
|
||||
|
||||
# norm & modulate
|
||||
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
|
||||
@@ -146,9 +137,8 @@ class CogVideoXBlock(nn.Module):
|
||||
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
|
||||
|
||||
hidden_states = hidden_states + gate_ff * ff_output[:, text_length:]
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_length]
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
@@ -157,53 +147,36 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, defaults to `30`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, defaults to `64`):
|
||||
The number of channels in each head.
|
||||
in_channels (`int`, defaults to `16`):
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, *optional*, defaults to `16`):
|
||||
out_channels (`int`, *optional*):
|
||||
The number of channels in the output.
|
||||
flip_sin_to_cos (`bool`, defaults to `True`):
|
||||
Whether to flip the sin to cos in the time embedding.
|
||||
time_embed_dim (`int`, defaults to `512`):
|
||||
Output dimension of timestep embeddings.
|
||||
text_embed_dim (`int`, defaults to `4096`):
|
||||
Input dimension of text embeddings from the text encoder.
|
||||
num_layers (`int`, defaults to `30`):
|
||||
The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
The dropout probability to use.
|
||||
attention_bias (`bool`, defaults to `True`):
|
||||
Whether or not to use bias in the attention projection layers.
|
||||
sample_width (`int`, defaults to `90`):
|
||||
The width of the input latents.
|
||||
sample_height (`int`, defaults to `60`):
|
||||
The height of the input latents.
|
||||
sample_frames (`int`, defaults to `49`):
|
||||
The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
|
||||
instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
|
||||
but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
|
||||
K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
|
||||
patch_size (`int`, defaults to `2`):
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||
attention_bias (`bool`, *optional*):
|
||||
Configure if the `TransformerBlocks` attention should contain a bias parameter.
|
||||
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
||||
This is fixed during training since it is used to learn a number of position embeddings.
|
||||
patch_size (`int`, *optional*):
|
||||
The size of the patches to use in the patch embedding layer.
|
||||
temporal_compression_ratio (`int`, defaults to `4`):
|
||||
The compression ratio across the temporal dimension. See documentation for `sample_frames`.
|
||||
max_text_seq_length (`int`, defaults to `226`):
|
||||
The maximum sequence length of the input text embeddings.
|
||||
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
||||
Activation function to use in feed-forward.
|
||||
timestep_activation_fn (`str`, defaults to `"silu"`):
|
||||
Activation function to use when generating the timestep embeddings.
|
||||
norm_elementwise_affine (`bool`, defaults to `True`):
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
|
||||
num_embeds_ada_norm ( `int`, *optional*):
|
||||
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
|
||||
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
|
||||
added to the hidden states. During inference, you can denoise for up to but not more steps than
|
||||
`num_embeds_ada_norm`.
|
||||
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
||||
The type of normalization to use. Options are `"layer_norm"` or `"ada_layer_norm"`.
|
||||
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to use elementwise affine in normalization layers.
|
||||
norm_eps (`float`, defaults to `1e-5`):
|
||||
The epsilon value to use in normalization layers.
|
||||
spatial_interpolation_scale (`float`, defaults to `1.875`):
|
||||
Scaling factor to apply in 3D positional embeddings across spatial dimensions.
|
||||
temporal_interpolation_scale (`float`, defaults to `1.0`):
|
||||
Scaling factor to apply in 3D positional embeddings across temporal dimensions.
|
||||
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use in normalization layers.
|
||||
caption_channels (`int`, *optional*):
|
||||
The number of channels in the caption embeddings.
|
||||
video_length (`int`, *optional*):
|
||||
The number of frames in the video-like data.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
@@ -213,7 +186,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
self,
|
||||
num_attention_heads: int = 30,
|
||||
attention_head_dim: int = 64,
|
||||
in_channels: int = 16,
|
||||
in_channels: Optional[int] = 16,
|
||||
out_channels: Optional[int] = 16,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
@@ -234,7 +207,6 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
norm_eps: float = 1e-5,
|
||||
spatial_interpolation_scale: float = 1.875,
|
||||
temporal_interpolation_scale: float = 1.0,
|
||||
use_rotary_positional_embeddings: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
@@ -299,113 +271,12 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
self.gradient_checkpointing = value
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: Union[int, float, torch.LongTensor],
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
batch_size, num_frames, channels, height, width = hidden_states.shape
|
||||
@@ -424,18 +295,16 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
||||
|
||||
# 3. Position embedding
|
||||
text_seq_length = encoder_hidden_states.shape[1]
|
||||
if not self.config.use_rotary_positional_embeddings:
|
||||
seq_length = height * width * num_frames // (self.config.patch_size**2)
|
||||
seq_length = height * width * num_frames // (self.config.patch_size**2)
|
||||
|
||||
pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
|
||||
hidden_states = hidden_states + pos_embeds
|
||||
hidden_states = self.embedding_dropout(hidden_states)
|
||||
pos_embeds = self.pos_embedding[:, : self.config.max_text_seq_length + seq_length]
|
||||
hidden_states = hidden_states + pos_embeds
|
||||
hidden_states = self.embedding_dropout(hidden_states)
|
||||
|
||||
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
||||
hidden_states = hidden_states[:, text_seq_length:]
|
||||
encoder_hidden_states = hidden_states[:, : self.config.max_text_seq_length]
|
||||
hidden_states = hidden_states[:, self.config.max_text_seq_length :]
|
||||
|
||||
# 4. Transformer blocks
|
||||
# 5. Transformer blocks
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
@@ -451,7 +320,6 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
emb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
@@ -459,23 +327,15 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=emb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
if not self.config.use_rotary_positional_embeddings:
|
||||
# CogVideoX-2B
|
||||
hidden_states = self.norm_final(hidden_states)
|
||||
else:
|
||||
# CogVideoX-5B
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
hidden_states = self.norm_final(hidden_states)
|
||||
hidden_states = hidden_states[:, text_seq_length:]
|
||||
hidden_states = self.norm_final(hidden_states)
|
||||
|
||||
# 5. Final block
|
||||
# 6. Final block
|
||||
hidden_states = self.norm_out(hidden_states, temb=emb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# 6. Unpatchify
|
||||
# 7. Unpatchify
|
||||
p = self.config.patch_size
|
||||
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
|
||||
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
||||
|
||||
@@ -19,7 +19,7 @@ from torch import nn
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import is_torch_version, logging
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ..attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0
|
||||
from ..attention_processor import Attention, AttentionProcessor, FusedAttnProcessor2_0
|
||||
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
@@ -247,14 +247,6 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
|
||||
Safe to just use `AttnProcessor()` as PixArt doesn't have any exotic attention processors in default model.
|
||||
"""
|
||||
self.set_attn_processor(AttnProcessor())
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
|
||||
# Copyright 2024 Black Forest Labs, The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -13,9 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@@ -23,23 +22,52 @@ import torch.nn.functional as F
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...models.attention import FeedForward
|
||||
from ...models.attention_processor import (
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
FluxAttnProcessor2_0,
|
||||
FusedFluxAttnProcessor2_0,
|
||||
)
|
||||
from ...models.attention_processor import Attention, FluxAttnProcessor2_0, FluxSingleAttnProcessor2_0
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
|
||||
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# YiYi to-do: refactor rope related functions/classes
|
||||
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
||||
assert dim % 2 == 0, "The dimension must be even."
|
||||
|
||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||
omega = 1.0 / (theta**scale)
|
||||
|
||||
batch_size, seq_length = pos.shape
|
||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||
cos_out = torch.cos(out)
|
||||
sin_out = torch.sin(out)
|
||||
|
||||
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
|
||||
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
|
||||
return out.float()
|
||||
|
||||
|
||||
# YiYi to-do: refactor rope related functions/classes
|
||||
class EmbedND(nn.Module):
|
||||
def __init__(self, dim: int, theta: int, axes_dim: List[int]):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
emb = torch.cat(
|
||||
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
||||
dim=-3,
|
||||
)
|
||||
return emb.unsqueeze(1)
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class FluxSingleTransformerBlock(nn.Module):
|
||||
r"""
|
||||
@@ -64,7 +92,7 @@ class FluxSingleTransformerBlock(nn.Module):
|
||||
self.act_mlp = nn.GELU(approximate="tanh")
|
||||
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
||||
|
||||
processor = FluxAttnProcessor2_0()
|
||||
processor = FluxSingleAttnProcessor2_0()
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
@@ -222,7 +250,6 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
@@ -236,14 +263,13 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
joint_attention_dim: int = 4096,
|
||||
pooled_projection_dim: int = 768,
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: Tuple[int] = (16, 56, 56),
|
||||
axes_dims_rope: List[int] = [16, 56, 56],
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = in_channels
|
||||
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
||||
|
||||
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
|
||||
|
||||
self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=axes_dims_rope)
|
||||
text_time_guidance_cls = (
|
||||
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
|
||||
)
|
||||
@@ -281,106 +307,6 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
self.set_attn_processor(FusedFluxAttnProcessor2_0())
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
@@ -395,8 +321,6 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
txt_ids: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
controlnet_block_samples=None,
|
||||
controlnet_single_block_samples=None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
@@ -453,19 +377,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
if txt_ids.ndim == 3:
|
||||
logger.warning(
|
||||
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
||||
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
||||
)
|
||||
txt_ids = txt_ids[0]
|
||||
if img_ids.ndim == 3:
|
||||
logger.warning(
|
||||
"Passing `img_ids` 3d torch.Tensor is deprecated."
|
||||
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
||||
)
|
||||
img_ids = img_ids[0]
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
@@ -498,12 +410,6 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
# controlnet residual
|
||||
if controlnet_block_samples is not None:
|
||||
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
|
||||
interval_control = int(np.ceil(interval_control))
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
@@ -534,15 +440,6 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
# controlnet residual
|
||||
if controlnet_single_block_samples is not None:
|
||||
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
|
||||
interval_control = int(np.ceil(interval_control))
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
+ controlnet_single_block_samples[index_block // interval_control]
|
||||
)
|
||||
|
||||
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
|
||||
@@ -124,7 +124,7 @@ else:
|
||||
"AnimateDiffSparseControlNetPipeline",
|
||||
"AnimateDiffVideoToVideoPipeline",
|
||||
]
|
||||
_import_structure["flux"] = ["FluxPipeline", "FluxControlNetPipeline"]
|
||||
_import_structure["flux"] = ["FluxPipeline"]
|
||||
_import_structure["audioldm"] = ["AudioLDMPipeline"]
|
||||
_import_structure["audioldm2"] = [
|
||||
"AudioLDM2Pipeline",
|
||||
@@ -154,7 +154,6 @@ else:
|
||||
"StableDiffusionControlNetPAGPipeline",
|
||||
"StableDiffusionXLPAGPipeline",
|
||||
"StableDiffusionXLPAGInpaintPipeline",
|
||||
"StableDiffusionXLControlNetPAGImg2ImgPipeline",
|
||||
"StableDiffusionXLControlNetPAGPipeline",
|
||||
"StableDiffusionXLPAGImg2ImgPipeline",
|
||||
"PixArtSigmaPAGPipeline",
|
||||
@@ -174,7 +173,6 @@ else:
|
||||
_import_structure["controlnet_sd3"].extend(
|
||||
[
|
||||
"StableDiffusion3ControlNetPipeline",
|
||||
"StableDiffusion3ControlNetInpaintingPipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["deepfloyd_if"] = [
|
||||
@@ -467,7 +465,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .controlnet_hunyuandit import (
|
||||
HunyuanDiTControlNetPipeline,
|
||||
)
|
||||
from .controlnet_sd3 import StableDiffusion3ControlNetInpaintingPipeline, StableDiffusion3ControlNetPipeline
|
||||
from .controlnet_sd3 import (
|
||||
StableDiffusion3ControlNetPipeline,
|
||||
)
|
||||
from .controlnet_xs import (
|
||||
StableDiffusionControlNetXSPipeline,
|
||||
StableDiffusionXLControlNetXSPipeline,
|
||||
@@ -494,7 +494,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
VersatileDiffusionTextToImagePipeline,
|
||||
VQDiffusionPipeline,
|
||||
)
|
||||
from .flux import FluxControlNetPipeline, FluxPipeline
|
||||
from .flux import FluxPipeline
|
||||
from .hunyuandit import HunyuanDiTPipeline
|
||||
from .i2vgen_xl import I2VGenXLPipeline
|
||||
from .kandinsky import (
|
||||
@@ -548,7 +548,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusion3PAGPipeline,
|
||||
StableDiffusionControlNetPAGPipeline,
|
||||
StableDiffusionPAGPipeline,
|
||||
StableDiffusionXLControlNetPAGImg2ImgPipeline,
|
||||
StableDiffusionXLControlNetPAGPipeline,
|
||||
StableDiffusionXLPAGImg2ImgPipeline,
|
||||
StableDiffusionXLPAGInpaintPipeline,
|
||||
|
||||
@@ -49,14 +49,12 @@ from .kandinsky2_2 import (
|
||||
)
|
||||
from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline
|
||||
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
|
||||
from .lumina import LuminaText2ImgPipeline
|
||||
from .pag import (
|
||||
HunyuanDiTPAGPipeline,
|
||||
PixArtSigmaPAGPipeline,
|
||||
StableDiffusion3PAGPipeline,
|
||||
StableDiffusionControlNetPAGPipeline,
|
||||
StableDiffusionPAGPipeline,
|
||||
StableDiffusionXLControlNetPAGImg2ImgPipeline,
|
||||
StableDiffusionXLControlNetPAGPipeline,
|
||||
StableDiffusionXLPAGImg2ImgPipeline,
|
||||
StableDiffusionXLPAGInpaintPipeline,
|
||||
@@ -108,7 +106,6 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("pixart-sigma-pag", PixArtSigmaPAGPipeline),
|
||||
("auraflow", AuraFlowPipeline),
|
||||
("flux", FluxPipeline),
|
||||
("lumina", LuminaText2ImgPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -124,7 +121,6 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline),
|
||||
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline),
|
||||
("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline),
|
||||
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline),
|
||||
("lcm", LatentConsistencyModelImg2ImgPipeline),
|
||||
]
|
||||
)
|
||||
@@ -165,12 +161,12 @@ _AUTO_INPAINT_DECODER_PIPELINES_MAPPING = OrderedDict(
|
||||
)
|
||||
|
||||
if is_sentencepiece_available():
|
||||
from .kolors import KolorsImg2ImgPipeline, KolorsPipeline
|
||||
from .kolors import KolorsPipeline
|
||||
from .pag import KolorsPAGPipeline
|
||||
|
||||
AUTO_TEXT2IMAGE_PIPELINES_MAPPING["kolors"] = KolorsPipeline
|
||||
AUTO_TEXT2IMAGE_PIPELINES_MAPPING["kolors-pag"] = KolorsPAGPipeline
|
||||
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["kolors"] = KolorsImg2ImgPipeline
|
||||
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING["kolors"] = KolorsPipeline
|
||||
|
||||
SUPPORTED_TASKS_MAPPINGS = [
|
||||
AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
|
||||
@@ -957,8 +953,7 @@ class AutoPipelineForInpainting(ConfigMixin):
|
||||
if "enable_pag" in kwargs:
|
||||
enable_pag = kwargs.pop("enable_pag")
|
||||
if enable_pag:
|
||||
to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline"
|
||||
orig_class_name = config["_class_name"].replace(to_replace, "PAG" + to_replace)
|
||||
orig_class_name = config["_class_name"].replace("Pipeline", "PAGPipeline")
|
||||
|
||||
inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name)
|
||||
|
||||
|
||||
@@ -23,7 +23,6 @@ from transformers import T5EncoderModel, T5Tokenizer
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
|
||||
from ...models.embeddings import get_3d_rotary_pos_embed
|
||||
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
||||
from ...utils import BaseOutput, logging, replace_example_docstring
|
||||
@@ -41,7 +40,6 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from diffusers import CogVideoXPipeline
|
||||
>>> from diffusers.utils import export_to_video
|
||||
|
||||
>>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
|
||||
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16).to("cuda")
|
||||
>>> prompt = (
|
||||
... "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
|
||||
@@ -57,25 +55,6 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
|
||||
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
||||
tw = tgt_width
|
||||
th = tgt_height
|
||||
h, w = src
|
||||
r = h / w
|
||||
if r > (th / tw):
|
||||
resize_height = th
|
||||
resize_width = int(round(th / h * w))
|
||||
else:
|
||||
resize_width = tw
|
||||
resize_height = int(round(tw / w * h))
|
||||
|
||||
crop_top = int(round((th - resize_height) / 2.0))
|
||||
crop_left = int(round((tw - resize_width) / 2.0))
|
||||
|
||||
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
@@ -353,11 +332,20 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
||||
def decode_latents(self, latents: torch.Tensor, num_seconds: int):
|
||||
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
|
||||
frames = self.vae.decode(latents).sample
|
||||
frames = []
|
||||
for i in range(num_seconds):
|
||||
start_frame, end_frame = (0, 3) if i == 0 else (2 * i + 1, 2 * i + 3)
|
||||
|
||||
current_frames = self.vae.decode(latents[:, :, start_frame:end_frame]).sample
|
||||
frames.append(current_frames)
|
||||
|
||||
self.vae.clear_fake_context_parallel_cache()
|
||||
|
||||
frames = torch.cat(frames, dim=2)
|
||||
return frames
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
@@ -430,46 +418,6 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
def fuse_qkv_projections(self) -> None:
|
||||
r"""Enables fused QKV projections."""
|
||||
self.fusing_transformer = True
|
||||
self.transformer.fuse_qkv_projections()
|
||||
|
||||
def unfuse_qkv_projections(self) -> None:
|
||||
r"""Disable QKV projection fusion if enabled."""
|
||||
if not self.fusing_transformer:
|
||||
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
|
||||
else:
|
||||
self.transformer.unfuse_qkv_projections()
|
||||
self.fusing_transformer = False
|
||||
|
||||
def _prepare_rotary_positional_embeddings(
|
||||
self,
|
||||
height: int,
|
||||
width: int,
|
||||
num_frames: int,
|
||||
device: torch.device,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
||||
|
||||
grid_crops_coords = get_resize_crop_region_for_grid(
|
||||
(grid_height, grid_width), base_size_width, base_size_height
|
||||
)
|
||||
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
||||
embed_dim=self.transformer.config.attention_head_dim,
|
||||
crops_coords=grid_crops_coords,
|
||||
grid_size=(grid_height, grid_width),
|
||||
temporal_size=num_frames,
|
||||
use_real=True,
|
||||
)
|
||||
|
||||
freqs_cos = freqs_cos.to(device=device)
|
||||
freqs_sin = freqs_sin.to(device=device)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
@@ -490,7 +438,8 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: int = 480,
|
||||
width: int = 720,
|
||||
num_frames: int = 49,
|
||||
num_frames: int = 48,
|
||||
fps: int = 8,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
guidance_scale: float = 6,
|
||||
@@ -585,10 +534,9 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
if num_frames > 49:
|
||||
raise ValueError(
|
||||
"The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
|
||||
)
|
||||
assert (
|
||||
num_frames <= 48 and num_frames % fps == 0 and fps == 8
|
||||
), f"The number of frames must be divisible by {fps=} and less than 48 frames (for now). Other values are not supported in CogVideoX."
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
@@ -645,6 +593,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
|
||||
# 5. Prepare latents.
|
||||
latent_channels = self.transformer.config.in_channels
|
||||
num_frames += 1
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
latent_channels,
|
||||
@@ -660,14 +609,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Create rotary embeds if required
|
||||
image_rotary_emb = (
|
||||
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
||||
if self.transformer.config.use_rotary_positional_embeddings
|
||||
else None
|
||||
)
|
||||
|
||||
# 8. Denoising loop
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
@@ -688,7 +630,6 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred.float()
|
||||
@@ -732,7 +673,7 @@ class CogVideoXPipeline(DiffusionPipeline):
|
||||
progress_bar.update()
|
||||
|
||||
if not output_type == "latent":
|
||||
video = self.decode_latents(latents)
|
||||
video = self.decode_latents(latents, num_frames // fps)
|
||||
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
@@ -1538,6 +1538,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
if isinstance(controlnet_cond_scale, list):
|
||||
controlnet_cond_scale = controlnet_cond_scale[0]
|
||||
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
||||
|
||||
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
||||
control_model_input,
|
||||
t,
|
||||
|
||||
@@ -23,9 +23,6 @@ except OptionalDependencyNotAvailable:
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_stable_diffusion_3_controlnet"] = ["StableDiffusion3ControlNetPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_3_controlnet_inpainting"] = [
|
||||
"StableDiffusion3ControlNetInpaintingPipeline"
|
||||
]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
@@ -36,7 +33,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_stable_diffusion_3_controlnet import StableDiffusion3ControlNetPipeline
|
||||
from .pipeline_stable_diffusion_3_controlnet_inpainting import StableDiffusion3ControlNetInpaintingPipeline
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_flax_available()):
|
||||
|
||||
-1144
File diff suppressed because it is too large
Load Diff
@@ -23,7 +23,6 @@ except OptionalDependencyNotAvailable:
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_flux"] = ["FluxPipeline"]
|
||||
_import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"]
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
@@ -32,7 +31,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_flux import FluxPipeline
|
||||
from .pipeline_flux_controlnet import FluxControlNetPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin
|
||||
from ...loaders import FluxLoraLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.transformers import FluxTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
@@ -137,7 +137,7 @@ def retrieve_timesteps(
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
|
||||
class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
|
||||
r"""
|
||||
The Flux pipeline for text-to-image generation.
|
||||
|
||||
@@ -331,6 +331,10 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_2 = prompt_2 or prompt
|
||||
@@ -360,7 +364,8 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
|
||||
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
||||
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
||||
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
||||
text_ids = text_ids.repeat(num_images_per_prompt, 1, 1)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds, text_ids
|
||||
|
||||
@@ -420,8 +425,9 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
|
||||
|
||||
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||
|
||||
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
|
||||
latent_image_ids = latent_image_ids.reshape(
|
||||
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||
)
|
||||
|
||||
return latent_image_ids.to(device=device, dtype=dtype)
|
||||
@@ -448,35 +454,6 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
|
||||
|
||||
return latents
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.vae.enable_slicing()
|
||||
|
||||
def disable_vae_slicing(self):
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_slicing()
|
||||
|
||||
def enable_vae_tiling(self):
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
"""
|
||||
self.vae.enable_tiling()
|
||||
|
||||
def disable_vae_tiling(self):
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
||||
computing decoding in one step.
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
@@ -700,13 +677,6 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# handle guidance
|
||||
if self.transformer.config.guidance_embeds:
|
||||
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
||||
guidance = guidance.expand(latents.shape[0])
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
@@ -716,8 +686,16 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
# handle guidance
|
||||
if self.transformer.config.guidance_embeds:
|
||||
guidance = torch.tensor([guidance_scale], device=device)
|
||||
guidance = guidance.expand(latents.shape[0])
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
|
||||
@@ -1,854 +0,0 @@
|
||||
# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import (
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
T5EncoderModel,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.controlnet_flux import FluxControlNetModel
|
||||
from ...models.transformers import FluxTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import FluxPipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers.utils import load_image
|
||||
>>> from diffusers import FluxControlNetPipeline
|
||||
>>> from diffusers import FluxControlNetModel
|
||||
|
||||
>>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny-alpha"
|
||||
>>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
|
||||
>>> pipe = FluxControlNetPipeline.from_pretrained(
|
||||
... base_model, controlnet=controlnet, torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
>>> control_image = load_image("https://huggingface.co/InstantX/SD3-Controlnet-Canny/resolve/main/canny.jpg")
|
||||
>>> prompt = "A girl in city, 25 years old, cool, futuristic"
|
||||
>>> image = pipe(
|
||||
... prompt,
|
||||
... control_image=control_image,
|
||||
... controlnet_conditioning_scale=0.6,
|
||||
... num_inference_steps=28,
|
||||
... guidance_scale=3.5,
|
||||
... ).images[0]
|
||||
>>> image.save("flux.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.16,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
|
||||
r"""
|
||||
The Flux pipeline for text-to-image generation.
|
||||
|
||||
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
||||
|
||||
Args:
|
||||
transformer ([`FluxTransformer2DModel`]):
|
||||
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||
text_encoder_2 ([`T5EncoderModel`]):
|
||||
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
||||
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
tokenizer_2 (`T5TokenizerFast`):
|
||||
Second Tokenizer of class
|
||||
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder_2: T5EncoderModel,
|
||||
tokenizer_2: T5TokenizerFast,
|
||||
transformer: FluxTransformer2DModel,
|
||||
controlnet: FluxControlNetModel,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_2=tokenizer_2,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
controlnet=controlnet,
|
||||
)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
|
||||
)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.tokenizer_max_length = (
|
||||
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
||||
)
|
||||
self.default_sample_size = 64
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
max_sequence_length: int = 512,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = self.tokenizer_2(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_length=False,
|
||||
return_overflowing_tokens=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
|
||||
|
||||
dtype = self.text_encoder_2.dtype
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
|
||||
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
def _get_clip_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
num_images_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer_max_length,
|
||||
truncation=True,
|
||||
return_overflowing_tokens=False,
|
||||
return_length=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
||||
)
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
||||
|
||||
# Use pooled output of CLIPTextModel
|
||||
prompt_embeds = prompt_embeds.pooler_output
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
max_sequence_length: int = 512,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
r"""
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
||||
used in all text-encoders
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
lora_scale (`float`, *optional*):
|
||||
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
|
||||
scale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_2 = prompt_2 or prompt
|
||||
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
||||
|
||||
# We only use the pooled prompt output from the CLIPTextModel
|
||||
pooled_prompt_embeds = self._get_clip_prompt_embeds(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
)
|
||||
prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=prompt_2,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if self.text_encoder is not None:
|
||||
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if self.text_encoder_2 is not None:
|
||||
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
||||
|
||||
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
||||
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds, text_ids
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
prompt_2,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=None,
|
||||
pooled_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
max_sequence_length=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt_2 is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
||||
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
||||
|
||||
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
)
|
||||
|
||||
if max_sequence_length is not None and max_sequence_length > 512:
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
|
||||
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
||||
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
||||
|
||||
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||
|
||||
latent_image_ids = latent_image_ids.reshape(
|
||||
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||
)
|
||||
|
||||
return latent_image_ids.to(device=device, dtype=dtype)
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
|
||||
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
||||
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
|
||||
def _unpack_latents(latents, height, width, vae_scale_factor):
|
||||
batch_size, num_patches, channels = latents.shape
|
||||
|
||||
height = height // vae_scale_factor
|
||||
width = width // vae_scale_factor
|
||||
|
||||
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
|
||||
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
||||
|
||||
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
|
||||
|
||||
return latents
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
height = 2 * (int(height) // self.vae_scale_factor)
|
||||
width = 2 * (int(width) // self.vae_scale_factor)
|
||||
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
|
||||
if latents is not None:
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
||||
return latents.to(device=device, dtype=dtype), latent_image_ids
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
||||
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
||||
|
||||
return latents, latent_image_ids
|
||||
|
||||
# Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
|
||||
def prepare_image(
|
||||
self,
|
||||
image,
|
||||
width,
|
||||
height,
|
||||
batch_size,
|
||||
num_images_per_prompt,
|
||||
device,
|
||||
dtype,
|
||||
do_classifier_free_guidance=False,
|
||||
guess_mode=False,
|
||||
):
|
||||
if isinstance(image, torch.Tensor):
|
||||
pass
|
||||
else:
|
||||
image = self.image_processor.preprocess(image, height=height, width=width)
|
||||
|
||||
image_batch_size = image.shape[0]
|
||||
|
||||
if image_batch_size == 1:
|
||||
repeat_by = batch_size
|
||||
else:
|
||||
# image batch size is the same as prompt batch size
|
||||
repeat_by = num_images_per_prompt
|
||||
|
||||
image = image.repeat_interleave(repeat_by, dim=0)
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
if do_classifier_free_guidance and not guess_mode:
|
||||
image = torch.cat([image] * 2)
|
||||
|
||||
return image
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def joint_attention_kwargs(self):
|
||||
return self._joint_attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 28,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 7.0,
|
||||
control_image: PipelineImageInput = None,
|
||||
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
||||
will be used instead
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
||||
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
||||
images.
|
||||
"""
|
||||
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
prompt_2,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._joint_attention_kwargs = joint_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
dtype = self.transformer.dtype
|
||||
|
||||
lora_scale = (
|
||||
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
||||
)
|
||||
(
|
||||
prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
text_ids,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_2=prompt_2,
|
||||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
# 3. Prepare control image
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
if isinstance(self.controlnet, FluxControlNetModel):
|
||||
control_image = self.prepare_image(
|
||||
image=control_image,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
height, width = control_image.shape[-2:]
|
||||
|
||||
# vae encode
|
||||
control_image = self.vae.encode(control_image).latent_dist.sample()
|
||||
control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
|
||||
# pack
|
||||
height_control_image, width_control_image = control_image.shape[2:]
|
||||
control_image = self._pack_latents(
|
||||
control_image,
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height_control_image,
|
||||
width_control_image,
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
latents, latent_image_ids = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
image_seq_len = latents.shape[1]
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.base_image_seq_len,
|
||||
self.scheduler.config.max_image_seq_len,
|
||||
self.scheduler.config.base_shift,
|
||||
self.scheduler.config.max_shift,
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
timesteps,
|
||||
sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
# handle guidance
|
||||
if self.transformer.config.guidance_embeds:
|
||||
guidance = torch.tensor([guidance_scale], device=device)
|
||||
guidance = guidance.expand(latents.shape[0])
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
# controlnet
|
||||
controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
|
||||
hidden_states=latents,
|
||||
controlnet_cond=control_image,
|
||||
conditioning_scale=controlnet_conditioning_scale,
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
timestep=timestep / 1000,
|
||||
guidance=guidance,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
controlnet_block_samples=controlnet_block_samples,
|
||||
controlnet_single_block_samples=controlnet_single_block_samples,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
|
||||
else:
|
||||
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
||||
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return FluxPipelineOutput(images=image)
|
||||
@@ -56,7 +56,7 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from diffusers.utils import export_to_gif
|
||||
|
||||
>>> # You can replace the checkpoint id with "maxin-cn/Latte-1" too.
|
||||
>>> pipe = LattePipeline.from_pretrained("maxin-cn/Latte-1", torch_dtype=torch.float16)
|
||||
>>> pipe = LattePipeline.from_pretrained("maxin-cn/Latte-1", torch_dtype=torch.float16).to("cuda")
|
||||
>>> # Enable memory optimizations.
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ EXAMPLE_DOC_STRING = """
|
||||
|
||||
>>> pipe = LuminaText2ImgPipeline.from_pretrained(
|
||||
... "Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
... ).cuda()
|
||||
>>> # Enable memory optimizations.
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
|
||||
|
||||
@@ -24,7 +24,6 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["pipeline_pag_controlnet_sd"] = ["StableDiffusionControlNetPAGPipeline"]
|
||||
_import_structure["pipeline_pag_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPAGPipeline"]
|
||||
_import_structure["pipeline_pag_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetPAGImg2ImgPipeline"]
|
||||
_import_structure["pipeline_pag_hunyuandit"] = ["HunyuanDiTPAGPipeline"]
|
||||
_import_structure["pipeline_pag_kolors"] = ["KolorsPAGPipeline"]
|
||||
_import_structure["pipeline_pag_pixart_sigma"] = ["PixArtSigmaPAGPipeline"]
|
||||
@@ -45,7 +44,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .pipeline_pag_controlnet_sd import StableDiffusionControlNetPAGPipeline
|
||||
from .pipeline_pag_controlnet_sd_xl import StableDiffusionXLControlNetPAGPipeline
|
||||
from .pipeline_pag_controlnet_sd_xl_img2img import StableDiffusionXLControlNetPAGImg2ImgPipeline
|
||||
from .pipeline_pag_hunyuandit import HunyuanDiTPAGPipeline
|
||||
from .pipeline_pag_kolors import KolorsPAGPipeline
|
||||
from .pipeline_pag_pixart_sigma import PixArtSigmaPAGPipeline
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1471,14 +1471,6 @@ class StableDiffusionXLPAGInpaintPipeline(
|
||||
generator,
|
||||
self.do_classifier_free_guidance,
|
||||
)
|
||||
if self.do_perturbed_attention_guidance:
|
||||
if self.do_classifier_free_guidance:
|
||||
mask, _ = mask.chunk(2)
|
||||
masked_image_latents, _ = masked_image_latents.chunk(2)
|
||||
mask = self._prepare_perturbed_attention_guidance(mask, mask, self.do_classifier_free_guidance)
|
||||
masked_image_latents = self._prepare_perturbed_attention_guidance(
|
||||
masked_image_latents, masked_image_latents, self.do_classifier_free_guidance
|
||||
)
|
||||
|
||||
# 8. Check that sizes of mask, masked image and latents match
|
||||
if num_channels_unet == 9:
|
||||
@@ -1667,10 +1659,10 @@ class StableDiffusionXLPAGInpaintPipeline(
|
||||
|
||||
if num_channels_unet == 4:
|
||||
init_latents_proper = image_latents
|
||||
if self.do_perturbed_attention_guidance:
|
||||
init_mask, *_ = mask.chunk(3) if self.do_classifier_free_guidance else mask.chunk(2)
|
||||
if self.do_classifier_free_guidance:
|
||||
init_mask, _ = mask.chunk(2)
|
||||
else:
|
||||
init_mask, *_ = mask.chunk(2) if self.do_classifier_free_guidance else mask
|
||||
init_mask = mask
|
||||
|
||||
if i < len(timesteps) - 1:
|
||||
noise_timestep = timesteps[i + 1]
|
||||
|
||||
@@ -22,7 +22,7 @@ from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from huggingface_hub import ModelCard, model_info
|
||||
from huggingface_hub import model_info
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from packaging import version
|
||||
|
||||
@@ -33,7 +33,6 @@ from ..utils import (
|
||||
ONNX_WEIGHTS_NAME,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
deprecate,
|
||||
get_class_from_dynamic_module,
|
||||
is_accelerate_available,
|
||||
is_peft_available,
|
||||
@@ -90,46 +89,49 @@ for library in LOADABLE_CLASSES:
|
||||
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
|
||||
|
||||
|
||||
def is_safetensors_compatible(filenames, passed_components=None, folder_names=None) -> bool:
|
||||
def is_safetensors_compatible(filenames, variant=None, passed_components=None) -> bool:
|
||||
"""
|
||||
Checking for safetensors compatibility:
|
||||
- The model is safetensors compatible only if there is a safetensors file for each model component present in
|
||||
filenames.
|
||||
- By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch
|
||||
files to know which safetensors files are needed.
|
||||
- The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file.
|
||||
|
||||
Converting default pytorch serialized filenames to safetensors serialized filenames:
|
||||
- For models from the diffusers library, just replace the ".bin" extension with ".safetensors"
|
||||
- For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin"
|
||||
extension is replaced with ".safetensors"
|
||||
"""
|
||||
pt_filenames = []
|
||||
|
||||
sf_filenames = set()
|
||||
|
||||
passed_components = passed_components or []
|
||||
if folder_names is not None:
|
||||
filenames = {f for f in filenames if os.path.split(f)[0] in folder_names}
|
||||
|
||||
# extract all components of the pipeline and their associated files
|
||||
components = {}
|
||||
for filename in filenames:
|
||||
if not len(filename.split("/")) == 2:
|
||||
_, extension = os.path.splitext(filename)
|
||||
|
||||
if len(filename.split("/")) == 2 and filename.split("/")[0] in passed_components:
|
||||
continue
|
||||
|
||||
component, component_filename = filename.split("/")
|
||||
if component in passed_components:
|
||||
continue
|
||||
if extension == ".bin":
|
||||
pt_filenames.append(os.path.normpath(filename))
|
||||
elif extension == ".safetensors":
|
||||
sf_filenames.add(os.path.normpath(filename))
|
||||
|
||||
components.setdefault(component, [])
|
||||
components[component].append(component_filename)
|
||||
for filename in pt_filenames:
|
||||
# filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extension = '.bam'
|
||||
path, filename = os.path.split(filename)
|
||||
filename, extension = os.path.splitext(filename)
|
||||
|
||||
# iterate over all files of a component
|
||||
# check if safetensor files exist for that component
|
||||
# if variant is provided check if the variant of the safetensors exists
|
||||
for component, component_filenames in components.items():
|
||||
matches = []
|
||||
for component_filename in component_filenames:
|
||||
filename, extension = os.path.splitext(component_filename)
|
||||
if filename.startswith("pytorch_model"):
|
||||
filename = filename.replace("pytorch_model", "model")
|
||||
else:
|
||||
filename = filename
|
||||
|
||||
match_exists = extension == ".safetensors"
|
||||
matches.append(match_exists)
|
||||
|
||||
if not any(matches):
|
||||
expected_sf_filename = os.path.normpath(os.path.join(path, filename))
|
||||
expected_sf_filename = f"{expected_sf_filename}.safetensors"
|
||||
if expected_sf_filename not in sf_filenames:
|
||||
logger.warning(f"{expected_sf_filename} not found")
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -747,92 +749,3 @@ def _fetch_class_library_tuple(module):
|
||||
class_name = not_compiled_module.__class__.__name__
|
||||
|
||||
return (library, class_name)
|
||||
|
||||
|
||||
def _identify_model_variants(folder: str, variant: str, config: dict) -> dict:
|
||||
model_variants = {}
|
||||
if variant is not None:
|
||||
for sub_folder in os.listdir(folder):
|
||||
folder_path = os.path.join(folder, sub_folder)
|
||||
is_folder = os.path.isdir(folder_path) and sub_folder in config
|
||||
variant_exists = is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path))
|
||||
if variant_exists:
|
||||
model_variants[sub_folder] = variant
|
||||
return model_variants
|
||||
|
||||
|
||||
def _resolve_custom_pipeline_and_cls(folder, config, custom_pipeline):
|
||||
custom_class_name = None
|
||||
if os.path.isfile(os.path.join(folder, f"{custom_pipeline}.py")):
|
||||
custom_pipeline = os.path.join(folder, f"{custom_pipeline}.py")
|
||||
elif isinstance(config["_class_name"], (list, tuple)) and os.path.isfile(
|
||||
os.path.join(folder, f"{config['_class_name'][0]}.py")
|
||||
):
|
||||
custom_pipeline = os.path.join(folder, f"{config['_class_name'][0]}.py")
|
||||
custom_class_name = config["_class_name"][1]
|
||||
|
||||
return custom_pipeline, custom_class_name
|
||||
|
||||
|
||||
def _maybe_raise_warning_for_inpainting(pipeline_class, pretrained_model_name_or_path: str, config: dict):
|
||||
if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
|
||||
version.parse(config["_diffusers_version"]).base_version
|
||||
) <= version.parse("0.5.1"):
|
||||
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy
|
||||
|
||||
pipeline_class = StableDiffusionInpaintPipelineLegacy
|
||||
|
||||
deprecation_message = (
|
||||
"You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
|
||||
f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For"
|
||||
" better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
|
||||
" checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
|
||||
f" checkpoint {pretrained_model_name_or_path} to the format of"
|
||||
" https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
|
||||
" the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
|
||||
)
|
||||
deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
|
||||
def _update_init_kwargs_with_connected_pipeline(
|
||||
init_kwargs: dict, passed_pipe_kwargs: dict, passed_class_objs: dict, folder: str, **pipeline_loading_kwargs
|
||||
) -> dict:
|
||||
from .pipeline_utils import DiffusionPipeline
|
||||
|
||||
modelcard = ModelCard.load(os.path.join(folder, "README.md"))
|
||||
connected_pipes = {prefix: getattr(modelcard.data, prefix, [None])[0] for prefix in CONNECTED_PIPES_KEYS}
|
||||
|
||||
# We don't scheduler argument to match the existing logic:
|
||||
# https://github.com/huggingface/diffusers/blob/867e0c919e1aa7ef8b03c8eb1460f4f875a683ae/src/diffusers/pipelines/pipeline_utils.py#L906C13-L925C14
|
||||
pipeline_loading_kwargs_cp = pipeline_loading_kwargs.copy()
|
||||
if pipeline_loading_kwargs_cp is not None and len(pipeline_loading_kwargs_cp) >= 1:
|
||||
for k in pipeline_loading_kwargs:
|
||||
if "scheduler" in k:
|
||||
_ = pipeline_loading_kwargs_cp.pop(k)
|
||||
|
||||
def get_connected_passed_kwargs(prefix):
|
||||
connected_passed_class_obj = {
|
||||
k.replace(f"{prefix}_", ""): w for k, w in passed_class_objs.items() if k.split("_")[0] == prefix
|
||||
}
|
||||
connected_passed_pipe_kwargs = {
|
||||
k.replace(f"{prefix}_", ""): w for k, w in passed_pipe_kwargs.items() if k.split("_")[0] == prefix
|
||||
}
|
||||
|
||||
connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs}
|
||||
return connected_passed_kwargs
|
||||
|
||||
connected_pipes = {
|
||||
prefix: DiffusionPipeline.from_pretrained(
|
||||
repo_id, **pipeline_loading_kwargs_cp, **get_connected_passed_kwargs(prefix)
|
||||
)
|
||||
for prefix, repo_id in connected_pipes.items()
|
||||
if repo_id is not None
|
||||
}
|
||||
|
||||
for prefix, connected_pipe in connected_pipes.items():
|
||||
# add connected pipes to `init_kwargs` with <prefix>_<component_name>, e.g. "prior_text_encoder"
|
||||
init_kwargs.update(
|
||||
{"_".join([prefix, name]): component for name, component in connected_pipe.components.items()}
|
||||
)
|
||||
|
||||
return init_kwargs
|
||||
|
||||
@@ -19,6 +19,7 @@ import inspect
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin
|
||||
@@ -75,11 +76,7 @@ from .pipeline_loading_utils import (
|
||||
_get_custom_pipeline_class,
|
||||
_get_final_device_map,
|
||||
_get_pipeline_class,
|
||||
_identify_model_variants,
|
||||
_maybe_raise_warning_for_inpainting,
|
||||
_resolve_custom_pipeline_and_cls,
|
||||
_unwrap_model,
|
||||
_update_init_kwargs_with_connected_pipeline,
|
||||
is_safetensors_compatible,
|
||||
load_sub_model,
|
||||
maybe_raise_or_warn,
|
||||
@@ -626,9 +623,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
>>> pipeline.scheduler = scheduler
|
||||
```
|
||||
"""
|
||||
# Copy the kwargs to re-use during loading connected pipeline.
|
||||
kwargs_copied = kwargs.copy()
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
@@ -729,19 +723,33 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
config_dict.pop("_ignore_files", None)
|
||||
|
||||
# 2. Define which model components should load variants
|
||||
# We retrieve the information by matching whether variant model checkpoints exist in the subfolders.
|
||||
# Example: `diffusion_pytorch_model.safetensors` -> `diffusion_pytorch_model.fp16.safetensors`
|
||||
# with variant being `"fp16"`.
|
||||
model_variants = _identify_model_variants(folder=cached_folder, variant=variant, config=config_dict)
|
||||
# We retrieve the information by matching whether variant
|
||||
# model checkpoints exist in the subfolders
|
||||
model_variants = {}
|
||||
if variant is not None:
|
||||
for folder in os.listdir(cached_folder):
|
||||
folder_path = os.path.join(cached_folder, folder)
|
||||
is_folder = os.path.isdir(folder_path) and folder in config_dict
|
||||
variant_exists = is_folder and any(
|
||||
p.split(".")[1].startswith(variant) for p in os.listdir(folder_path)
|
||||
)
|
||||
if variant_exists:
|
||||
model_variants[folder] = variant
|
||||
|
||||
# 3. Load the pipeline class, if using custom module then load it from the hub
|
||||
# if we load from explicit class, let's use it
|
||||
custom_pipeline, custom_class_name = _resolve_custom_pipeline_and_cls(
|
||||
folder=cached_folder, config=config_dict, custom_pipeline=custom_pipeline
|
||||
)
|
||||
custom_class_name = None
|
||||
if os.path.isfile(os.path.join(cached_folder, f"{custom_pipeline}.py")):
|
||||
custom_pipeline = os.path.join(cached_folder, f"{custom_pipeline}.py")
|
||||
elif isinstance(config_dict["_class_name"], (list, tuple)) and os.path.isfile(
|
||||
os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
|
||||
):
|
||||
custom_pipeline = os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
|
||||
custom_class_name = config_dict["_class_name"][1]
|
||||
|
||||
pipeline_class = _get_pipeline_class(
|
||||
cls,
|
||||
config=config_dict,
|
||||
config_dict,
|
||||
load_connected_pipeline=load_connected_pipeline,
|
||||
custom_pipeline=custom_pipeline,
|
||||
class_name=custom_class_name,
|
||||
@@ -753,13 +761,23 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
raise NotImplementedError("`device_map` is not yet supported for connected pipelines.")
|
||||
|
||||
# DEPRECATED: To be removed in 1.0.0
|
||||
# we are deprecating the `StableDiffusionInpaintPipelineLegacy` pipeline which gets loaded
|
||||
# when a user requests for a `StableDiffusionInpaintPipeline` with `diffusers` version being <= 0.5.1.
|
||||
_maybe_raise_warning_for_inpainting(
|
||||
pipeline_class=pipeline_class,
|
||||
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
config=config_dict,
|
||||
)
|
||||
if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
|
||||
version.parse(config_dict["_diffusers_version"]).base_version
|
||||
) <= version.parse("0.5.1"):
|
||||
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy
|
||||
|
||||
pipeline_class = StableDiffusionInpaintPipelineLegacy
|
||||
|
||||
deprecation_message = (
|
||||
"You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
|
||||
f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For"
|
||||
" better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
|
||||
" checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
|
||||
f" checkpoint {pretrained_model_name_or_path} to the format of"
|
||||
" https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
|
||||
" the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
|
||||
)
|
||||
deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
# 4. Define expected modules given pipeline signature
|
||||
# and define non-None initialized modules (=`init_kwargs`)
|
||||
@@ -770,6 +788,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
|
||||
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
||||
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
|
||||
|
||||
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
|
||||
# define init kwargs and make sure that optional component modules are filtered out
|
||||
@@ -829,7 +848,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
# 7. Load each module in the pipeline
|
||||
current_device_map = None
|
||||
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
|
||||
# 7.1 device_map shenanigans
|
||||
if final_device_map is not None and len(final_device_map) > 0:
|
||||
component_device = final_device_map.get(name, None)
|
||||
if component_device is not None:
|
||||
@@ -837,15 +855,15 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
else:
|
||||
current_device_map = None
|
||||
|
||||
# 7.2 - now that JAX/Flax is an official framework of the library, we might load from Flax names
|
||||
# 7.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
|
||||
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
|
||||
|
||||
# 7.3 Define all importable classes
|
||||
# 7.2 Define all importable classes
|
||||
is_pipeline_module = hasattr(pipelines, library_name)
|
||||
importable_classes = ALL_IMPORTABLE_CLASSES
|
||||
loaded_sub_model = None
|
||||
|
||||
# 7.4 Use passed sub model or load class_name from library_name
|
||||
# 7.3 Use passed sub model or load class_name from library_name
|
||||
if name in passed_class_obj:
|
||||
# if the model is in a pipeline module, then we load it from the pipeline
|
||||
# check that passed_class_obj has correct parent class
|
||||
@@ -883,17 +901,56 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
|
||||
|
||||
# 8. Handle connected pipelines.
|
||||
if pipeline_class._load_connected_pipes and os.path.isfile(os.path.join(cached_folder, "README.md")):
|
||||
init_kwargs = _update_init_kwargs_with_connected_pipeline(
|
||||
init_kwargs=init_kwargs,
|
||||
passed_pipe_kwargs=passed_pipe_kwargs,
|
||||
passed_class_objs=passed_class_obj,
|
||||
folder=cached_folder,
|
||||
**kwargs_copied,
|
||||
)
|
||||
modelcard = ModelCard.load(os.path.join(cached_folder, "README.md"))
|
||||
connected_pipes = {prefix: getattr(modelcard.data, prefix, [None])[0] for prefix in CONNECTED_PIPES_KEYS}
|
||||
load_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"force_download": force_download,
|
||||
"proxies": proxies,
|
||||
"local_files_only": local_files_only,
|
||||
"token": token,
|
||||
"revision": revision,
|
||||
"torch_dtype": torch_dtype,
|
||||
"custom_pipeline": custom_pipeline,
|
||||
"custom_revision": custom_revision,
|
||||
"provider": provider,
|
||||
"sess_options": sess_options,
|
||||
"device_map": device_map,
|
||||
"max_memory": max_memory,
|
||||
"offload_folder": offload_folder,
|
||||
"offload_state_dict": offload_state_dict,
|
||||
"low_cpu_mem_usage": low_cpu_mem_usage,
|
||||
"variant": variant,
|
||||
"use_safetensors": use_safetensors,
|
||||
}
|
||||
|
||||
# 9. Potentially add passed objects if expected
|
||||
def get_connected_passed_kwargs(prefix):
|
||||
connected_passed_class_obj = {
|
||||
k.replace(f"{prefix}_", ""): w for k, w in passed_class_obj.items() if k.split("_")[0] == prefix
|
||||
}
|
||||
connected_passed_pipe_kwargs = {
|
||||
k.replace(f"{prefix}_", ""): w for k, w in passed_pipe_kwargs.items() if k.split("_")[0] == prefix
|
||||
}
|
||||
|
||||
connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs}
|
||||
return connected_passed_kwargs
|
||||
|
||||
connected_pipes = {
|
||||
prefix: DiffusionPipeline.from_pretrained(
|
||||
repo_id, **load_kwargs.copy(), **get_connected_passed_kwargs(prefix)
|
||||
)
|
||||
for prefix, repo_id in connected_pipes.items()
|
||||
if repo_id is not None
|
||||
}
|
||||
|
||||
for prefix, connected_pipe in connected_pipes.items():
|
||||
# add connected pipes to `init_kwargs` with <prefix>_<component_name>, e.g. "prior_text_encoder"
|
||||
init_kwargs.update(
|
||||
{"_".join([prefix, name]): component for name, component in connected_pipe.components.items()}
|
||||
)
|
||||
|
||||
# 8. Potentially add passed objects if expected
|
||||
missing_modules = set(expected_modules) - set(init_kwargs.keys())
|
||||
passed_modules = list(passed_class_obj.keys())
|
||||
optional_modules = pipeline_class._optional_components
|
||||
@@ -1116,6 +1173,93 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
component.to("cpu")
|
||||
self.hf_device_map = None
|
||||
|
||||
def enable_dynamic_upcasting(
|
||||
self,
|
||||
components: Optional[List[str]] = None,
|
||||
upcast_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Enable module-wise dynamic upcasting. This allows models to be loaded into the GPU in a low memory dtype e.g.
|
||||
torch.float8_e4m3fn, but perform inference using a dtype that is supported on the GPU, by casting the module to
|
||||
the appropriate dtype right before the foward pass. The module is then moved back to the low memory dtype after
|
||||
the foward pass.
|
||||
|
||||
"""
|
||||
if components is None:
|
||||
raise ValueError("Please provide a list of pipeline component names to apply dynamic upcasting")
|
||||
|
||||
def fn_recursive_upcast(module, dtype, original_dtype, keep_in_fp32_modules):
|
||||
has_children = list(module.children())
|
||||
upcast_dtype = dtype
|
||||
downcast_dtype = original_dtype
|
||||
|
||||
def upcast_hook_fn(module, inputs):
|
||||
module = module.to(upcast_dtype)
|
||||
|
||||
def downcast_hook_fn(module, *args, **kwargs):
|
||||
module = module.to(downcast_dtype)
|
||||
|
||||
if not has_children:
|
||||
module.register_forward_pre_hook(upcast_hook_fn)
|
||||
module.register_forward_hook(downcast_hook_fn)
|
||||
|
||||
for name, child in module.named_children():
|
||||
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules):
|
||||
dtype = torch.float32
|
||||
else:
|
||||
dtype = upcast_dtype
|
||||
|
||||
fn_recursive_upcast(child, dtype, original_dtype, keep_in_fp32_modules)
|
||||
|
||||
for component in components:
|
||||
if not hasattr(self, component):
|
||||
raise ValueError(f"Pipeline has no component named: {component}")
|
||||
|
||||
component_module = getattr(self, component)
|
||||
if not isinstance(component_module, torch.nn.Module):
|
||||
raise ValueError(
|
||||
f"Pipeline component: {component} is not a torch.nn.Module. Cannot apply dynamic upcasting."
|
||||
)
|
||||
|
||||
use_keep_in_fp32_modules = (
|
||||
hasattr(component_module, "_keep_in_fp32_modules")
|
||||
and (component_module._keep_in_fp32_modules is not None)
|
||||
and (upcast_dtype != torch.float32)
|
||||
)
|
||||
if use_keep_in_fp32_modules:
|
||||
keep_in_fp32_modules = component_module._keep_in_fp32_modules
|
||||
else:
|
||||
keep_in_fp32_modules = []
|
||||
|
||||
original_dtype = component_module.dtype
|
||||
for name, module in component_module.named_children():
|
||||
fn_recursive_upcast(module, upcast_dtype, original_dtype, keep_in_fp32_modules)
|
||||
|
||||
def disable_dynamic_upcasting(
|
||||
self,
|
||||
):
|
||||
def fn_recursive_upcast(module):
|
||||
has_children = list(module.children())
|
||||
if not has_children:
|
||||
module._forward_pre_hooks = OrderedDict()
|
||||
module._forward_hooks = OrderedDict()
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_upcast(child)
|
||||
|
||||
for component in self.components:
|
||||
if not hasattr(self, component):
|
||||
raise ValueError(f"Pipeline has no component named: {component}")
|
||||
|
||||
component_module = getattr(self, component)
|
||||
if not issubclass(component_module, torch.nn.Module):
|
||||
raise ValueError(
|
||||
f"Pipeline component: {component} is not an torch.nn.Module. Cannot apply dynamic upcasting."
|
||||
)
|
||||
|
||||
for module in component_module.children():
|
||||
fn_recursive_upcast(module)
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
|
||||
@@ -1361,7 +1505,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
use_safetensors
|
||||
and not allow_pickle
|
||||
and not is_safetensors_compatible(
|
||||
model_filenames, passed_components=passed_components, folder_names=model_folder_names
|
||||
model_filenames, variant=variant, passed_components=passed_components
|
||||
)
|
||||
):
|
||||
raise EnvironmentError(
|
||||
@@ -1370,7 +1514,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
if from_flax:
|
||||
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
|
||||
elif use_safetensors and is_safetensors_compatible(
|
||||
model_filenames, passed_components=passed_components, folder_names=model_folder_names
|
||||
model_filenames, variant=variant, passed_components=passed_components
|
||||
):
|
||||
ignore_patterns = ["*.bin", "*.msgpack"]
|
||||
|
||||
|
||||
@@ -281,16 +281,6 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
def get_timestep_ratio_conditioning(self, t, alphas_cumprod):
|
||||
s = torch.tensor([0.008])
|
||||
clamp_range = [0, 1]
|
||||
min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2
|
||||
var = alphas_cumprod[t]
|
||||
var = var.clamp(*clamp_range)
|
||||
s, min_var = s.to(var.device), min_var.to(var.device)
|
||||
ratio = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s
|
||||
return ratio
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -444,30 +434,10 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
|
||||
batch_size, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, self.scheduler
|
||||
)
|
||||
|
||||
if isinstance(self.scheduler, DDPMWuerstchenScheduler):
|
||||
timesteps = timesteps[:-1]
|
||||
else:
|
||||
if hasattr(self.scheduler.config, "clip_sample") and self.scheduler.config.clip_sample:
|
||||
self.scheduler.config.clip_sample = False # disample sample clipping
|
||||
logger.warning(" set `clip_sample` to be False")
|
||||
|
||||
# 6. Run denoising loop
|
||||
if hasattr(self.scheduler, "betas"):
|
||||
alphas = 1.0 - self.scheduler.betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
else:
|
||||
alphas_cumprod = []
|
||||
|
||||
self._num_timesteps = len(timesteps)
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
if not isinstance(self.scheduler, DDPMWuerstchenScheduler):
|
||||
if len(alphas_cumprod) > 0:
|
||||
timestep_ratio = self.get_timestep_ratio_conditioning(t.long().cpu(), alphas_cumprod)
|
||||
timestep_ratio = timestep_ratio.expand(latents.size(0)).to(dtype).to(device)
|
||||
else:
|
||||
timestep_ratio = t.float().div(self.scheduler.timesteps[-1]).expand(latents.size(0)).to(dtype)
|
||||
else:
|
||||
timestep_ratio = t.expand(latents.size(0)).to(dtype)
|
||||
self._num_timesteps = len(timesteps[:-1])
|
||||
for i, t in enumerate(self.progress_bar(timesteps[:-1])):
|
||||
timestep_ratio = t.expand(latents.size(0)).to(dtype)
|
||||
|
||||
# 7. Denoise latents
|
||||
predicted_latents = self.decoder(
|
||||
@@ -484,8 +454,6 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
|
||||
predicted_latents = torch.lerp(predicted_latents_uncond, predicted_latents_text, self.guidance_scale)
|
||||
|
||||
# 9. Renoise latents to next timestep
|
||||
if not isinstance(self.scheduler, DDPMWuerstchenScheduler):
|
||||
timestep_ratio = t
|
||||
latents = self.scheduler.step(
|
||||
model_output=predicted_latents,
|
||||
timestep=timestep_ratio,
|
||||
|
||||
@@ -353,7 +353,7 @@ class StableCascadePriorPipeline(DiffusionPipeline):
|
||||
return self._num_timesteps
|
||||
|
||||
def get_timestep_ratio_conditioning(self, t, alphas_cumprod):
|
||||
s = torch.tensor([0.008])
|
||||
s = torch.tensor([0.003])
|
||||
clamp_range = [0, 1]
|
||||
min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2
|
||||
var = alphas_cumprod[t]
|
||||
@@ -557,7 +557,7 @@ class StableCascadePriorPipeline(DiffusionPipeline):
|
||||
if isinstance(self.scheduler, DDPMWuerstchenScheduler):
|
||||
timesteps = timesteps[:-1]
|
||||
else:
|
||||
if hasattr(self.scheduler.config, "clip_sample") and self.scheduler.config.clip_sample:
|
||||
if self.scheduler.config.clip_sample:
|
||||
self.scheduler.config.clip_sample = False # disample sample clipping
|
||||
logger.warning(" set `clip_sample` to be False")
|
||||
# 6. Run denoising loop
|
||||
|
||||
+81
-241
@@ -33,20 +33,6 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffu
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.preprocess
|
||||
def preprocess(image):
|
||||
warnings.warn(
|
||||
@@ -119,54 +105,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic")
|
||||
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
|
||||
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds])
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
):
|
||||
def _encode_prompt(self, prompt, device, do_classifier_free_guidance, negative_prompt):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
@@ -180,100 +119,81 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
|
||||
negative_prompt (`str` or `List[str]`):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
pooled_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
||||
input argument.
|
||||
"""
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
|
||||
if prompt_embeds is None or pooled_prompt_embeds is None:
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_length=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
text_encoder_out = self.text_encoder(
|
||||
text_input_ids.to(device),
|
||||
output_hidden_states=True,
|
||||
)
|
||||
text_embeddings = text_encoder_out.hidden_states[-1]
|
||||
text_pooler_out = text_encoder_out.pooler_output
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
max_length = text_input_ids.shape[-1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_length=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
text_encoder_out = self.text_encoder(
|
||||
text_input_ids.to(device),
|
||||
uncond_encoder_out = self.text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
output_hidden_states=True,
|
||||
)
|
||||
prompt_embeds = text_encoder_out.hidden_states[-1]
|
||||
pooled_prompt_embeds = text_encoder_out.pooler_output
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance:
|
||||
if negative_prompt_embeds is None or negative_pooled_prompt_embeds is None:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
uncond_embeddings = uncond_encoder_out.hidden_states[-1]
|
||||
uncond_pooler_out = uncond_encoder_out.pooler_output
|
||||
|
||||
max_length = text_input_ids.shape[-1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_length=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
text_pooler_out = torch.cat([uncond_pooler_out, text_pooler_out])
|
||||
|
||||
uncond_encoder_out = self.text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
negative_prompt_embeds = uncond_encoder_out.hidden_states[-1]
|
||||
negative_pooled_prompt_embeds = uncond_encoder_out.pooler_output
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
return text_embeddings, text_pooler_out
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
@@ -287,56 +207,12 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return image
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
image,
|
||||
callback_steps,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
pooled_prompt_embeds=None,
|
||||
negative_pooled_prompt_embeds=None,
|
||||
):
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and not isinstance(prompt, str) and not isinstance(prompt, list):
|
||||
def check_inputs(self, prompt, image, callback_steps):
|
||||
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
)
|
||||
|
||||
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
if (
|
||||
not isinstance(image, torch.Tensor)
|
||||
and not isinstance(image, np.ndarray)
|
||||
and not isinstance(image, PIL.Image.Image)
|
||||
and not isinstance(image, list)
|
||||
):
|
||||
@@ -346,14 +222,10 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
|
||||
|
||||
# verify batch size of prompt and image are same if image is a list or tensor
|
||||
if isinstance(image, (list, torch.Tensor)):
|
||||
if prompt is not None:
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
else:
|
||||
batch_size = len(prompt)
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
batch_size = len(prompt)
|
||||
if isinstance(image, list):
|
||||
image_batch_size = len(image)
|
||||
else:
|
||||
@@ -389,17 +261,13 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt: Union[str, List[str]],
|
||||
image: PipelineImageInput = None,
|
||||
num_inference_steps: int = 75,
|
||||
guidance_scale: float = 9.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
|
||||
@@ -491,22 +359,10 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
|
||||
"""
|
||||
|
||||
# 1. Check inputs
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
image,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
)
|
||||
self.check_inputs(prompt, image, callback_steps)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None:
|
||||
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
@@ -517,32 +373,16 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
|
||||
prompt = [""] * batch_size
|
||||
|
||||
# 3. Encode input prompt
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
text_embeddings, text_pooler_out = self._encode_prompt(
|
||||
prompt, device, do_classifier_free_guidance, negative_prompt
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds])
|
||||
|
||||
# 4. Preprocess image
|
||||
image = self.image_processor.preprocess(image)
|
||||
image = image.to(dtype=prompt_embeds.dtype, device=device)
|
||||
image = image.to(dtype=text_embeddings.dtype, device=device)
|
||||
if image.shape[1] == 3:
|
||||
# encode image if not in latent-space yet
|
||||
image = retrieve_latents(self.vae.encode(image), generator=generator) * self.vae.config.scaling_factor
|
||||
image = self.vae.encode(image).latent_dist.sample() * self.vae.config.scaling_factor
|
||||
|
||||
# 5. set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
@@ -560,17 +400,17 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
|
||||
inv_noise_level = (noise_level**2 + 1) ** (-0.5)
|
||||
|
||||
image_cond = F.interpolate(image, scale_factor=2, mode="nearest") * inv_noise_level[:, None, None, None]
|
||||
image_cond = image_cond.to(prompt_embeds.dtype)
|
||||
image_cond = image_cond.to(text_embeddings.dtype)
|
||||
|
||||
noise_level_embed = torch.cat(
|
||||
[
|
||||
torch.ones(pooled_prompt_embeds.shape[0], 64, dtype=pooled_prompt_embeds.dtype, device=device),
|
||||
torch.zeros(pooled_prompt_embeds.shape[0], 64, dtype=pooled_prompt_embeds.dtype, device=device),
|
||||
torch.ones(text_pooler_out.shape[0], 64, dtype=text_pooler_out.dtype, device=device),
|
||||
torch.zeros(text_pooler_out.shape[0], 64, dtype=text_pooler_out.dtype, device=device),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
timestep_condition = torch.cat([noise_level_embed, pooled_prompt_embeds], dim=1)
|
||||
timestep_condition = torch.cat([noise_level_embed, text_pooler_out], dim=1)
|
||||
|
||||
# 6. Prepare latent variables
|
||||
height, width = image.shape[2:]
|
||||
@@ -580,7 +420,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
|
||||
num_channels_latents,
|
||||
height * 2, # 2x upscale
|
||||
width * 2,
|
||||
prompt_embeds.dtype,
|
||||
text_embeddings.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
@@ -614,7 +454,7 @@ class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMix
|
||||
noise_pred = self.unet(
|
||||
scaled_model_input,
|
||||
timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_hidden_states=text_embeddings,
|
||||
timestep_cond=timestep_condition,
|
||||
).sample
|
||||
|
||||
|
||||
+1
-1
@@ -602,9 +602,9 @@ class StableDiffusionKDiffusionPipeline(
|
||||
sigma_min: float = self.k_diffusion_model.sigmas[0].item()
|
||||
sigma_max: float = self.k_diffusion_model.sigmas[-1].item()
|
||||
sigmas = get_sigmas_karras(n=num_inference_steps, sigma_min=sigma_min, sigma_max=sigma_max)
|
||||
sigmas = sigmas.to(device)
|
||||
else:
|
||||
sigmas = self.scheduler.sigmas
|
||||
sigmas = sigmas.to(device)
|
||||
sigmas = sigmas.to(prompt_embeds.dtype)
|
||||
|
||||
# 6. Prepare latent variables
|
||||
|
||||
@@ -182,21 +182,6 @@ class DiTTransformer2DModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class FluxControlNetModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class FluxTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -287,21 +287,6 @@ class CycleDiffusionPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class FluxControlNetPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class FluxPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
@@ -1637,21 +1622,6 @@ class StableDiffusionXLControlNetInpaintPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetPAGImg2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetPAGPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Callable, List, Optional, Union
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
import PIL.Image
|
||||
import PIL.ImageOps
|
||||
@@ -81,22 +80,12 @@ def load_video(
|
||||
)
|
||||
|
||||
if is_url:
|
||||
response = requests.get(video, stream=True)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to download video. Status code: {response.status_code}")
|
||||
|
||||
parsed_url = urlparse(video)
|
||||
file_name = os.path.basename(unquote(parsed_url.path))
|
||||
|
||||
suffix = os.path.splitext(file_name)[1] or ".mp4"
|
||||
video_data = requests.get(video, stream=True).raw
|
||||
suffix = os.path.splitext(video)[1] or ".mp4"
|
||||
video_path = tempfile.NamedTemporaryFile(suffix=suffix, delete=False).name
|
||||
|
||||
was_tempfile_created = True
|
||||
|
||||
video_data = response.iter_content(chunk_size=8192)
|
||||
with open(video_path, "wb") as f:
|
||||
for chunk in video_data:
|
||||
f.write(chunk)
|
||||
f.write(video_data.read())
|
||||
|
||||
video = video_path
|
||||
|
||||
|
||||
@@ -12,26 +12,19 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
|
||||
from diffusers.utils.testing_utils import floats_tensor, is_peft_available, require_peft_backend, torch_device
|
||||
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
from utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
|
||||
from utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
@@ -97,51 +90,3 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_inputs.update({"generator": generator})
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
def test_with_alpha_in_state_dict(self):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe.transformer.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
|
||||
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
|
||||
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
|
||||
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
# modify the state dict to have alpha values following
|
||||
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA/blob/main/jon_snow.safetensors
|
||||
state_dict_with_alpha = safetensors.torch.load_file(
|
||||
os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")
|
||||
)
|
||||
alpha_dict = {}
|
||||
for k, v in state_dict_with_alpha.items():
|
||||
# only do for `transformer` and for the k projections -- should be enough to test.
|
||||
if "transformer" in k and "to_k" in k and "lora_A" in k:
|
||||
alpha_dict[f"{k}.alpha"] = float(torch.randint(10, 100, size=()))
|
||||
state_dict_with_alpha.update(alpha_dict)
|
||||
|
||||
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(state_dict_with_alpha)
|
||||
images_lora_with_alpha = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
|
||||
"Loading from saved checkpoints should give same results.",
|
||||
)
|
||||
self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3))
|
||||
|
||||
@@ -32,7 +32,7 @@ from utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
@require_peft_backend
|
||||
class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = StableDiffusion3Pipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler()
|
||||
scheduler_kwargs = {}
|
||||
uses_flow_matching = True
|
||||
transformer_kwargs = {
|
||||
@@ -80,7 +80,8 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
Related PR: https://github.com/huggingface/diffusers/pull/8584
|
||||
"""
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components[0])
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
|
||||
@@ -124,6 +124,71 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_sdxl_0_9_lora_one(self):
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9")
|
||||
lora_model_id = "hf-internal-testing/sdxl-0.9-daiton-lora"
|
||||
lora_filename = "daiton-xl-lora-test.safetensors"
|
||||
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
images = pipe(
|
||||
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
|
||||
).images
|
||||
|
||||
images = images[0, -3:, -3:, -1].flatten()
|
||||
expected = np.array([0.3838, 0.3482, 0.3588, 0.3162, 0.319, 0.3369, 0.338, 0.3366, 0.3213])
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected, images)
|
||||
assert max_diff < 1e-3
|
||||
pipe.unload_lora_weights()
|
||||
release_memory(pipe)
|
||||
|
||||
def test_sdxl_0_9_lora_two(self):
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9")
|
||||
lora_model_id = "hf-internal-testing/sdxl-0.9-costumes-lora"
|
||||
lora_filename = "saijo.safetensors"
|
||||
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
images = pipe(
|
||||
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
|
||||
).images
|
||||
|
||||
images = images[0, -3:, -3:, -1].flatten()
|
||||
expected = np.array([0.3137, 0.3269, 0.3355, 0.255, 0.2577, 0.2563, 0.2679, 0.2758, 0.2626])
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected, images)
|
||||
assert max_diff < 1e-3
|
||||
|
||||
pipe.unload_lora_weights()
|
||||
release_memory(pipe)
|
||||
|
||||
def test_sdxl_0_9_lora_three(self):
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-0.9")
|
||||
lora_model_id = "hf-internal-testing/sdxl-0.9-kamepan-lora"
|
||||
lora_filename = "kame_sdxl_v2-000020-16rank.safetensors"
|
||||
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
images = pipe(
|
||||
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
|
||||
).images
|
||||
|
||||
images = images[0, -3:, -3:, -1].flatten()
|
||||
expected = np.array([0.4015, 0.3761, 0.3616, 0.3745, 0.3462, 0.3337, 0.3564, 0.3649, 0.3468])
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected, images)
|
||||
assert max_diff < 5e-3
|
||||
|
||||
pipe.unload_lora_weights()
|
||||
release_memory(pipe)
|
||||
|
||||
def test_sdxl_1_0_lora(self):
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
|
||||
@@ -976,6 +976,7 @@ class ModelTesterMixin:
|
||||
self.assertTrue(actual_num_shards == expected_num_shards)
|
||||
|
||||
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto")
|
||||
new_model = new_model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
if "generator" in inputs_dict:
|
||||
|
||||
@@ -26,11 +26,9 @@ from ..test_modeling_common import ModelTesterMixin
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AuraFlowTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = AuraFlowTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
model_split_percents = [0.7, 0.6, 0.6]
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
@@ -73,7 +71,3 @@ class AuraFlowTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skip("AuraFlowTransformer2DModel uses its own dedicated attention processor. This test does not apply")
|
||||
def test_set_attn_processor_for_determinism(self):
|
||||
pass
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user