Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 795f2c90fd | |||
| 84e2337807 |
@@ -1,7 +1,6 @@
|
||||
name: Nightly and release tests on main/release branch
|
||||
name: Nightly tests on main
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: "0 0 * * *" # every day at midnight
|
||||
|
||||
@@ -13,96 +12,110 @@ env:
|
||||
PYTEST_TIMEOUT: 600
|
||||
RUN_SLOW: yes
|
||||
RUN_NIGHTLY: yes
|
||||
PIPELINE_USAGE_CUTOFF: 5000
|
||||
SLACK_API_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
||||
jobs:
|
||||
setup_torch_cuda_pipeline_matrix:
|
||||
name: Setup Torch Pipelines Matrix
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install -e .
|
||||
pip install huggingface_hub
|
||||
- name: Fetch Pipeline Matrix
|
||||
id: fetch_pipeline_matrix
|
||||
run: |
|
||||
matrix=$(python utils/fetch_torch_cuda_pipeline_test_matrix.py)
|
||||
echo $matrix
|
||||
echo "pipeline_test_matrix=$matrix" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Pipeline Tests Artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: test-pipelines.json
|
||||
path: reports
|
||||
|
||||
run_nightly_tests_for_torch_pipelines:
|
||||
name: Torch Pipelines CUDA Nightly Tests
|
||||
needs: setup_torch_cuda_pipeline_matrix
|
||||
run_nightly_tests:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
module: ${{ fromJson(needs.setup_torch_cuda_pipeline_matrix.outputs.pipeline_test_matrix) }}
|
||||
runs-on: [single-gpu, nvidia-gpu, t4, ci]
|
||||
config:
|
||||
- name: Nightly PyTorch CUDA tests on Ubuntu
|
||||
framework: pytorch
|
||||
runner: docker-gpu
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
report: torch_cuda
|
||||
- name: Nightly Flax TPU tests on Ubuntu
|
||||
framework: flax
|
||||
runner: docker-tpu
|
||||
image: diffusers/diffusers-flax-tpu
|
||||
report: flax_tpu
|
||||
- name: Nightly ONNXRuntime CUDA tests on Ubuntu
|
||||
framework: onnxruntime
|
||||
runner: docker-gpu
|
||||
image: diffusers/diffusers-onnxruntime-cuda
|
||||
report: onnx_cuda
|
||||
|
||||
name: ${{ matrix.config.name }}
|
||||
|
||||
runs-on: ${{ matrix.config.runner }}
|
||||
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ --gpus 0
|
||||
image: ${{ matrix.config.image }}
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ ${{ matrix.config.runner == 'docker-tpu' && '--privileged' || '--gpus 0'}}
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: NVIDIA-SMI
|
||||
run: nvidia-smi
|
||||
|
||||
if: ${{ matrix.config.runner == 'docker-gpu' }}
|
||||
run: |
|
||||
nvidia-smi
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers
|
||||
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate
|
||||
python -m uv pip install pytest-reportlog
|
||||
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Nightly PyTorch CUDA checkpoint (pipelines) tests
|
||||
|
||||
- name: Run nightly PyTorch CUDA tests
|
||||
if: ${{ matrix.config.framework == 'pytorch' }}
|
||||
env:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
|
||||
--report-log=tests_pipeline_${{ matrix.module }}_cuda.log \
|
||||
tests/pipelines/${{ matrix.module }}
|
||||
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
--report-log=${{ matrix.config.report }}.log \
|
||||
tests/
|
||||
|
||||
- name: Run nightly Flax TPU tests
|
||||
if: ${{ matrix.config.framework == 'flax' }}
|
||||
env:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pytest -n 0 \
|
||||
-s -v -k "Flax" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
--report-log=${{ matrix.config.report }}.log \
|
||||
tests/
|
||||
|
||||
- name: Run nightly ONNXRuntime CUDA tests
|
||||
if: ${{ matrix.config.framework == 'onnxruntime' }}
|
||||
env:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
--report-log=${{ matrix.config.report }}.log \
|
||||
tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_pipeline_${{ matrix.module }}_cuda_stats.txt
|
||||
cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
|
||||
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: pipeline_${{ matrix.module }}_test_reports
|
||||
name: ${{ matrix.config.report }}_test_reports
|
||||
path: reports
|
||||
|
||||
- name: Generate Report and Notify Channel
|
||||
@@ -111,254 +124,9 @@ jobs:
|
||||
pip install slack_sdk tabulate
|
||||
python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
run_nightly_tests_for_other_torch_modules:
|
||||
name: Torch Non-Pipelines CUDA Nightly Tests
|
||||
runs-on: docker-gpu
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ --gpus 0
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
strategy:
|
||||
matrix:
|
||||
module: [models, schedulers, others, examples]
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
python -m uv pip install pytest-reportlog
|
||||
|
||||
- name: Environment
|
||||
run: python utils/print_env.py
|
||||
|
||||
- name: Run nightly PyTorch CUDA tests for non-pipeline modules
|
||||
if: ${{ matrix.module != 'examples'}}
|
||||
env:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_torch_${{ matrix.module }}_cuda \
|
||||
--report-log=tests_torch_${{ matrix.module }}_cuda.log \
|
||||
tests/${{ matrix.module }}
|
||||
|
||||
- name: Run nightly example tests with Torch
|
||||
if: ${{ matrix.module == 'examples' }}
|
||||
env:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
python -m uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v --make-reports=examples_torch_cuda \
|
||||
--report-log=examples_torch_cuda.log \
|
||||
examples/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_torch_${{ matrix.module }}_cuda_stats.txt
|
||||
cat reports/tests_torch_${{ matrix.module }}_cuda_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: torch_${{ matrix.module }}_cuda_test_reports
|
||||
path: reports
|
||||
|
||||
- name: Generate Report and Notify Channel
|
||||
if: always()
|
||||
run: |
|
||||
pip install slack_sdk tabulate
|
||||
python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
run_lora_nightly_tests:
|
||||
name: Nightly LoRA Tests with PEFT and TORCH
|
||||
runs-on: docker-gpu
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ --gpus 0
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
python -m uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
python -m uv pip install pytest-reportlog
|
||||
|
||||
- name: Environment
|
||||
run: python utils/print_env.py
|
||||
|
||||
- name: Run nightly LoRA tests with PEFT and Torch
|
||||
env:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_torch_lora_cuda \
|
||||
--report-log=tests_torch_lora_cuda.log \
|
||||
tests/lora
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_torch_lora_cuda_stats.txt
|
||||
cat reports/tests_torch_lora_cuda_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: torch_lora_cuda_test_reports
|
||||
path: reports
|
||||
|
||||
- name: Generate Report and Notify Channel
|
||||
if: always()
|
||||
run: |
|
||||
pip install slack_sdk tabulate
|
||||
python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
run_flax_tpu_tests:
|
||||
name: Nightly Flax TPU Tests
|
||||
runs-on: docker-tpu
|
||||
if: github.event_name == 'schedule'
|
||||
|
||||
container:
|
||||
image: diffusers/diffusers-flax-tpu
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ --privileged
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
python -m uv pip install pytest-reportlog
|
||||
|
||||
- name: Environment
|
||||
run: python utils/print_env.py
|
||||
|
||||
- name: Run nightly Flax TPU tests
|
||||
env:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 0 \
|
||||
-s -v -k "Flax" \
|
||||
--make-reports=tests_flax_tpu \
|
||||
--report-log=tests_flax_tpu.log \
|
||||
tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_flax_tpu_stats.txt
|
||||
cat reports/tests_flax_tpu_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: flax_tpu_test_reports
|
||||
path: reports
|
||||
|
||||
- name: Generate Report and Notify Channel
|
||||
if: always()
|
||||
run: |
|
||||
pip install slack_sdk tabulate
|
||||
python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
run_nightly_onnx_tests:
|
||||
name: Nightly ONNXRuntime CUDA tests on Ubuntu
|
||||
runs-on: docker-gpu
|
||||
container:
|
||||
image: diffusers/diffusers-onnxruntime-cuda
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: NVIDIA-SMI
|
||||
run: nvidia-smi
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
python -m uv pip install pytest-reportlog
|
||||
|
||||
- name: Environment
|
||||
run: python utils/print_env.py
|
||||
|
||||
- name: Run nightly ONNXRuntime CUDA tests
|
||||
env:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Onnx" \
|
||||
--make-reports=tests_onnx_cuda \
|
||||
--report-log=tests_onnx_cuda.log \
|
||||
tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_onnx_cuda_stats.txt
|
||||
cat reports/tests_onnx_cuda_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: ${{ matrix.config.report }}_test_reports
|
||||
path: reports
|
||||
|
||||
- name: Generate Report and Notify Channel
|
||||
if: always()
|
||||
run: |
|
||||
pip install slack_sdk tabulate
|
||||
python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
run_nightly_tests_apple_m1:
|
||||
name: Nightly PyTorch MPS tests on MacOS
|
||||
runs-on: [ self-hosted, apple-m1 ]
|
||||
if: github.event_name == 'schedule'
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
|
||||
@@ -32,11 +32,9 @@ jobs:
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[quality]
|
||||
- name: Check quality
|
||||
run: make quality
|
||||
- name: Check if failure
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'" >> $GITHUB_STEP_SUMMARY
|
||||
ruff check examples tests src utils scripts
|
||||
ruff format examples tests src utils scripts --check
|
||||
|
||||
check_repository_consistency:
|
||||
needs: check_code_quality
|
||||
@@ -51,15 +49,11 @@ jobs:
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[quality]
|
||||
- name: Check repo consistency
|
||||
- name: Check quality
|
||||
run: |
|
||||
python utils/check_copies.py
|
||||
python utils/check_dummies.py
|
||||
make deps_table_check_updated
|
||||
- name: Check if failure
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
run_fast_tests:
|
||||
needs: [check_code_quality, check_repository_consistency]
|
||||
@@ -71,7 +65,7 @@ jobs:
|
||||
|
||||
name: LoRA - ${{ matrix.lib-versions }}
|
||||
|
||||
runs-on: [ self-hosted, intel-cpu, 8-cpu, ci ]
|
||||
runs-on: docker-cpu
|
||||
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
@@ -108,7 +102,7 @@ jobs:
|
||||
- name: Run fast PyTorch LoRA CPU tests with PEFT backend
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/lora/
|
||||
|
||||
@@ -40,11 +40,9 @@ jobs:
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[quality]
|
||||
- name: Check quality
|
||||
run: make quality
|
||||
- name: Check if failure
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'" >> $GITHUB_STEP_SUMMARY
|
||||
ruff check examples tests src utils scripts
|
||||
ruff format examples tests src utils scripts --check
|
||||
|
||||
check_repository_consistency:
|
||||
needs: check_code_quality
|
||||
@@ -59,15 +57,11 @@ jobs:
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[quality]
|
||||
- name: Check repo consistency
|
||||
- name: Check quality
|
||||
run: |
|
||||
python utils/check_copies.py
|
||||
python utils/check_dummies.py
|
||||
make deps_table_check_updated
|
||||
- name: Check if failure
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
run_fast_tests:
|
||||
needs: [check_code_quality, check_repository_consistency]
|
||||
@@ -77,22 +71,22 @@ jobs:
|
||||
config:
|
||||
- name: Fast PyTorch Pipeline CPU tests
|
||||
framework: pytorch_pipelines
|
||||
runner: [ self-hosted, intel-cpu, 32-cpu, 256-ram, ci ]
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_cpu_pipelines
|
||||
- name: Fast PyTorch Models & Schedulers CPU tests
|
||||
framework: pytorch_models
|
||||
runner: [ self-hosted, intel-cpu, 8-cpu, ci ]
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_cpu_models_schedulers
|
||||
- name: Fast Flax CPU tests
|
||||
framework: flax
|
||||
runner: [ self-hosted, intel-cpu, 8-cpu, ci ]
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-flax-cpu
|
||||
report: flax_cpu
|
||||
- name: PyTorch Example CPU tests
|
||||
framework: pytorch_examples
|
||||
runner: [ self-hosted, intel-cpu, 8-cpu, ci ]
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_example_cpu
|
||||
|
||||
@@ -130,7 +124,7 @@ jobs:
|
||||
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pytest -n 8 --max-worker-restart=0 --dist=loadfile \
|
||||
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/pipelines
|
||||
@@ -139,7 +133,7 @@ jobs:
|
||||
if: ${{ matrix.config.framework == 'pytorch_models' }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx and not Dependency" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/models tests/schedulers tests/others
|
||||
@@ -148,7 +142,7 @@ jobs:
|
||||
if: ${{ matrix.config.framework == 'flax' }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Flax" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests
|
||||
@@ -158,7 +152,7 @@ jobs:
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install peft
|
||||
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
examples
|
||||
|
||||
@@ -181,7 +175,7 @@ jobs:
|
||||
config:
|
||||
- name: Hub tests for models, schedulers, and pipelines
|
||||
framework: hub_tests_pytorch
|
||||
runner: [ self-hosted, intel-cpu, 8-cpu, ci ]
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_hub
|
||||
|
||||
|
||||
@@ -29,22 +29,22 @@ jobs:
|
||||
config:
|
||||
- name: Fast PyTorch CPU tests on Ubuntu
|
||||
framework: pytorch
|
||||
runner: [ self-hosted, intel-cpu, 8-cpu, ci ]
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_cpu
|
||||
- name: Fast Flax CPU tests on Ubuntu
|
||||
framework: flax
|
||||
runner: [ self-hosted, intel-cpu, 8-cpu, ci ]
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-flax-cpu
|
||||
report: flax_cpu
|
||||
- name: Fast ONNXRuntime CPU tests on Ubuntu
|
||||
framework: onnxruntime
|
||||
runner: [ self-hosted, intel-cpu, 8-cpu, ci ]
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-onnxruntime-cpu
|
||||
report: onnx_cpu
|
||||
- name: PyTorch Example CPU tests on Ubuntu
|
||||
framework: pytorch_examples
|
||||
runner: [ self-hosted, intel-cpu, 8-cpu, ci ]
|
||||
runner: docker-cpu
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_example_cpu
|
||||
|
||||
@@ -81,7 +81,7 @@ jobs:
|
||||
if: ${{ matrix.config.framework == 'pytorch' }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
@@ -90,7 +90,7 @@ jobs:
|
||||
if: ${{ matrix.config.framework == 'flax' }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Flax" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
@@ -99,7 +99,7 @@ jobs:
|
||||
if: ${{ matrix.config.framework == 'onnxruntime' }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
@@ -109,7 +109,7 @@ jobs:
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install peft
|
||||
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
examples
|
||||
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
name: Update Diffusers metadata
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- update_diffusers_metadata*
|
||||
|
||||
jobs:
|
||||
update_metadata:
|
||||
runs-on: ubuntu-22.04
|
||||
defaults:
|
||||
run:
|
||||
shell: bash -l {0}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: Setup environment
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install datasets pandas
|
||||
pip install .[torch]
|
||||
|
||||
- name: Update metadata
|
||||
env:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.DIFFUSERS_BOT_TOKEN }}
|
||||
run: |
|
||||
python utils/update_metadata.py --commit_sha ${{ github.sha }}
|
||||
@@ -42,7 +42,6 @@ repo-consistency:
|
||||
quality:
|
||||
ruff check $(check_dirs) setup.py
|
||||
ruff format --check $(check_dirs) setup.py
|
||||
doc-builder style src/diffusers docs/source --max_len 119 --check_only
|
||||
python utils/check_doc_toc.py
|
||||
|
||||
# Format source code automatically and check is there are any problems left that need manual fixing
|
||||
@@ -56,7 +55,6 @@ extra_style_checks:
|
||||
style:
|
||||
ruff check $(check_dirs) setup.py --fix
|
||||
ruff format $(check_dirs) setup.py
|
||||
doc-builder style src/diffusers docs/source --max_len 119
|
||||
${MAKE} autogenerate_code
|
||||
${MAKE} extra_style_checks
|
||||
|
||||
|
||||
@@ -20,8 +20,7 @@ The abstract of the paper is the following:
|
||||
|
||||
*Although audio generation shares commonalities across different types of audio, such as speech, music, and sound effects, designing models for each type requires careful consideration of specific objectives and biases that can significantly differ from those of other types. To bring us closer to a unified perspective of audio generation, this paper proposes a framework that utilizes the same learning method for speech, music, and sound effect generation. Our framework introduces a general representation of audio, called "language of audio" (LOA). Any audio can be translated into LOA based on AudioMAE, a self-supervised pre-trained representation learning model. In the generation process, we translate any modalities into LOA by using a GPT-2 model, and we perform self-supervised audio generation learning with a latent diffusion model conditioned on LOA. The proposed framework naturally brings advantages such as in-context learning abilities and reusable self-supervised pretrained AudioMAE and latent diffusion models. Experiments on the major benchmarks of text-to-audio, text-to-music, and text-to-speech demonstrate state-of-the-art or competitive performance against previous approaches. Our code, pretrained model, and demo are available at [this https URL](https://audioldm.github.io/audioldm2).*
|
||||
|
||||
This pipeline was contributed by [sanchit-gandhi](https://huggingface.co/sanchit-gandhi) and [Nguyễn Công Tú Anh](https://github.com/tuanh123789). The original codebase can be
|
||||
found at [haoheliu/audioldm2](https://github.com/haoheliu/audioldm2).
|
||||
This pipeline was contributed by [sanchit-gandhi](https://huggingface.co/sanchit-gandhi). The original codebase can be found at [haoheliu/audioldm2](https://github.com/haoheliu/audioldm2).
|
||||
|
||||
## Tips
|
||||
|
||||
@@ -37,8 +36,6 @@ See table below for details on the three checkpoints:
|
||||
| [audioldm2](https://huggingface.co/cvssp/audioldm2) | Text-to-audio | 350M | 1.1B | 1150k |
|
||||
| [audioldm2-large](https://huggingface.co/cvssp/audioldm2-large) | Text-to-audio | 750M | 1.5B | 1150k |
|
||||
| [audioldm2-music](https://huggingface.co/cvssp/audioldm2-music) | Text-to-music | 350M | 1.1B | 665k |
|
||||
| [audioldm2-gigaspeech](https://huggingface.co/anhnct/audioldm2_gigaspeech) | Text-to-speech | 350M | 1.1B |10k |
|
||||
| [audioldm2-ljspeech](https://huggingface.co/anhnct/audioldm2_ljspeech) | Text-to-speech | 350M | 1.1B | |
|
||||
|
||||
### Constructing a prompt
|
||||
|
||||
@@ -56,7 +53,7 @@ See table below for details on the three checkpoints:
|
||||
* The quality of the generated waveforms can vary significantly based on the seed. Try generating with different seeds until you find a satisfactory generation.
|
||||
* Multiple waveforms can be generated in one go: set `num_waveforms_per_prompt` to a value greater than 1. Automatic scoring will be performed between the generated waveforms and prompt text, and the audios ranked from best to worst accordingly.
|
||||
|
||||
The following example demonstrates how to construct good music and speech generation using the aforementioned tips: [example](https://huggingface.co/docs/diffusers/main/en/api/pipelines/audioldm2#diffusers.AudioLDM2Pipeline.__call__.example).
|
||||
The following example demonstrates how to construct good music generation using the aforementioned tips: [example](https://huggingface.co/docs/diffusers/main/en/api/pipelines/audioldm2#diffusers.AudioLDM2Pipeline.__call__.example).
|
||||
|
||||
<Tip>
|
||||
|
||||
|
||||
@@ -88,7 +88,7 @@ accelerate config default
|
||||
|
||||
Or if your environment doesn't support an interactive shell, like a notebook, you can use:
|
||||
|
||||
```py
|
||||
```bash
|
||||
from accelerate.utils import write_basic_config
|
||||
|
||||
write_basic_config()
|
||||
|
||||
@@ -54,7 +54,7 @@ accelerate config default
|
||||
|
||||
Or if your environment doesn't support an interactive shell, like a notebook, you can use:
|
||||
|
||||
```py
|
||||
```bash
|
||||
from accelerate.utils import write_basic_config
|
||||
|
||||
write_basic_config()
|
||||
@@ -84,7 +84,7 @@ Many of the basic parameters are described in the [DreamBooth](dreambooth#script
|
||||
- `--freeze_model`: freezes the key and value parameters in the cross-attention layer; the default is `crossattn_kv`, but you can set it to `crossattn` to train all the parameters in the cross-attention layer
|
||||
- `--concepts_list`: to learn multiple concepts, provide a path to a JSON file containing the concepts
|
||||
- `--modifier_token`: a special word used to represent the learned concept
|
||||
- `--initializer_token`: a special word used to initialize the embeddings of the `modifier_token`
|
||||
- `--initializer_token`:
|
||||
|
||||
### Prior preservation loss
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ accelerate config default
|
||||
|
||||
Or if your environment doesn't support an interactive shell, like a notebook, you can use:
|
||||
|
||||
```py
|
||||
```bash
|
||||
from accelerate.utils import write_basic_config
|
||||
|
||||
write_basic_config()
|
||||
@@ -180,7 +180,7 @@ elif args.pretrained_model_name_or_path:
|
||||
revision=args.revision,
|
||||
use_fast=False,
|
||||
)
|
||||
|
||||
|
||||
# Load scheduler and models
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
text_encoder = text_encoder_cls.from_pretrained(
|
||||
|
||||
@@ -51,7 +51,7 @@ accelerate config default
|
||||
|
||||
Or if your environment doesn't support an interactive shell, like a notebook, you can use:
|
||||
|
||||
```py
|
||||
```bash
|
||||
from accelerate.utils import write_basic_config
|
||||
|
||||
write_basic_config()
|
||||
@@ -89,7 +89,7 @@ The dataset preprocessing code and training loop are found in the [`main()`](htt
|
||||
|
||||
As with the script parameters, a walkthrough of the training script is provided in the [Text-to-image](text2image#training-script) training guide. Instead, this guide takes a look at the InstructPix2Pix relevant parts of the script.
|
||||
|
||||
The script begins by modifying the [number of input channels](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L445) in the first convolutional layer of the UNet to account for InstructPix2Pix's additional conditioning image:
|
||||
The script begins by modifing the [number of input channels](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/instruct_pix2pix/train_instruct_pix2pix.py#L445) in the first convolutional layer of the UNet to account for InstructPix2Pix's additional conditioning image:
|
||||
|
||||
```py
|
||||
in_channels = 8
|
||||
|
||||
@@ -59,7 +59,7 @@ accelerate config default
|
||||
|
||||
Or if your environment doesn't support an interactive shell, like a notebook, you can use:
|
||||
|
||||
```py
|
||||
```bash
|
||||
from accelerate.utils import write_basic_config
|
||||
|
||||
write_basic_config()
|
||||
@@ -235,7 +235,7 @@ accelerate launch --mixed_precision="fp16" train_text_to_image_prior.py \
|
||||
--validation_prompts="A robot pokemon, 4k photo" \
|
||||
--report_to="wandb" \
|
||||
--push_to_hub \
|
||||
--output_dir="kandi2-prior-pokemon-model"
|
||||
--output_dir="kandi2-prior-pokemon-model"
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
@@ -259,7 +259,7 @@ accelerate launch --mixed_precision="fp16" train_text_to_image_decoder.py \
|
||||
--validation_prompts="A robot pokemon, 4k photo" \
|
||||
--report_to="wandb" \
|
||||
--push_to_hub \
|
||||
--output_dir="kandi2-decoder-pokemon-model"
|
||||
--output_dir="kandi2-decoder-pokemon-model"
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
|
||||
@@ -53,7 +53,7 @@ accelerate config default
|
||||
|
||||
Or if your environment doesn't support an interactive shell, like a notebook, you can use:
|
||||
|
||||
```py
|
||||
```bash
|
||||
from accelerate.utils import write_basic_config
|
||||
|
||||
write_basic_config()
|
||||
@@ -252,4 +252,4 @@ The SDXL training script is discussed in more detail in the [SDXL training](sdxl
|
||||
Congratulations on distilling a LCM model! To learn more about LCM, the following may be helpful:
|
||||
|
||||
- Learn how to use [LCMs for inference](../using-diffusers/lcm) for text-to-image, image-to-image, and with LoRA checkpoints.
|
||||
- Read the [SDXL in 4 steps with Latent Consistency LoRAs](https://huggingface.co/blog/lcm_lora) blog post to learn more about SDXL LCM-LoRA's for super fast inference, quality comparisons, benchmarks, and more.
|
||||
- Read the [SDXL in 4 steps with Latent Consistency LoRAs](https://huggingface.co/blog/lcm_lora) blog post to learn more about SDXL LCM-LoRA's for super fast inference, quality comparisons, benchmarks, and more.
|
||||
@@ -59,7 +59,7 @@ accelerate config default
|
||||
|
||||
Or if your environment doesn't support an interactive shell, like a notebook, you can use:
|
||||
|
||||
```py
|
||||
```bash
|
||||
from accelerate.utils import write_basic_config
|
||||
|
||||
write_basic_config()
|
||||
|
||||
@@ -53,7 +53,7 @@ accelerate config default
|
||||
|
||||
Or if your environment doesn't support an interactive shell, like a notebook, you can use:
|
||||
|
||||
```py
|
||||
```bash
|
||||
from accelerate.utils import write_basic_config
|
||||
|
||||
write_basic_config()
|
||||
|
||||
@@ -69,7 +69,7 @@ accelerate config default
|
||||
|
||||
Or if your environment doesn't support an interactive shell, like a notebook, you can use:
|
||||
|
||||
```py
|
||||
```bash
|
||||
from accelerate.utils import write_basic_config
|
||||
|
||||
write_basic_config()
|
||||
|
||||
@@ -67,7 +67,7 @@ accelerate config default
|
||||
|
||||
Or if your environment doesn't support an interactive shell, like a notebook, you can use:
|
||||
|
||||
```py
|
||||
```bash
|
||||
from accelerate.utils import write_basic_config
|
||||
|
||||
write_basic_config()
|
||||
|
||||
@@ -51,7 +51,7 @@ accelerate config default
|
||||
|
||||
Or if your environment doesn't support an interactive shell like a notebook, you can use:
|
||||
|
||||
```py
|
||||
```bash
|
||||
from accelerate.utils import write_basic_config
|
||||
|
||||
write_basic_config()
|
||||
|
||||
@@ -53,7 +53,7 @@ accelerate config default
|
||||
|
||||
Or if your environment doesn't support an interactive shell, like a notebook, you can use:
|
||||
|
||||
```py
|
||||
```bash
|
||||
from accelerate.utils import write_basic_config
|
||||
|
||||
write_basic_config()
|
||||
@@ -173,7 +173,7 @@ pipeline = AutoPipelineForText2Image.from_pretrained("path/to/saved/model", torc
|
||||
|
||||
caption = "A cute bird pokemon holding a shield"
|
||||
images = pipeline(
|
||||
caption,
|
||||
caption,
|
||||
width=1024,
|
||||
height=1536,
|
||||
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
|
||||
|
||||
@@ -133,62 +133,6 @@ image
|
||||
|
||||

|
||||
|
||||
### Customize adapters strength
|
||||
For even more customization, you can control how strongly the adapter affects each part of the pipeline. For this, pass a dictionary with the control strengths (called "scales") to [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`].
|
||||
|
||||
For example, here's how you can turn on the adapter for the `down` parts, but turn it off for the `mid` and `up` parts:
|
||||
```python
|
||||
pipe.enable_lora() # enable lora again, after we disabled it above
|
||||
prompt = "toy_face of a hacker with a hoodie, pixel art"
|
||||
adapter_weight_scales = { "unet": { "down": 1, "mid": 0, "up": 0} }
|
||||
pipe.set_adapters("pixel", adapter_weight_scales)
|
||||
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
|
||||
image
|
||||
```
|
||||
|
||||

|
||||
|
||||
Let's see how turning off the `down` part and turning on the `mid` and `up` part respectively changes the image.
|
||||
```python
|
||||
adapter_weight_scales = { "unet": { "down": 0, "mid": 1, "up": 0} }
|
||||
pipe.set_adapters("pixel", adapter_weight_scales)
|
||||
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
|
||||
image
|
||||
```
|
||||
|
||||

|
||||
|
||||
```python
|
||||
adapter_weight_scales = { "unet": { "down": 0, "mid": 0, "up": 1} }
|
||||
pipe.set_adapters("pixel", adapter_weight_scales)
|
||||
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
|
||||
image
|
||||
```
|
||||
|
||||

|
||||
|
||||
Looks cool!
|
||||
|
||||
This is a really powerful feature. You can use it to control the adapter strengths down to per-transformer level. And you can even use it for multiple adapters.
|
||||
```python
|
||||
adapter_weight_scales_toy = 0.5
|
||||
adapter_weight_scales_pixel = {
|
||||
"unet": {
|
||||
"down": 0.9, # all transformers in the down-part will use scale 0.9
|
||||
# "mid" # because, in this example, "mid" is not given, all transformers in the mid part will use the default scale 1.0
|
||||
"up": {
|
||||
"block_0": 0.6, # all 3 transformers in the 0th block in the up-part will use scale 0.6
|
||||
"block_1": [0.4, 0.8, 1.0], # the 3 transformers in the 1st block in the up-part will use scales 0.4, 0.8 and 1.0 respectively
|
||||
}
|
||||
}
|
||||
}
|
||||
pipe.set_adapters(["toy", "pixel"], [adapter_weight_scales_toy, adapter_weight_scales_pixel])
|
||||
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
|
||||
image
|
||||
```
|
||||
|
||||

|
||||
|
||||
## Manage active adapters
|
||||
|
||||
You have attached multiple adapters in this tutorial, and if you're feeling a bit lost on what adapters have been attached to the pipeline's components, use the [`~diffusers.loaders.LoraLoaderMixin.get_active_adapters`] method to check the list of active adapters:
|
||||
|
||||
@@ -148,9 +148,9 @@ pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
use_safetensors=True
|
||||
).to("cuda")
|
||||
|
||||
image = pipeline(
|
||||
prompt="A croissant shaped like a cute bear.",
|
||||
negative_prompt="Deformed, ugly, bad anatomy",
|
||||
image = pipe(
|
||||
prompt = "A croissant shaped like a cute bear."
|
||||
negative_prompt = "Deformed, ugly, bad anatomy"
|
||||
callback_on_step_end=decode_tensors,
|
||||
callback_on_step_end_tensor_inputs=["latents"],
|
||||
).images[0]
|
||||
|
||||
@@ -179,210 +179,6 @@ stable_diffusion_img2img = StableDiffusionImg2ImgPipeline(
|
||||
)
|
||||
```
|
||||
|
||||
### Switch loaded pippelines
|
||||
|
||||
There are many diffuser pipelines that use the same pre-trained model as [`StableDiffusionPipeline`] and [`StableDiffusionXLPipeline`], but they implement specific features to help you achieve better generation results. This guide will show you how to use the `from_pipe` API to create multiple pipelines without increasing memory usage. By using this approach, you can easily switch between pipelines to use different features.
|
||||
|
||||
Let's take an example where we first create a [`StableDiffusionPipeline`] and then reuse the already loaded model components to create a [`StableDiffusionSAGPipeline`] to enhance generation quality.
|
||||
|
||||
we will generate an image of a bear eating pizza using Stable Diffusion with the IP-Adapter
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline, StableDiffusionSAGPipeline
|
||||
import torch
|
||||
import gc
|
||||
from diffusers.utils import load_image
|
||||
from accelerate.utils import compute_module_sizes
|
||||
|
||||
base_repo = "SG161222/Realistic_Vision_V6.0_B1_noVAE"
|
||||
num_inference_steps = 50
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png")
|
||||
prompt="bear eats pizza"
|
||||
negative_prompt = "wrong white balance, dark, sketches,worst quality,low quality"
|
||||
|
||||
pipe_sd = DiffusionPipeline.from_pretrained(base_repo, torch_dtype=torch.float16)
|
||||
pipe_sd.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
|
||||
pipe_sd.set_ip_adapter_scale(0.6)
|
||||
pipe_sd.to("cuda")
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(33)
|
||||
out_sd = pipe_sd(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
ip_adapter_image=image,
|
||||
num_inference_steps=num_inference_steps,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
let’s take a look at the image and also print out the memory used
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/from_pipe_out_sd_0.png"/>
|
||||
</div>
|
||||
|
||||
```python
|
||||
def bytes_to_giga_bytes(bytes):
|
||||
return bytes / 1024 / 1024 / 1024
|
||||
print(
|
||||
f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB"
|
||||
)
|
||||
```
|
||||
|
||||
```bash
|
||||
Max memory allocated: 4.406213283538818 GB
|
||||
```
|
||||
|
||||
Now, we can use `from_pipe` to switch to the SAG pipeline.
|
||||
|
||||
```python
|
||||
pipe_sag = StableDiffusionSAGPipeline.from_pipe(
|
||||
pipe_sd,
|
||||
)
|
||||
```
|
||||
|
||||
It already has IP-Adapter loaded so that you can pass the same bear image as `ip_adapter_image`
|
||||
|
||||
```python
|
||||
generator = torch.Generator(device="cpu").manual_seed(33)
|
||||
out_sag = pipe_sag(
|
||||
prompt = prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
ip_adapter_image=image,
|
||||
num_inference_steps=num_inference_steps,
|
||||
generator=generator,
|
||||
guidance_scale=1.0,
|
||||
sag_scale=0.75).images[0]
|
||||
```
|
||||
|
||||
You can see a pretty nice improvement in the output
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/from_pipe_out_sag_1.png"/>
|
||||
</div>
|
||||
|
||||
Now we have both `stableDiffusionPipeline` and `StableDiffusionSAGPipeline` co-existing with the same loaded model components; You can use them interchangeably without additional memory.
|
||||
|
||||
```
|
||||
print(
|
||||
f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB"
|
||||
)
|
||||
```
|
||||
|
||||
```bash
|
||||
Max memory allocated: 4.406213283538818 GB
|
||||
```
|
||||
|
||||
Let's unload the IP adapter from the SAG pipeline. It's important to note that methods like `load_ip_adapter` and `unload_ip_adapter` modify the state of the model components. Therefore, when you use these methods on one pipeline, it will affect all other pipelines that share the same model components.
|
||||
|
||||
```bash
|
||||
pipe_sag.unload_ip_adapter()
|
||||
```
|
||||
|
||||
If you try to use the Stable Diffusion pipeline with IP adapter again, it will fail
|
||||
|
||||
```bash
|
||||
generator = torch.Generator(device="cpu").manual_seed(33)
|
||||
out_sd = pipe_sd(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
ip_adapter_image=image,
|
||||
num_inference_steps=num_inference_steps,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
```bash
|
||||
AttributeError: 'NoneType' object has no attribute 'image_projection_layers'
|
||||
```
|
||||
|
||||
Please note that the pipeline methods may not function properly on a new pipeline created using the `from_pipe` method. For instance, the `enable_model_cpu_offload` method installs hooks to the model components based on a unique offloading sequence for each pipeline. Therefore, if the models are executed in a different order in the new pipeline, the CPU offloading may not work correctly.
|
||||
|
||||
To ensure proper functionality, we recommend re-applying the pipeline methods on the new pipeline created using the `from_pipe` method.
|
||||
|
||||
You can also add or subtract model components when you create new pipelines. Let's now create a AnimateDiff pipeline with an additional `MotionAdapter` module
|
||||
|
||||
```bash
|
||||
from diffusers import AnimateDiffPipeline, MotionAdapter, DDIMScheduler
|
||||
from diffusers.utils import export_to_gif
|
||||
|
||||
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16)
|
||||
|
||||
pipe_animate = AnimateDiffPipeline.from_pipe(pipe_sd, motion_adapter=adapter)
|
||||
pipe_animate.scheduler = DDIMScheduler.from_config(pipe_animate.scheduler.config, beta_schedule="linear")
|
||||
# load ip_adapter again and load lora weights
|
||||
pipe_animate.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
|
||||
pipe_animate.load_lora_weights("guoyww/animatediff-motion-lora-zoom-out", adapter_name="zoom-out")
|
||||
pipe_animate.to("cuda")
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(33)
|
||||
pipe_animate.set_adapters("zoom-out", adapter_weights=0.75)
|
||||
out = pipe_animate(
|
||||
prompt= prompt,
|
||||
num_frames=16,
|
||||
num_inference_steps=num_inference_steps,
|
||||
ip_adapter_image = image,
|
||||
generator=generator,
|
||||
).frames[0]
|
||||
export_to_gif(out, "out_animate.gif")
|
||||
```
|
||||
<div class="flex justify-center">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/from_pipe_out_animate_3.gif"/>
|
||||
</div>
|
||||
|
||||
|
||||
When creating multiple pipelines using the `from_pipe` method, it is important to note that the memory requirement will be determined by the pipeline with the highest memory usage. This means that regardless of the number of pipelines you create, the total memory requirement will always be the same as the highest memory requirement among the pipelines.
|
||||
|
||||
For example, we have created three pipelines - `stableDiffusionPipeline`, `StableDiffusionSAGPipeline`, and `AnimateDiffPipeline` - and the `AnimateDiffPipeline` has the highest memory requirement, then the total memory usage will be based on the memory requirement of the `AnimateDiffPipeline`.
|
||||
|
||||
Therefore, creating additional pipelines will not add up to the total memory requirement. Each pipeline can be used interchangeably without any additional memory overhead.
|
||||
|
||||
|
||||
Did you know that you can use `from_pipe` with a community pipeline? Let me show you an example of using long negative prompt and prompt weighting!
|
||||
|
||||
```bash
|
||||
pipe_lpw = DiffusionPipeline.from_pipe(
|
||||
pipe_sd,
|
||||
custom_pipeline="lpw_stable_diffusion",
|
||||
).to("cuda")
|
||||
|
||||
prompt = "best_quality (1girl:1.3) bow bride brown_hair closed_mouth frilled_bow frilled_hair_tubes frills (full_body:1.3) fox_ear hair_bow hair_tubes happy hood japanese_clothes kimono long_sleeves red_bow smile solo tabi uchikake white_kimono wide_sleeves cherry_blossoms"
|
||||
neg_prompt = "lowres, bad_anatomy, error_body, error_hair, error_arm, error_hands, bad_hands, error_fingers, bad_fingers, missing_fingers, error_legs, bad_legs, multiple_legs, missing_legs, error_lighting, error_shadow, error_reflection, text, error, extra_digit, fewer_digits, cropped, worst_quality, low_quality, normal_quality, jpeg_artifacts, signature, watermark, username, blurry"
|
||||
generator = torch.Generator(device="cpu").manual_seed(33)
|
||||
out_lpw = pipe_lpw.text2img(
|
||||
prompt,
|
||||
negative_prompt=neg_prompt,
|
||||
width=512,height=512,
|
||||
max_embeddings_multiples=3,
|
||||
num_inference_steps=num_inference_steps,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/from_pipe_out_lpw_4.png"/>
|
||||
</div>
|
||||
|
||||
let’s run StableDiffusionPipeline with the same inputs to compare: the result from the long prompt weighting pipeline is more aligned with the text prompt.
|
||||
|
||||
```
|
||||
generator = torch.Generator(device="cpu").manual_seed(33)
|
||||
out_sd = pipe_sd(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
).images[0]
|
||||
out_sd
|
||||
```
|
||||
<div class="flex justify-center">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/from_pipe_out_sd_5.png"/>
|
||||
</div>
|
||||
|
||||
|
||||
You can easily switch between different pipelines using the `from_pipe` method, similar to turning on and off a feature on your pipeline. To switch between tasks, you can use the `from_pipe` method with `AutoPipeline`, which automatically identifies the pipeline class based on the task. You can find more information about this feature at the [AutoPipe Guide](https://huggingface.co/docs/diffusers/tutorials/autopipeline).
|
||||
|
||||
|
||||
## Checkpoint variants
|
||||
|
||||
A checkpoint variant is usually a checkpoint whose weights are:
|
||||
|
||||
@@ -153,43 +153,18 @@ image
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_attn_proc.png" />
|
||||
</div>
|
||||
|
||||
<Tip>
|
||||
|
||||
For both [`~loaders.LoraLoaderMixin.load_lora_weights`] and [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`], you can pass the `cross_attention_kwargs={"scale": 0.5}` parameter to adjust how much of the LoRA weights to use. A value of `0` is the same as only using the base model weights, and a value of `1` is equivalent to using the fully finetuned LoRA.
|
||||
|
||||
</Tip>
|
||||
|
||||
To unload the LoRA weights, use the [`~loaders.LoraLoaderMixin.unload_lora_weights`] method to discard the LoRA weights and restore the model to its original weights:
|
||||
|
||||
```py
|
||||
pipeline.unload_lora_weights()
|
||||
```
|
||||
|
||||
### Adjust LoRA weight scale
|
||||
|
||||
For both [`~loaders.LoraLoaderMixin.load_lora_weights`] and [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`], you can pass the `cross_attention_kwargs={"scale": 0.5}` parameter to adjust how much of the LoRA weights to use. A value of `0` is the same as only using the base model weights, and a value of `1` is equivalent to using the fully finetuned LoRA.
|
||||
|
||||
For more granular control on the amount of LoRA weights used per layer, you can use [`~loaders.LoraLoaderMixin.set_adapters`] and pass a dictionary specifying by how much to scale the weights in each layer by.
|
||||
```python
|
||||
pipe = ... # create pipeline
|
||||
pipe.load_lora_weights(..., adapter_name="my_adapter")
|
||||
scales = {
|
||||
"text_encoder": 0.5,
|
||||
"text_encoder_2": 0.5, # only usable if pipe has a 2nd text encoder
|
||||
"unet": {
|
||||
"down": 0.9, # all transformers in the down-part will use scale 0.9
|
||||
# "mid" # in this example "mid" is not given, therefore all transformers in the mid part will use the default scale 1.0
|
||||
"up": {
|
||||
"block_0": 0.6, # all 3 transformers in the 0th block in the up-part will use scale 0.6
|
||||
"block_1": [0.4, 0.8, 1.0], # the 3 transformers in the 1st block in the up-part will use scales 0.4, 0.8 and 1.0 respectively
|
||||
}
|
||||
}
|
||||
}
|
||||
pipe.set_adapters("my_adapter", scales)
|
||||
```
|
||||
|
||||
This also works with multiple adapters - see [this guide](https://huggingface.co/docs/diffusers/tutorials/using_peft_for_inference#customize-adapters-strength) for how to do it.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Currently, [`~loaders.LoraLoaderMixin.set_adapters`] only supports scaling attention weights. If a LoRA has other parts (e.g., resnets or down-/upsamplers), they will keep a scale of 1.0.
|
||||
|
||||
</Tip>
|
||||
|
||||
### Kohya and TheLastBen
|
||||
|
||||
Other popular LoRA trainers from the community include those by [Kohya](https://github.com/kohya-ss/sd-scripts/) and [TheLastBen](https://github.com/TheLastBen/fast-stable-diffusion). These trainers create different LoRA checkpoints than those trained by 🤗 Diffusers, but they can still be loaded in the same way.
|
||||
|
||||
+1
-1
@@ -42,7 +42,7 @@ Training examples show how to pretrain or fine-tune diffusion models for a varie
|
||||
| [**Dreambooth**](./dreambooth) | ✅ | - | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb)
|
||||
| [**ControlNet**](./controlnet) | ✅ | ✅ | -
|
||||
| [**InstructPix2Pix**](./instruct_pix2pix) | ✅ | ✅ | -
|
||||
| [**Reinforcement Learning for Control**](./reinforcement_learning) | - | - | coming soon.
|
||||
| [**Reinforcement Learning for Control**](https://github.com/huggingface/diffusers/blob/main/examples/reinforcement_learning/run_diffusers_locomotion.py) | - | - | coming soon.
|
||||
|
||||
## Community
|
||||
|
||||
|
||||
@@ -308,6 +308,6 @@ accelerate launch train_dreambooth_lora_sdxl_advanced.py \
|
||||
Check out [these recommended practices](https://huggingface.co/blog/sdxl_lora_advanced_script#additional-good-practices)
|
||||
|
||||
## Running on Colab Notebook
|
||||
Check out [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_Dreambooth_LoRA_advanced_example.ipynb).
|
||||
Check out [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_advanced_example.ipynb).
|
||||
to train using the advanced features (including pivotal tuning), and [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb) to train on a free colab, using some of the advanced features (excluding pivotal tuning)
|
||||
|
||||
|
||||
@@ -23,7 +23,6 @@ import os
|
||||
import re
|
||||
import shutil
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
@@ -657,6 +656,7 @@ def parse_args(input_args=None):
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_dora",
|
||||
type=bool,
|
||||
action="store_true",
|
||||
default=False,
|
||||
help=(
|
||||
@@ -1845,12 +1845,7 @@ def main(args):
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
pipeline_args = {"prompt": args.validation_prompt}
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
with torch.cuda.amp.autocast():
|
||||
images = [
|
||||
pipeline(**pipeline_args, generator=generator).images[0]
|
||||
for _ in range(args.num_validation_images)
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import gc
|
||||
import hashlib
|
||||
import itertools
|
||||
@@ -25,7 +26,6 @@ import random
|
||||
import re
|
||||
import shutil
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
@@ -2192,12 +2192,13 @@ def main(args):
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
pipeline_args = {"prompt": args.validation_prompt}
|
||||
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
inference_ctx = (
|
||||
contextlib.nullcontext()
|
||||
if "playground" in args.pretrained_model_name_or_path
|
||||
else torch.cuda.amp.autocast()
|
||||
)
|
||||
|
||||
with autocast_ctx:
|
||||
with inference_ctx:
|
||||
images = [
|
||||
pipeline(**pipeline_args, generator=generator).images[0]
|
||||
for _ in range(args.num_validation_images)
|
||||
|
||||
@@ -430,9 +430,6 @@ def main(args):
|
||||
log_with=args.report_to,
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
if accelerator.is_main_process:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
+67
-374
File diff suppressed because it is too large
Load Diff
@@ -103,7 +103,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
|
||||
print(f"Combining with alpha={alpha}, interpolation mode={interp}")
|
||||
|
||||
checkpoint_count = len(pretrained_model_name_or_path_list)
|
||||
# Ignore result from model_index_json comparison of the two checkpoints
|
||||
# Ignore result from model_index_json comparision of the two checkpoints
|
||||
force = kwargs.pop("force", False)
|
||||
|
||||
# If less than 2 checkpoints, nothing to merge. If more than 3, not supported for now.
|
||||
@@ -217,7 +217,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
|
||||
]
|
||||
checkpoint_path_2 = files[0] if len(files) > 0 else None
|
||||
# For an attr if both checkpoint_path_1 and 2 are None, ignore.
|
||||
# If at least one is present, deal with it according to interp method, of course only if the state_dict keys match.
|
||||
# If atleast one is present, deal with it according to interp method, of course only if the state_dict keys match.
|
||||
if checkpoint_path_1 is None and checkpoint_path_2 is None:
|
||||
print(f"Skipping {attr}: not present in 2nd or 3d model")
|
||||
continue
|
||||
|
||||
@@ -1,994 +0,0 @@
|
||||
import math
|
||||
import numbers
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from diffusers.image_processor import PipelineImageInput
|
||||
from diffusers.models import AsymmetricAutoencoderKL, ImageProjection
|
||||
from diffusers.models.attention_processor import Attention, AttnProcessor
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import (
|
||||
StableDiffusionInpaintPipeline,
|
||||
retrieve_timesteps,
|
||||
)
|
||||
from diffusers.utils import deprecate
|
||||
|
||||
|
||||
class RASGAttnProcessor:
|
||||
def __init__(self, mask, token_idx, scale_factor):
|
||||
self.attention_scores = None # Stores the last output of the similarity matrix here. Each layer will get its own RASGAttnProcessor assigned
|
||||
self.mask = mask
|
||||
self.token_idx = token_idx
|
||||
self.scale_factor = scale_factor
|
||||
self.mask_resoltuion = mask.shape[-1] * mask.shape[-2] # 64 x 64 if the image is 512x512
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
scale: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
# Same as the default AttnProcessor up untill the part where similarity matrix gets saved
|
||||
downscale_factor = self.mask_resoltuion // hidden_states.shape[1]
|
||||
residual = hidden_states
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
query = attn.head_to_batch_dim(query)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
# Automatically recognize the resolution and save the attention similarity values
|
||||
# We need to use the values before the softmax function, hence the rewritten get_attention_scores function.
|
||||
if downscale_factor == self.scale_factor**2:
|
||||
self.attention_scores = get_attention_scores(attn, query, key, attention_mask)
|
||||
attention_probs = self.attention_scores.softmax(dim=-1)
|
||||
attention_probs = attention_probs.to(query.dtype)
|
||||
else:
|
||||
attention_probs = attn.get_attention_scores(query, key, attention_mask) # Original code
|
||||
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PAIntAAttnProcessor:
|
||||
def __init__(self, transformer_block, mask, token_idx, do_classifier_free_guidance, scale_factors):
|
||||
self.transformer_block = transformer_block # Stores the parent transformer block.
|
||||
self.mask = mask
|
||||
self.scale_factors = scale_factors
|
||||
self.do_classifier_free_guidance = do_classifier_free_guidance
|
||||
self.token_idx = token_idx
|
||||
self.shape = mask.shape[2:]
|
||||
self.mask_resoltuion = mask.shape[-1] * mask.shape[-2] # 64 x 64
|
||||
self.default_processor = AttnProcessor()
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
scale: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
# Automatically recognize the resolution of the current attention layer and resize the masks accordingly
|
||||
downscale_factor = self.mask_resoltuion // hidden_states.shape[1]
|
||||
|
||||
mask = None
|
||||
for factor in self.scale_factors:
|
||||
if downscale_factor == factor**2:
|
||||
shape = (self.shape[0] // factor, self.shape[1] // factor)
|
||||
mask = F.interpolate(self.mask, shape, mode="bicubic") # B, 1, H, W
|
||||
break
|
||||
if mask is None:
|
||||
return self.default_processor(attn, hidden_states, encoder_hidden_states, attention_mask, temb, scale)
|
||||
|
||||
# STARTS HERE
|
||||
residual = hidden_states
|
||||
# Save the input hidden_states for later use
|
||||
input_hidden_states = hidden_states
|
||||
|
||||
# ================================================== #
|
||||
# =============== SELF ATTENTION 1 ================= #
|
||||
# ================================================== #
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
query = attn.head_to_batch_dim(query)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
# self_attention_probs = attn.get_attention_scores(query, key, attention_mask) # We can't use post-softmax attention scores in this case
|
||||
self_attention_scores = get_attention_scores(
|
||||
attn, query, key, attention_mask
|
||||
) # The custom function returns pre-softmax probabilities
|
||||
self_attention_probs = self_attention_scores.softmax(
|
||||
dim=-1
|
||||
) # Manually compute the probabilities here, the scores will be reused in the second part of PAIntA
|
||||
self_attention_probs = self_attention_probs.to(query.dtype)
|
||||
|
||||
hidden_states = torch.bmm(self_attention_probs, value)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
# x = x + self.attn1(self.norm1(x))
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection: # So many residuals everywhere
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
self_attention_output_hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
# ================================================== #
|
||||
# ============ BasicTransformerBlock =============== #
|
||||
# ================================================== #
|
||||
# We use a hack by running the code from the BasicTransformerBlock that is between Self and Cross attentions here
|
||||
# The other option would've been modifying the BasicTransformerBlock and adding this functionality here.
|
||||
# I assumed that changing the BasicTransformerBlock would have been a bigger deal and decided to use this hack isntead.
|
||||
|
||||
# The SelfAttention block recieves the normalized latents from the BasicTransformerBlock,
|
||||
# But the residual of the output is the non-normalized version.
|
||||
# Therefore we unnormalize the input hidden state here
|
||||
unnormalized_input_hidden_states = (
|
||||
input_hidden_states + self.transformer_block.norm1.bias
|
||||
) * self.transformer_block.norm1.weight
|
||||
|
||||
# TODO: return if neccessary
|
||||
# if self.use_ada_layer_norm_zero:
|
||||
# attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
# elif self.use_ada_layer_norm_single:
|
||||
# attn_output = gate_msa * attn_output
|
||||
|
||||
transformer_hidden_states = self_attention_output_hidden_states + unnormalized_input_hidden_states
|
||||
if transformer_hidden_states.ndim == 4:
|
||||
transformer_hidden_states = transformer_hidden_states.squeeze(1)
|
||||
|
||||
# TODO: return if neccessary
|
||||
# 2.5 GLIGEN Control
|
||||
# if gligen_kwargs is not None:
|
||||
# transformer_hidden_states = self.fuser(transformer_hidden_states, gligen_kwargs["objs"])
|
||||
# NOTE: we experimented with using GLIGEN and HDPainter together, the results were not that great
|
||||
|
||||
# 3. Cross-Attention
|
||||
if self.transformer_block.use_ada_layer_norm:
|
||||
# transformer_norm_hidden_states = self.transformer_block.norm2(transformer_hidden_states, timestep)
|
||||
raise NotImplementedError()
|
||||
elif self.transformer_block.use_ada_layer_norm_zero or self.transformer_block.use_layer_norm:
|
||||
transformer_norm_hidden_states = self.transformer_block.norm2(transformer_hidden_states)
|
||||
elif self.transformer_block.use_ada_layer_norm_single:
|
||||
# For PixArt norm2 isn't applied here:
|
||||
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
|
||||
transformer_norm_hidden_states = transformer_hidden_states
|
||||
elif self.transformer_block.use_ada_layer_norm_continuous:
|
||||
# transformer_norm_hidden_states = self.transformer_block.norm2(transformer_hidden_states, added_cond_kwargs["pooled_text_emb"])
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
raise ValueError("Incorrect norm")
|
||||
|
||||
if self.transformer_block.pos_embed is not None and self.transformer_block.use_ada_layer_norm_single is False:
|
||||
transformer_norm_hidden_states = self.transformer_block.pos_embed(transformer_norm_hidden_states)
|
||||
|
||||
# ================================================== #
|
||||
# ================= CROSS ATTENTION ================ #
|
||||
# ================================================== #
|
||||
|
||||
# We do an initial pass of the CrossAttention up to obtaining the similarity matrix here.
|
||||
# The similarity matrix is used to obtain scaling coefficients for the attention matrix of the self attention
|
||||
# We reuse the previously computed self-attention matrix, and only repeat the steps after the softmax
|
||||
|
||||
cross_attention_input_hidden_states = (
|
||||
transformer_norm_hidden_states # Renaming the variable for the sake of readability
|
||||
)
|
||||
|
||||
# TODO: check if classifier_free_guidance is being used before splitting here
|
||||
if self.do_classifier_free_guidance:
|
||||
# Our scaling coefficients depend only on the conditional part, so we split the inputs
|
||||
(
|
||||
_cross_attention_input_hidden_states_unconditional,
|
||||
cross_attention_input_hidden_states_conditional,
|
||||
) = cross_attention_input_hidden_states.chunk(2)
|
||||
|
||||
# Same split for the encoder_hidden_states i.e. the tokens
|
||||
# Since the SelfAttention processors don't get the encoder states as input, we inject them into the processor in the begining.
|
||||
_encoder_hidden_states_unconditional, encoder_hidden_states_conditional = self.encoder_hidden_states.chunk(
|
||||
2
|
||||
)
|
||||
else:
|
||||
cross_attention_input_hidden_states_conditional = cross_attention_input_hidden_states
|
||||
encoder_hidden_states_conditional = self.encoder_hidden_states.chunk(2)
|
||||
|
||||
# Rename the variables for the sake of readability
|
||||
# The part below is the beginning of the __call__ function of the following CrossAttention layer
|
||||
cross_attention_hidden_states = cross_attention_input_hidden_states_conditional
|
||||
cross_attention_encoder_hidden_states = encoder_hidden_states_conditional
|
||||
|
||||
attn2 = self.transformer_block.attn2
|
||||
|
||||
if attn2.spatial_norm is not None:
|
||||
cross_attention_hidden_states = attn2.spatial_norm(cross_attention_hidden_states, temb)
|
||||
|
||||
input_ndim = cross_attention_hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = cross_attention_hidden_states.shape
|
||||
cross_attention_hidden_states = cross_attention_hidden_states.view(
|
||||
batch_size, channel, height * width
|
||||
).transpose(1, 2)
|
||||
|
||||
(
|
||||
batch_size,
|
||||
sequence_length,
|
||||
_,
|
||||
) = cross_attention_hidden_states.shape # It is definitely a cross attention, so no need for an if block
|
||||
# TODO: change the attention_mask here
|
||||
attention_mask = attn2.prepare_attention_mask(
|
||||
None, sequence_length, batch_size
|
||||
) # I assume the attention mask is the same...
|
||||
|
||||
if attn2.group_norm is not None:
|
||||
cross_attention_hidden_states = attn2.group_norm(cross_attention_hidden_states.transpose(1, 2)).transpose(
|
||||
1, 2
|
||||
)
|
||||
|
||||
query2 = attn2.to_q(cross_attention_hidden_states)
|
||||
|
||||
if attn2.norm_cross:
|
||||
cross_attention_encoder_hidden_states = attn2.norm_encoder_hidden_states(
|
||||
cross_attention_encoder_hidden_states
|
||||
)
|
||||
|
||||
key2 = attn2.to_k(cross_attention_encoder_hidden_states)
|
||||
query2 = attn2.head_to_batch_dim(query2)
|
||||
key2 = attn2.head_to_batch_dim(key2)
|
||||
|
||||
cross_attention_probs = attn2.get_attention_scores(query2, key2, attention_mask)
|
||||
|
||||
# CrossAttention ends here, the remaining part is not used
|
||||
|
||||
# ================================================== #
|
||||
# ================ SELF ATTENTION 2 ================ #
|
||||
# ================================================== #
|
||||
# DEJA VU!
|
||||
|
||||
mask = (mask > 0.5).to(self_attention_output_hidden_states.dtype)
|
||||
m = mask.to(self_attention_output_hidden_states.device)
|
||||
# m = rearrange(m, 'b c h w -> b (h w) c').contiguous()
|
||||
m = m.permute(0, 2, 3, 1).reshape((m.shape[0], -1, m.shape[1])).contiguous() # B HW 1
|
||||
m = torch.matmul(m, m.permute(0, 2, 1)) + (1 - m)
|
||||
|
||||
# # Compute scaling coefficients for the similarity matrix
|
||||
# # Select the cross attention values for the correct tokens only!
|
||||
# cross_attention_probs = cross_attention_probs.mean(dim = 0)
|
||||
# cross_attention_probs = cross_attention_probs[:, self.token_idx].sum(dim=1)
|
||||
|
||||
# cross_attention_probs = cross_attention_probs.reshape(shape)
|
||||
# gaussian_smoothing = GaussianSmoothing(channels=1, kernel_size=3, sigma=0.5, dim=2).to(self_attention_output_hidden_states.device)
|
||||
# cross_attention_probs = gaussian_smoothing(cross_attention_probs.unsqueeze(0))[0] # optional smoothing
|
||||
# cross_attention_probs = cross_attention_probs.reshape(-1)
|
||||
# cross_attention_probs = ((cross_attention_probs - torch.median(cross_attention_probs.ravel())) / torch.max(cross_attention_probs.ravel())).clip(0, 1)
|
||||
|
||||
# c = (1 - m) * cross_attention_probs.reshape(1, 1, -1) + m # PAIntA scaling coefficients
|
||||
|
||||
# Compute scaling coefficients for the similarity matrix
|
||||
# Select the cross attention values for the correct tokens only!
|
||||
|
||||
batch_size, dims, channels = cross_attention_probs.shape
|
||||
batch_size = batch_size // attn.heads
|
||||
cross_attention_probs = cross_attention_probs.reshape((batch_size, attn.heads, dims, channels)) # B, D, HW, T
|
||||
|
||||
cross_attention_probs = cross_attention_probs.mean(dim=1) # B, HW, T
|
||||
cross_attention_probs = cross_attention_probs[..., self.token_idx].sum(dim=-1) # B, HW
|
||||
cross_attention_probs = cross_attention_probs.reshape((batch_size,) + shape) # , B, H, W
|
||||
|
||||
gaussian_smoothing = GaussianSmoothing(channels=1, kernel_size=3, sigma=0.5, dim=2).to(
|
||||
self_attention_output_hidden_states.device
|
||||
)
|
||||
cross_attention_probs = gaussian_smoothing(cross_attention_probs[:, None])[:, 0] # optional smoothing B, H, W
|
||||
|
||||
# Median normalization
|
||||
cross_attention_probs = cross_attention_probs.reshape(batch_size, -1) # B, HW
|
||||
cross_attention_probs = (
|
||||
cross_attention_probs - cross_attention_probs.median(dim=-1, keepdim=True).values
|
||||
) / cross_attention_probs.max(dim=-1, keepdim=True).values
|
||||
cross_attention_probs = cross_attention_probs.clip(0, 1)
|
||||
|
||||
c = (1 - m) * cross_attention_probs.reshape(batch_size, 1, -1) + m
|
||||
c = c.repeat_interleave(attn.heads, 0) # BD, HW
|
||||
if self.do_classifier_free_guidance:
|
||||
c = torch.cat([c, c]) # 2BD, HW
|
||||
|
||||
# Rescaling the original self-attention matrix
|
||||
self_attention_scores_rescaled = self_attention_scores * c
|
||||
self_attention_probs_rescaled = self_attention_scores_rescaled.softmax(dim=-1)
|
||||
|
||||
# Continuing the self attention normally using the new matrix
|
||||
hidden_states = torch.bmm(self_attention_probs_rescaled, value)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + input_hidden_states
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class StableDiffusionHDPainterPipeline(StableDiffusionInpaintPipeline):
|
||||
def get_tokenized_prompt(self, prompt):
|
||||
out = self.tokenizer(prompt)
|
||||
return [self.tokenizer.decode(x) for x in out["input_ids"]]
|
||||
|
||||
def init_attn_processors(
|
||||
self,
|
||||
mask,
|
||||
token_idx,
|
||||
use_painta=True,
|
||||
use_rasg=True,
|
||||
painta_scale_factors=[2, 4], # 64x64 -> [16x16, 32x32]
|
||||
rasg_scale_factor=4, # 64x64 -> 16x16
|
||||
self_attention_layer_name="attn1",
|
||||
cross_attention_layer_name="attn2",
|
||||
list_of_painta_layer_names=None,
|
||||
list_of_rasg_layer_names=None,
|
||||
):
|
||||
default_processor = AttnProcessor()
|
||||
width, height = mask.shape[-2:]
|
||||
width, height = width // self.vae_scale_factor, height // self.vae_scale_factor
|
||||
|
||||
painta_scale_factors = [x * self.vae_scale_factor for x in painta_scale_factors]
|
||||
rasg_scale_factor = self.vae_scale_factor * rasg_scale_factor
|
||||
|
||||
attn_processors = {}
|
||||
for x in self.unet.attn_processors:
|
||||
if (list_of_painta_layer_names is None and self_attention_layer_name in x) or (
|
||||
list_of_painta_layer_names is not None and x in list_of_painta_layer_names
|
||||
):
|
||||
if use_painta:
|
||||
transformer_block = self.unet.get_submodule(x.replace(".attn1.processor", ""))
|
||||
attn_processors[x] = PAIntAAttnProcessor(
|
||||
transformer_block, mask, token_idx, self.do_classifier_free_guidance, painta_scale_factors
|
||||
)
|
||||
else:
|
||||
attn_processors[x] = default_processor
|
||||
elif (list_of_rasg_layer_names is None and cross_attention_layer_name in x) or (
|
||||
list_of_rasg_layer_names is not None and x in list_of_rasg_layer_names
|
||||
):
|
||||
if use_rasg:
|
||||
attn_processors[x] = RASGAttnProcessor(mask, token_idx, rasg_scale_factor)
|
||||
else:
|
||||
attn_processors[x] = default_processor
|
||||
|
||||
self.unet.set_attn_processor(attn_processors)
|
||||
# import json
|
||||
# with open('/home/hayk.manukyan/repos/diffusers/debug.txt', 'a') as f:
|
||||
# json.dump({x:str(y) for x,y in self.unet.attn_processors.items()}, f, indent=4)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: PipelineImageInput = None,
|
||||
mask_image: PipelineImageInput = None,
|
||||
masked_image_latents: torch.FloatTensor = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
padding_mask_crop: Optional[int] = None,
|
||||
strength: float = 1.0,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 7.5,
|
||||
positive_prompt: Optional[str] = "",
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.01,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
clip_skip: int = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
use_painta=True,
|
||||
use_rasg=True,
|
||||
self_attention_layer_name=".attn1",
|
||||
cross_attention_layer_name=".attn2",
|
||||
painta_scale_factors=[2, 4], # 16 x 16 and 32 x 32
|
||||
rasg_scale_factor=4, # 16x16 by default
|
||||
list_of_painta_layer_names=None,
|
||||
list_of_rasg_layer_names=None,
|
||||
**kwargs,
|
||||
):
|
||||
callback = kwargs.pop("callback", None)
|
||||
callback_steps = kwargs.pop("callback_steps", None)
|
||||
|
||||
if callback is not None:
|
||||
deprecate(
|
||||
"callback",
|
||||
"1.0.0",
|
||||
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
||||
)
|
||||
if callback_steps is not None:
|
||||
deprecate(
|
||||
"callback_steps",
|
||||
"1.0.0",
|
||||
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
||||
)
|
||||
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
#
|
||||
prompt_no_positives = prompt
|
||||
if isinstance(prompt, list):
|
||||
prompt = [x + positive_prompt for x in prompt]
|
||||
else:
|
||||
prompt = prompt + positive_prompt
|
||||
|
||||
# 1. Check inputs
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
image,
|
||||
mask_image,
|
||||
height,
|
||||
width,
|
||||
strength,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
padding_mask_crop,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# assert batch_size == 1, "Does not work with batch size > 1 currently"
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
self.do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
clip_skip=self.clip_skip,
|
||||
)
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
if ip_adapter_image is not None:
|
||||
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
||||
image_embeds, negative_image_embeds = self.encode_image(
|
||||
ip_adapter_image, device, num_images_per_prompt, output_hidden_state
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
||||
|
||||
# 4. set timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
timesteps, num_inference_steps = self.get_timesteps(
|
||||
num_inference_steps=num_inference_steps, strength=strength, device=device
|
||||
)
|
||||
# check that number of inference steps is not < 1 - as this doesn't make sense
|
||||
if num_inference_steps < 1:
|
||||
raise ValueError(
|
||||
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
|
||||
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
|
||||
)
|
||||
# at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
|
||||
is_strength_max = strength == 1.0
|
||||
|
||||
# 5. Preprocess mask and image
|
||||
|
||||
if padding_mask_crop is not None:
|
||||
crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
|
||||
resize_mode = "fill"
|
||||
else:
|
||||
crops_coords = None
|
||||
resize_mode = "default"
|
||||
|
||||
original_image = image
|
||||
init_image = self.image_processor.preprocess(
|
||||
image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
|
||||
)
|
||||
init_image = init_image.to(dtype=torch.float32)
|
||||
|
||||
# 6. Prepare latent variables
|
||||
num_channels_latents = self.vae.config.latent_channels
|
||||
num_channels_unet = self.unet.config.in_channels
|
||||
return_image_latents = num_channels_unet == 4
|
||||
|
||||
latents_outputs = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
image=init_image,
|
||||
timestep=latent_timestep,
|
||||
is_strength_max=is_strength_max,
|
||||
return_noise=True,
|
||||
return_image_latents=return_image_latents,
|
||||
)
|
||||
|
||||
if return_image_latents:
|
||||
latents, noise, image_latents = latents_outputs
|
||||
else:
|
||||
latents, noise = latents_outputs
|
||||
|
||||
# 7. Prepare mask latent variables
|
||||
mask_condition = self.mask_processor.preprocess(
|
||||
mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
|
||||
)
|
||||
|
||||
if masked_image_latents is None:
|
||||
masked_image = init_image * (mask_condition < 0.5)
|
||||
else:
|
||||
masked_image = masked_image_latents
|
||||
|
||||
mask, masked_image_latents = self.prepare_mask_latents(
|
||||
mask_condition,
|
||||
masked_image,
|
||||
batch_size * num_images_per_prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
self.do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
# 7.5 Setting up HD-Painter
|
||||
|
||||
# Get the indices of the tokens to be modified by both RASG and PAIntA
|
||||
token_idx = list(range(1, self.get_tokenized_prompt(prompt_no_positives).index("<|endoftext|>"))) + [
|
||||
self.get_tokenized_prompt(prompt).index("<|endoftext|>")
|
||||
]
|
||||
|
||||
# Setting up the attention processors
|
||||
self.init_attn_processors(
|
||||
mask_condition,
|
||||
token_idx,
|
||||
use_painta,
|
||||
use_rasg,
|
||||
painta_scale_factors=painta_scale_factors,
|
||||
rasg_scale_factor=rasg_scale_factor,
|
||||
self_attention_layer_name=self_attention_layer_name,
|
||||
cross_attention_layer_name=cross_attention_layer_name,
|
||||
list_of_painta_layer_names=list_of_painta_layer_names,
|
||||
list_of_rasg_layer_names=list_of_rasg_layer_names,
|
||||
)
|
||||
|
||||
# 8. Check that sizes of mask, masked image and latents match
|
||||
if num_channels_unet == 9:
|
||||
# default case for runwayml/stable-diffusion-inpainting
|
||||
num_channels_mask = mask.shape[1]
|
||||
num_channels_masked_image = masked_image_latents.shape[1]
|
||||
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
|
||||
raise ValueError(
|
||||
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
||||
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
||||
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
||||
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
||||
" `pipeline.unet` or your `mask_image` or `image` input."
|
||||
)
|
||||
elif num_channels_unet != 4:
|
||||
raise ValueError(
|
||||
f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
|
||||
)
|
||||
|
||||
# 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
if use_rasg:
|
||||
extra_step_kwargs["generator"] = None
|
||||
|
||||
# 9.1 Add image embeds for IP-Adapter
|
||||
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
|
||||
|
||||
# 9.2 Optionally get Guidance Scale Embedding
|
||||
timestep_cond = None
|
||||
if self.unet.config.time_cond_proj_dim is not None:
|
||||
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
|
||||
timestep_cond = self.get_guidance_scale_embedding(
|
||||
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
||||
).to(device=device, dtype=latents.dtype)
|
||||
|
||||
# 10. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
painta_active = True
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
if t < 500 and painta_active:
|
||||
self.init_attn_processors(
|
||||
mask_condition,
|
||||
token_idx,
|
||||
False,
|
||||
use_rasg,
|
||||
painta_scale_factors=painta_scale_factors,
|
||||
rasg_scale_factor=rasg_scale_factor,
|
||||
self_attention_layer_name=self_attention_layer_name,
|
||||
cross_attention_layer_name=cross_attention_layer_name,
|
||||
list_of_painta_layer_names=list_of_painta_layer_names,
|
||||
list_of_rasg_layer_names=list_of_rasg_layer_names,
|
||||
)
|
||||
painta_active = False
|
||||
|
||||
with torch.enable_grad():
|
||||
self.unet.zero_grad()
|
||||
latents = latents.detach()
|
||||
latents.requires_grad = True
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
|
||||
# concat latents, mask, masked_image_latents in the channel dimension
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
if num_channels_unet == 9:
|
||||
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
|
||||
|
||||
self.scheduler.latents = latents
|
||||
self.encoder_hidden_states = prompt_embeds
|
||||
for attn_processor in self.unet.attn_processors.values():
|
||||
attn_processor.encoder_hidden_states = prompt_embeds
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep_cond=timestep_cond,
|
||||
cross_attention_kwargs=self.cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if use_rasg:
|
||||
# Perform RASG
|
||||
_, _, height, width = mask_condition.shape # 512 x 512
|
||||
scale_factor = self.vae_scale_factor * rasg_scale_factor # 8 * 4 = 32
|
||||
|
||||
# TODO: Fix for > 1 batch_size
|
||||
rasg_mask = F.interpolate(
|
||||
mask_condition, (height // scale_factor, width // scale_factor), mode="bicubic"
|
||||
)[0, 0] # mode is nearest by default, B, H, W
|
||||
|
||||
# Aggregate the saved attention maps
|
||||
attn_map = []
|
||||
for processor in self.unet.attn_processors.values():
|
||||
if hasattr(processor, "attention_scores") and processor.attention_scores is not None:
|
||||
if self.do_classifier_free_guidance:
|
||||
attn_map.append(processor.attention_scores.chunk(2)[1]) # (B/2) x H, 256, 77
|
||||
else:
|
||||
attn_map.append(processor.attention_scores) # B x H, 256, 77 ?
|
||||
|
||||
attn_map = (
|
||||
torch.cat(attn_map)
|
||||
.mean(0)
|
||||
.permute(1, 0)
|
||||
.reshape((-1, height // scale_factor, width // scale_factor))
|
||||
) # 77, 16, 16
|
||||
|
||||
# Compute the attention score
|
||||
attn_score = -sum(
|
||||
[
|
||||
F.binary_cross_entropy_with_logits(x - 1.0, rasg_mask.to(device))
|
||||
for x in attn_map[token_idx]
|
||||
]
|
||||
)
|
||||
|
||||
# Backward the score and compute the gradients
|
||||
attn_score.backward()
|
||||
|
||||
# Normalzie the gradients and compute the noise component
|
||||
variance_noise = latents.grad.detach()
|
||||
# print("VARIANCE SHAPE", variance_noise.shape)
|
||||
variance_noise -= torch.mean(variance_noise, [1, 2, 3], keepdim=True)
|
||||
variance_noise /= torch.std(variance_noise, [1, 2, 3], keepdim=True)
|
||||
else:
|
||||
variance_noise = None
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents, **extra_step_kwargs, return_dict=False, variance_noise=variance_noise
|
||||
)[0]
|
||||
|
||||
if num_channels_unet == 4:
|
||||
init_latents_proper = image_latents
|
||||
if self.do_classifier_free_guidance:
|
||||
init_mask, _ = mask.chunk(2)
|
||||
else:
|
||||
init_mask = mask
|
||||
|
||||
if i < len(timesteps) - 1:
|
||||
noise_timestep = timesteps[i + 1]
|
||||
init_latents_proper = self.scheduler.add_noise(
|
||||
init_latents_proper, noise, torch.tensor([noise_timestep])
|
||||
)
|
||||
|
||||
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
mask = callback_outputs.pop("mask", mask)
|
||||
masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
if not output_type == "latent":
|
||||
condition_kwargs = {}
|
||||
if isinstance(self.vae, AsymmetricAutoencoderKL):
|
||||
init_image = init_image.to(device=device, dtype=masked_image_latents.dtype)
|
||||
init_image_condition = init_image.clone()
|
||||
init_image = self._encode_vae_image(init_image, generator=generator)
|
||||
mask_condition = mask_condition.to(device=device, dtype=masked_image_latents.dtype)
|
||||
condition_kwargs = {"image": init_image_condition, "mask": mask_condition}
|
||||
image = self.vae.decode(
|
||||
latents / self.vae.config.scaling_factor, return_dict=False, generator=generator, **condition_kwargs
|
||||
)[0]
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
else:
|
||||
image = latents
|
||||
has_nsfw_concept = None
|
||||
|
||||
if has_nsfw_concept is None:
|
||||
do_denormalize = [True] * image.shape[0]
|
||||
else:
|
||||
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
||||
|
||||
if padding_mask_crop is not None:
|
||||
image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
|
||||
|
||||
# ============= Utility Functions ============== #
|
||||
|
||||
|
||||
class GaussianSmoothing(nn.Module):
|
||||
"""
|
||||
Apply gaussian smoothing on a
|
||||
1d, 2d or 3d tensor. Filtering is performed seperately for each channel
|
||||
in the input using a depthwise convolution.
|
||||
Arguments:
|
||||
channels (int, sequence): Number of channels of the input tensors. Output will
|
||||
have this number of channels as well.
|
||||
kernel_size (int, sequence): Size of the gaussian kernel.
|
||||
sigma (float, sequence): Standard deviation of the gaussian kernel.
|
||||
dim (int, optional): The number of dimensions of the data.
|
||||
Default value is 2 (spatial).
|
||||
"""
|
||||
|
||||
def __init__(self, channels, kernel_size, sigma, dim=2):
|
||||
super(GaussianSmoothing, self).__init__()
|
||||
if isinstance(kernel_size, numbers.Number):
|
||||
kernel_size = [kernel_size] * dim
|
||||
if isinstance(sigma, numbers.Number):
|
||||
sigma = [sigma] * dim
|
||||
|
||||
# The gaussian kernel is the product of the
|
||||
# gaussian function of each dimension.
|
||||
kernel = 1
|
||||
meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32) for size in kernel_size])
|
||||
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
|
||||
mean = (size - 1) / 2
|
||||
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-(((mgrid - mean) / (2 * std)) ** 2))
|
||||
|
||||
# Make sure sum of values in gaussian kernel equals 1.
|
||||
kernel = kernel / torch.sum(kernel)
|
||||
|
||||
# Reshape to depthwise convolutional weight
|
||||
kernel = kernel.view(1, 1, *kernel.size())
|
||||
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
|
||||
|
||||
self.register_buffer("weight", kernel)
|
||||
self.groups = channels
|
||||
|
||||
if dim == 1:
|
||||
self.conv = F.conv1d
|
||||
elif dim == 2:
|
||||
self.conv = F.conv2d
|
||||
elif dim == 3:
|
||||
self.conv = F.conv3d
|
||||
else:
|
||||
raise RuntimeError("Only 1, 2 and 3 dimensions are supported. Received {}.".format(dim))
|
||||
|
||||
def forward(self, input):
|
||||
"""
|
||||
Apply gaussian filter to input.
|
||||
Arguments:
|
||||
input (torch.Tensor): Input to apply gaussian filter on.
|
||||
Returns:
|
||||
filtered (torch.Tensor): Filtered output.
|
||||
"""
|
||||
return self.conv(input, weight=self.weight.to(input.dtype), groups=self.groups, padding="same")
|
||||
|
||||
|
||||
def get_attention_scores(
|
||||
self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Compute the attention scores.
|
||||
|
||||
Args:
|
||||
query (`torch.Tensor`): The query tensor.
|
||||
key (`torch.Tensor`): The key tensor.
|
||||
attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The attention probabilities/scores.
|
||||
"""
|
||||
if self.upcast_attention:
|
||||
query = query.float()
|
||||
key = key.float()
|
||||
|
||||
if attention_mask is None:
|
||||
baddbmm_input = torch.empty(
|
||||
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
|
||||
)
|
||||
beta = 0
|
||||
else:
|
||||
baddbmm_input = attention_mask
|
||||
beta = 1
|
||||
|
||||
attention_scores = torch.baddbmm(
|
||||
baddbmm_input,
|
||||
query,
|
||||
key.transpose(-1, -2),
|
||||
beta=beta,
|
||||
alpha=self.scale,
|
||||
)
|
||||
del baddbmm_input
|
||||
|
||||
if self.upcast_softmax:
|
||||
attention_scores = attention_scores.float()
|
||||
|
||||
return attention_scores
|
||||
@@ -726,7 +726,7 @@ class LatentConsistencyModelWalkPipeline(
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
`._callback_tensor_inputs` attribute of your pipeine class.
|
||||
embedding_interpolation_type (`str`, *optional*, defaults to `"lerp"`):
|
||||
The type of interpolation to use for interpolating between text embeddings. Choose between `"lerp"` and `"slerp"`.
|
||||
latent_interpolation_type (`str`, *optional*, defaults to `"slerp"`):
|
||||
@@ -779,7 +779,7 @@ class LatentConsistencyModelWalkPipeline(
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
if batch_size < 2:
|
||||
raise ValueError(f"`prompt` must have length of at least 2 but found {batch_size}")
|
||||
raise ValueError(f"`prompt` must have length of atleast 2 but found {batch_size}")
|
||||
if num_images_per_prompt != 1:
|
||||
raise ValueError("`num_images_per_prompt` must be `1` as no other value is supported yet")
|
||||
if prompt_embeds is not None:
|
||||
@@ -883,7 +883,7 @@ class LatentConsistencyModelWalkPipeline(
|
||||
) as batch_progress_bar:
|
||||
for batch_index in range(0, bs, process_batch_size):
|
||||
batch_inference_latents = inference_latents[batch_index : batch_index + process_batch_size]
|
||||
batch_inference_embeddings = inference_embeddings[
|
||||
batch_inference_embedddings = inference_embeddings[
|
||||
batch_index : batch_index + process_batch_size
|
||||
]
|
||||
|
||||
@@ -892,7 +892,7 @@ class LatentConsistencyModelWalkPipeline(
|
||||
)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
current_bs = batch_inference_embeddings.shape[0]
|
||||
current_bs = batch_inference_embedddings.shape[0]
|
||||
w = torch.tensor(self.guidance_scale - 1).repeat(current_bs)
|
||||
w_embedding = self.get_guidance_scale_embedding(
|
||||
w, embedding_dim=self.unet.config.time_cond_proj_dim
|
||||
@@ -901,14 +901,14 @@ class LatentConsistencyModelWalkPipeline(
|
||||
# 10. Perform inference for current batch
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for index, t in enumerate(timesteps):
|
||||
batch_inference_latents = batch_inference_latents.to(batch_inference_embeddings.dtype)
|
||||
batch_inference_latents = batch_inference_latents.to(batch_inference_embedddings.dtype)
|
||||
|
||||
# model prediction (v-prediction, eps, x)
|
||||
model_pred = self.unet(
|
||||
batch_inference_latents,
|
||||
t,
|
||||
timestep_cond=w_embedding,
|
||||
encoder_hidden_states=batch_inference_embeddings,
|
||||
encoder_hidden_states=batch_inference_embedddings,
|
||||
cross_attention_kwargs=self.cross_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
@@ -924,8 +924,8 @@ class LatentConsistencyModelWalkPipeline(
|
||||
callback_outputs = callback_on_step_end(self, index, t, callback_kwargs)
|
||||
|
||||
batch_inference_latents = callback_outputs.pop("latents", batch_inference_latents)
|
||||
batch_inference_embeddings = callback_outputs.pop(
|
||||
"prompt_embeds", batch_inference_embeddings
|
||||
batch_inference_embedddings = callback_outputs.pop(
|
||||
"prompt_embeds", batch_inference_embedddings
|
||||
)
|
||||
w_embedding = callback_outputs.pop("w_embedding", w_embedding)
|
||||
denoised = callback_outputs.pop("denoised", denoised)
|
||||
@@ -939,7 +939,7 @@ class LatentConsistencyModelWalkPipeline(
|
||||
step_idx = index // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, batch_inference_latents)
|
||||
|
||||
denoised = denoised.to(batch_inference_embeddings.dtype)
|
||||
denoised = denoised.to(batch_inference_embedddings.dtype)
|
||||
|
||||
# Note: This is not supported because you would get black images in your latent walk if
|
||||
# NSFW concept is detected
|
||||
|
||||
@@ -530,7 +530,7 @@ class LLMGroundedDiffusionPipeline(
|
||||
)
|
||||
|
||||
if len(phrases) != len(boxes):
|
||||
raise ValueError(
|
||||
ValueError(
|
||||
"length of `phrases` and `boxes` has to be same, but"
|
||||
f" got: `phrases` {len(phrases)} != `boxes` {len(boxes)}"
|
||||
)
|
||||
|
||||
@@ -439,9 +439,7 @@ class StableDiffusionLongPromptWeightingPipeline(
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder-->unet->vae"
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -164,7 +164,7 @@ def get_prompts_tokens_with_weights(clip_tokenizer: CLIPTokenizer, prompt: str):
|
||||
text_tokens (list)
|
||||
A list contains token ids
|
||||
text_weight (list)
|
||||
A list contains the correspondent weight of token ids
|
||||
A list contains the correspodent weight of token ids
|
||||
|
||||
Example:
|
||||
import torch
|
||||
@@ -1028,7 +1028,7 @@ class SDXLLongPromptWeightingPipeline(
|
||||
# because `num_inference_steps` might be even given that every timestep
|
||||
# (except the highest one) is duplicated. If `num_inference_steps` is even it would
|
||||
# mean that we cut the timesteps in the middle of the denoising step
|
||||
# (between 1st and 2nd derivative) which leads to incorrect results. By adding 1
|
||||
# (between 1st and 2nd devirative) which leads to incorrect results. By adding 1
|
||||
# we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
|
||||
num_inference_steps = num_inference_steps + 1
|
||||
|
||||
@@ -1531,7 +1531,7 @@ class SDXLLongPromptWeightingPipeline(
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
`._callback_tensor_inputs` attribute of your pipeine class.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -2131,7 +2131,7 @@ class SDXLLongPromptWeightingPipeline(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Override to properly handle the loading and unloading of the additional text encoder.
|
||||
# Overrride to properly handle the loading and unloading of the additional text encoder.
|
||||
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
||||
# We could have accessed the unet config from `lora_state_dict()` too. We pass
|
||||
# it here explicitly to be able to tell that it's coming from an SDXL
|
||||
|
||||
@@ -18,7 +18,6 @@
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import Dict, Union
|
||||
|
||||
@@ -26,7 +25,6 @@ import matplotlib
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from PIL.Image import Resampling
|
||||
from scipy.optimize import minimize
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
from tqdm.auto import tqdm
|
||||
@@ -36,14 +34,13 @@ from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
DiffusionPipeline,
|
||||
LCMScheduler,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils import BaseOutput, check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.25.0")
|
||||
check_min_version("0.28.0.dev0")
|
||||
|
||||
|
||||
class MarigoldDepthOutput(BaseOutput):
|
||||
@@ -64,19 +61,6 @@ class MarigoldDepthOutput(BaseOutput):
|
||||
uncertainty: Union[None, np.ndarray]
|
||||
|
||||
|
||||
def get_pil_resample_method(method_str: str) -> Resampling:
|
||||
resample_method_dic = {
|
||||
"bilinear": Resampling.BILINEAR,
|
||||
"bicubic": Resampling.BICUBIC,
|
||||
"nearest": Resampling.NEAREST,
|
||||
}
|
||||
resample_method = resample_method_dic.get(method_str, None)
|
||||
if resample_method is None:
|
||||
raise ValueError(f"Unknown resampling method: {resample_method}")
|
||||
else:
|
||||
return resample_method
|
||||
|
||||
|
||||
class MarigoldPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io.
|
||||
@@ -129,9 +113,7 @@ class MarigoldPipeline(DiffusionPipeline):
|
||||
ensemble_size: int = 10,
|
||||
processing_res: int = 768,
|
||||
match_input_res: bool = True,
|
||||
resample_method: str = "bilinear",
|
||||
batch_size: int = 0,
|
||||
seed: Union[int, None] = None,
|
||||
color_map: str = "Spectral",
|
||||
show_progress_bar: bool = True,
|
||||
ensemble_kwargs: Dict = None,
|
||||
@@ -147,9 +129,7 @@ class MarigoldPipeline(DiffusionPipeline):
|
||||
If set to 0: will not resize at all.
|
||||
match_input_res (`bool`, *optional*, defaults to `True`):
|
||||
Resize depth prediction to match input resolution.
|
||||
Only valid if `processing_res` > 0.
|
||||
resample_method: (`str`, *optional*, defaults to `bilinear`):
|
||||
Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
|
||||
Only valid if `limit_input_res` is not None.
|
||||
denoising_steps (`int`, *optional*, defaults to `10`):
|
||||
Number of diffusion denoising steps (DDIM) during inference.
|
||||
ensemble_size (`int`, *optional*, defaults to `10`):
|
||||
@@ -157,8 +137,6 @@ class MarigoldPipeline(DiffusionPipeline):
|
||||
batch_size (`int`, *optional*, defaults to `0`):
|
||||
Inference batch size, no bigger than `num_ensemble`.
|
||||
If set to 0, the script will automatically decide the proper batch size.
|
||||
seed (`int`, *optional*, defaults to `None`)
|
||||
Reproducibility seed.
|
||||
show_progress_bar (`bool`, *optional*, defaults to `True`):
|
||||
Display a progress bar of diffusion denoising.
|
||||
color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation):
|
||||
@@ -168,7 +146,8 @@ class MarigoldPipeline(DiffusionPipeline):
|
||||
Returns:
|
||||
`MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including:
|
||||
- **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1]
|
||||
- **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and values in [0, 1], None if `color_map` is `None`
|
||||
- **depth_colored** (`None` or `PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and
|
||||
values in [0, 1]. None if `color_map` is `None`
|
||||
- **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
|
||||
coming from ensembling. None if `ensemble_size = 1`
|
||||
"""
|
||||
@@ -179,21 +158,13 @@ class MarigoldPipeline(DiffusionPipeline):
|
||||
if not match_input_res:
|
||||
assert processing_res is not None, "Value error: `resize_output_back` is only valid with "
|
||||
assert processing_res >= 0
|
||||
assert denoising_steps >= 1
|
||||
assert ensemble_size >= 1
|
||||
|
||||
# Check if denoising step is reasonable
|
||||
self._check_inference_step(denoising_steps)
|
||||
|
||||
resample_method: Resampling = get_pil_resample_method(resample_method)
|
||||
|
||||
# ----------------- Image Preprocess -----------------
|
||||
# Resize image
|
||||
if processing_res > 0:
|
||||
input_image = self.resize_max_res(
|
||||
input_image,
|
||||
max_edge_resolution=processing_res,
|
||||
resample_method=resample_method,
|
||||
)
|
||||
input_image = self.resize_max_res(input_image, max_edge_resolution=processing_res)
|
||||
# Convert the image to RGB, to 1.remove the alpha channel 2.convert B&W to 3-channel
|
||||
input_image = input_image.convert("RGB")
|
||||
image = np.asarray(input_image)
|
||||
@@ -232,10 +203,9 @@ class MarigoldPipeline(DiffusionPipeline):
|
||||
rgb_in=batched_img,
|
||||
num_inference_steps=denoising_steps,
|
||||
show_pbar=show_progress_bar,
|
||||
seed=seed,
|
||||
)
|
||||
depth_pred_ls.append(depth_pred_raw.detach())
|
||||
depth_preds = torch.concat(depth_pred_ls, dim=0).squeeze()
|
||||
depth_pred_ls.append(depth_pred_raw.detach().clone())
|
||||
depth_preds = torch.concat(depth_pred_ls, axis=0).squeeze()
|
||||
torch.cuda.empty_cache() # clear vram cache for ensembling
|
||||
|
||||
# ----------------- Test-time ensembling -----------------
|
||||
@@ -257,7 +227,7 @@ class MarigoldPipeline(DiffusionPipeline):
|
||||
# Resize back to original resolution
|
||||
if match_input_res:
|
||||
pred_img = Image.fromarray(depth_pred)
|
||||
pred_img = pred_img.resize(input_size, resample=resample_method)
|
||||
pred_img = pred_img.resize(input_size)
|
||||
depth_pred = np.asarray(pred_img)
|
||||
|
||||
# Clip output range
|
||||
@@ -273,32 +243,12 @@ class MarigoldPipeline(DiffusionPipeline):
|
||||
depth_colored_img = Image.fromarray(depth_colored_hwc)
|
||||
else:
|
||||
depth_colored_img = None
|
||||
|
||||
return MarigoldDepthOutput(
|
||||
depth_np=depth_pred,
|
||||
depth_colored=depth_colored_img,
|
||||
uncertainty=pred_uncert,
|
||||
)
|
||||
|
||||
def _check_inference_step(self, n_step: int):
|
||||
"""
|
||||
Check if denoising step is reasonable
|
||||
Args:
|
||||
n_step (`int`): denoising steps
|
||||
"""
|
||||
assert n_step >= 1
|
||||
|
||||
if isinstance(self.scheduler, DDIMScheduler):
|
||||
if n_step < 10:
|
||||
logging.warning(
|
||||
f"Too few denoising steps: {n_step}. Recommended to use the LCM checkpoint for few-step inference."
|
||||
)
|
||||
elif isinstance(self.scheduler, LCMScheduler):
|
||||
if not 1 <= n_step <= 4:
|
||||
logging.warning(f"Non-optimal setting of denoising steps: {n_step}. Recommended setting is 1-4 steps.")
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}")
|
||||
|
||||
def _encode_empty_text(self):
|
||||
"""
|
||||
Encode text embedding for empty prompt.
|
||||
@@ -315,13 +265,7 @@ class MarigoldPipeline(DiffusionPipeline):
|
||||
self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
|
||||
|
||||
@torch.no_grad()
|
||||
def single_infer(
|
||||
self,
|
||||
rgb_in: torch.Tensor,
|
||||
num_inference_steps: int,
|
||||
seed: Union[int, None],
|
||||
show_pbar: bool,
|
||||
) -> torch.Tensor:
|
||||
def single_infer(self, rgb_in: torch.Tensor, num_inference_steps: int, show_pbar: bool) -> torch.Tensor:
|
||||
"""
|
||||
Perform an individual depth prediction without ensembling.
|
||||
|
||||
@@ -342,20 +286,10 @@ class MarigoldPipeline(DiffusionPipeline):
|
||||
timesteps = self.scheduler.timesteps # [T]
|
||||
|
||||
# Encode image
|
||||
rgb_latent = self.encode_rgb(rgb_in)
|
||||
rgb_latent = self._encode_rgb(rgb_in)
|
||||
|
||||
# Initial depth map (noise)
|
||||
if seed is None:
|
||||
rand_num_generator = None
|
||||
else:
|
||||
rand_num_generator = torch.Generator(device=device)
|
||||
rand_num_generator.manual_seed(seed)
|
||||
depth_latent = torch.randn(
|
||||
rgb_latent.shape,
|
||||
device=device,
|
||||
dtype=self.dtype,
|
||||
generator=rand_num_generator,
|
||||
) # [B, 4, h, w]
|
||||
depth_latent = torch.randn(rgb_latent.shape, device=device, dtype=self.dtype) # [B, 4, h, w]
|
||||
|
||||
# Batched empty text embedding
|
||||
if self.empty_text_embed is None:
|
||||
@@ -380,9 +314,9 @@ class MarigoldPipeline(DiffusionPipeline):
|
||||
noise_pred = self.unet(unet_input, t, encoder_hidden_states=batch_empty_text_embed).sample # [B, 4, h, w]
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
depth_latent = self.scheduler.step(noise_pred, t, depth_latent, generator=rand_num_generator).prev_sample
|
||||
|
||||
depth = self.decode_depth(depth_latent)
|
||||
depth_latent = self.scheduler.step(noise_pred, t, depth_latent).prev_sample
|
||||
torch.cuda.empty_cache()
|
||||
depth = self._decode_depth(depth_latent)
|
||||
|
||||
# clip prediction
|
||||
depth = torch.clip(depth, -1.0, 1.0)
|
||||
@@ -391,7 +325,7 @@ class MarigoldPipeline(DiffusionPipeline):
|
||||
|
||||
return depth
|
||||
|
||||
def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
|
||||
def _encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Encode RGB image into latent.
|
||||
|
||||
@@ -410,7 +344,7 @@ class MarigoldPipeline(DiffusionPipeline):
|
||||
rgb_latent = mean * self.rgb_latent_scale_factor
|
||||
return rgb_latent
|
||||
|
||||
def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
|
||||
def _decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Decode depth latent into depth map.
|
||||
|
||||
@@ -431,7 +365,7 @@ class MarigoldPipeline(DiffusionPipeline):
|
||||
return depth_mean
|
||||
|
||||
@staticmethod
|
||||
def resize_max_res(img: Image.Image, max_edge_resolution: int, resample_method=Resampling.BILINEAR) -> Image.Image:
|
||||
def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
|
||||
"""
|
||||
Resize image to limit maximum edge length while keeping aspect ratio.
|
||||
|
||||
@@ -440,8 +374,6 @@ class MarigoldPipeline(DiffusionPipeline):
|
||||
Image to be resized.
|
||||
max_edge_resolution (`int`):
|
||||
Maximum edge length (pixel).
|
||||
resample_method (`PIL.Image.Resampling`):
|
||||
Resampling method used to resize images.
|
||||
|
||||
Returns:
|
||||
`Image.Image`: Resized image.
|
||||
@@ -452,7 +384,7 @@ class MarigoldPipeline(DiffusionPipeline):
|
||||
new_width = int(original_width * downscale_factor)
|
||||
new_height = int(original_height * downscale_factor)
|
||||
|
||||
resized_img = img.resize((new_width, new_height), resample=resample_method)
|
||||
resized_img = img.resize((new_width, new_height))
|
||||
return resized_img
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -196,7 +196,7 @@ class StableDiffusionTilingPipeline(DiffusionPipeline, StableDiffusionExtrasMixi
|
||||
guidance_scale_tiles: specific weights for classifier-free guidance in each tile.
|
||||
guidance_scale_tiles: specific weights for classifier-free guidance in each tile. If None, the value provided in guidance_scale will be used.
|
||||
seed_tiles: specific seeds for the initialization latents in each tile. These will override the latents generated for the whole canvas using the standard seed parameter.
|
||||
seed_tiles_mode: either "full" "exclusive". If "full", all the latents affected by the tile be overriden. If "exclusive", only the latents that are affected exclusively by this tile (and no other tiles) will be overriden.
|
||||
seed_tiles_mode: either "full" "exclusive". If "full", all the latents affected by the tile be overriden. If "exclusive", only the latents that are affected exclusively by this tile (and no other tiles) will be overrriden.
|
||||
seed_reroll_regions: a list of tuples in the form (start row, end row, start column, end column, seed) defining regions in pixel space for which the latents will be overriden using the given seed. Takes priority over seed_tiles.
|
||||
cpu_vae: the decoder from latent space to pixel space can require too mucho GPU RAM for large images. If you find out of memory errors at the end of the generation process, try setting this parameter to True to run the decoder in CPU. Slower, but should run without memory issues.
|
||||
|
||||
@@ -325,7 +325,7 @@ class StableDiffusionTilingPipeline(DiffusionPipeline, StableDiffusionExtrasMixi
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# Mask for tile weights strength
|
||||
# Mask for tile weights strenght
|
||||
tile_weights = self._gaussian_weights(tile_width, tile_height, batch_size)
|
||||
|
||||
# Diffusion timesteps
|
||||
|
||||
@@ -832,7 +832,7 @@ class AnimateDiffControlNetPipeline(
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
allback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
@@ -840,7 +840,7 @@ class AnimateDiffControlNetPipeline(
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
`._callback_tensor_inputs` attribute of your pipeine class.
|
||||
|
||||
Examples:
|
||||
|
||||
|
||||
@@ -1280,7 +1280,7 @@ class DemoFusionSDXLPipeline(
|
||||
|
||||
return output_images
|
||||
|
||||
# Override to properly handle the loading and unloading of the additional text encoder.
|
||||
# Overrride to properly handle the loading and unloading of the additional text encoder.
|
||||
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
||||
# We could have accessed the unet config from `lora_state_dict()` too. We pass
|
||||
# it here explicitly to be able to tell that it's coming from an SDXL
|
||||
|
||||
@@ -887,7 +887,7 @@ class StyleAlignedSDXLPipeline(
|
||||
# because `num_inference_steps` might be even given that every timestep
|
||||
# (except the highest one) is duplicated. If `num_inference_steps` is even it would
|
||||
# mean that we cut the timesteps in the middle of the denoising step
|
||||
# (between 1st and 2nd derivative) which leads to incorrect results. By adding 1
|
||||
# (between 1st and 2nd devirative) which leads to incorrect results. By adding 1
|
||||
# we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
|
||||
num_inference_steps = num_inference_steps + 1
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -26,7 +26,7 @@ from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInver
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||
from diffusers.pipelines.stable_diffusion_ldm3d.pipeline_stable_diffusion_ldm3d import LDM3DPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_ldm3d import LDM3DPipelineOutput
|
||||
from diffusers.schedulers import DDPMScheduler, KarrasDiffusionSchedulers
|
||||
from diffusers.utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
|
||||
@@ -1073,7 +1073,7 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(
|
||||
# because `num_inference_steps` might be even given that every timestep
|
||||
# (except the highest one) is duplicated. If `num_inference_steps` is even it would
|
||||
# mean that we cut the timesteps in the middle of the denoising step
|
||||
# (between 1st and 2nd derivative) which leads to incorrect results. By adding 1
|
||||
# (between 1st and 2nd devirative) which leads to incorrect results. By adding 1
|
||||
# we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
|
||||
num_inference_steps = num_inference_steps + 1
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -46,11 +46,6 @@ except Exception:
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
logger.warning(
|
||||
"To use instant id pipelines, please make sure you have the `insightface` library installed: `pip install insightface`."
|
||||
"Please refer to: https://huggingface.co/InstantX/InstantID for further instructions regarding inference"
|
||||
)
|
||||
|
||||
|
||||
def FeedForward(dim, mult=4):
|
||||
inner_dim = int(dim * mult)
|
||||
@@ -706,7 +701,7 @@ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
`._callback_tensor_inputs` attribute of your pipeine class.
|
||||
|
||||
Examples:
|
||||
|
||||
|
||||
@@ -23,7 +23,6 @@ import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
@@ -239,10 +238,6 @@ class SDText2ImageDataset:
|
||||
|
||||
def log_validation(vae, unet, args, accelerator, weight_dtype, step):
|
||||
logger.info("Running validation... ")
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
|
||||
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
@@ -279,7 +274,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step):
|
||||
|
||||
for _, prompt in enumerate(validation_prompts):
|
||||
images = []
|
||||
with autocast_ctx:
|
||||
with torch.autocast("cuda", dtype=weight_dtype):
|
||||
images = pipeline(
|
||||
prompt=prompt,
|
||||
num_inference_steps=4,
|
||||
@@ -1177,11 +1172,6 @@ def main(args):
|
||||
).input_ids.to(accelerator.device)
|
||||
uncond_prompt_embeds = text_encoder(uncond_input_ids)[0]
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
# 16. Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
@@ -1310,7 +1300,7 @@ def main(args):
|
||||
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
|
||||
# solver timestep.
|
||||
with torch.no_grad():
|
||||
with autocast_ctx:
|
||||
with torch.autocast("cuda"):
|
||||
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
|
||||
cond_teacher_output = teacher_unet(
|
||||
noisy_model_input.to(weight_dtype),
|
||||
@@ -1369,7 +1359,7 @@ def main(args):
|
||||
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
|
||||
# Note that we do not use a separate target network for LCM-LoRA distillation.
|
||||
with torch.no_grad():
|
||||
with autocast_ctx:
|
||||
with torch.autocast("cuda", dtype=weight_dtype):
|
||||
target_noise_pred = unet(
|
||||
x_prev.float(),
|
||||
timesteps,
|
||||
|
||||
@@ -22,7 +22,6 @@ import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
|
||||
import accelerate
|
||||
@@ -147,12 +146,7 @@ def log_validation(vae, args, accelerator, weight_dtype, step, unet=None, is_fin
|
||||
|
||||
for _, prompt in enumerate(validation_prompts):
|
||||
images = []
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
|
||||
|
||||
with autocast_ctx:
|
||||
with torch.autocast("cuda", dtype=weight_dtype):
|
||||
images = pipeline(
|
||||
prompt=prompt,
|
||||
num_inference_steps=4,
|
||||
|
||||
@@ -24,7 +24,6 @@ import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
@@ -257,10 +256,6 @@ class SDXLText2ImageDataset:
|
||||
|
||||
def log_validation(vae, unet, args, accelerator, weight_dtype, step):
|
||||
logger.info("Running validation... ")
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
|
||||
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
@@ -296,7 +291,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step):
|
||||
|
||||
for _, prompt in enumerate(validation_prompts):
|
||||
images = []
|
||||
with autocast_ctx:
|
||||
with torch.autocast("cuda", dtype=weight_dtype):
|
||||
images = pipeline(
|
||||
prompt=prompt,
|
||||
num_inference_steps=4,
|
||||
@@ -1358,12 +1353,7 @@ def main(args):
|
||||
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
|
||||
# solver timestep.
|
||||
with torch.no_grad():
|
||||
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
with torch.autocast("cuda"):
|
||||
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
|
||||
cond_teacher_output = teacher_unet(
|
||||
noisy_model_input.to(weight_dtype),
|
||||
@@ -1426,12 +1416,7 @@ def main(args):
|
||||
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
|
||||
# Note that we do not use a separate target network for LCM-LoRA distillation.
|
||||
with torch.no_grad():
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
|
||||
|
||||
with autocast_ctx:
|
||||
with torch.autocast("cuda", enabled=True, dtype=weight_dtype):
|
||||
target_noise_pred = unet(
|
||||
x_prev.float(),
|
||||
timesteps,
|
||||
|
||||
@@ -23,7 +23,6 @@ import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
@@ -253,12 +252,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="targe
|
||||
|
||||
for _, prompt in enumerate(validation_prompts):
|
||||
images = []
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
with torch.autocast("cuda"):
|
||||
images = pipeline(
|
||||
prompt=prompt,
|
||||
num_inference_steps=4,
|
||||
@@ -1263,12 +1257,7 @@ def main(args):
|
||||
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
|
||||
# solver timestep.
|
||||
with torch.no_grad():
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
with torch.autocast("cuda"):
|
||||
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
|
||||
cond_teacher_output = teacher_unet(
|
||||
noisy_model_input.to(weight_dtype),
|
||||
@@ -1326,12 +1315,7 @@ def main(args):
|
||||
|
||||
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
|
||||
with torch.no_grad():
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
|
||||
|
||||
with autocast_ctx:
|
||||
with torch.autocast("cuda", dtype=weight_dtype):
|
||||
target_noise_pred = target_unet(
|
||||
x_prev.float(),
|
||||
timesteps,
|
||||
|
||||
@@ -24,7 +24,6 @@ import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
@@ -271,12 +270,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="targe
|
||||
|
||||
for _, prompt in enumerate(validation_prompts):
|
||||
images = []
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
with torch.autocast("cuda"):
|
||||
images = pipeline(
|
||||
prompt=prompt,
|
||||
num_inference_steps=4,
|
||||
@@ -1361,12 +1355,7 @@ def main(args):
|
||||
# estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
|
||||
# solver timestep.
|
||||
with torch.no_grad():
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
with torch.autocast("cuda"):
|
||||
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
|
||||
cond_teacher_output = teacher_unet(
|
||||
noisy_model_input.to(weight_dtype),
|
||||
@@ -1428,12 +1417,7 @@ def main(args):
|
||||
|
||||
# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
|
||||
with torch.no_grad():
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type, dtype=weight_dtype)
|
||||
|
||||
with autocast_ctx:
|
||||
with torch.autocast("cuda", dtype=weight_dtype):
|
||||
target_noise_pred = target_unet(
|
||||
x_prev.float(),
|
||||
timesteps,
|
||||
|
||||
@@ -752,10 +752,6 @@ def main(args):
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import functools
|
||||
import gc
|
||||
import logging
|
||||
@@ -21,7 +22,6 @@ import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
|
||||
import accelerate
|
||||
@@ -125,10 +125,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
|
||||
)
|
||||
|
||||
image_logs = []
|
||||
if is_final_validation or torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast("cuda")
|
||||
|
||||
for validation_prompt, validation_image in zip(validation_prompts, validation_images):
|
||||
validation_image = Image.open(validation_image).convert("RGB")
|
||||
@@ -137,7 +134,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step,
|
||||
images = []
|
||||
|
||||
for _ in range(args.num_validation_images):
|
||||
with autocast_ctx:
|
||||
with inference_ctx:
|
||||
image = pipeline(
|
||||
prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator
|
||||
).images[0]
|
||||
@@ -795,12 +792,6 @@ def main(args):
|
||||
|
||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||
|
||||
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
|
||||
# due to pytorch#99272, MPS does not yet support bfloat16.
|
||||
raise ValueError(
|
||||
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
||||
)
|
||||
|
||||
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
||||
|
||||
accelerator = Accelerator(
|
||||
@@ -810,10 +801,6 @@ def main(args):
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
|
||||
@@ -676,10 +676,6 @@ def main(args):
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
|
||||
@@ -259,17 +259,13 @@ The authors found that by using DoRA, both the learning capacity and training st
|
||||
> This is also aligned with some of the quantitative analysis shown in the paper.
|
||||
|
||||
**Usage**
|
||||
1. To use DoRA you need to upgrade the installation of `peft`:
|
||||
1. To use DoRA you need to install `peft` from main:
|
||||
```bash
|
||||
pip install-U peft
|
||||
pip install git+https://github.com/huggingface/peft.git
|
||||
```
|
||||
2. Enable DoRA training by adding this flag
|
||||
```bash
|
||||
--use_dora
|
||||
```
|
||||
**Inference**
|
||||
The inference is the same as if you train a regular LoRA 🤗
|
||||
|
||||
## Format compatibility
|
||||
|
||||
You can pass `--output_kohya_format` to additionally generate a state dictionary which should be compatible with other platforms and tools such as Automatic 1111, Comfy, Kohya, etc. The `output_dir` will contain a file named "pytorch_lora_weights_kohya.safetensors".
|
||||
The inference is the same as if you train a regular LoRA 🤗
|
||||
@@ -821,10 +821,6 @@ def main(args):
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
|
||||
@@ -749,10 +749,6 @@ def main(args):
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import gc
|
||||
import itertools
|
||||
import json
|
||||
@@ -23,7 +24,6 @@ import os
|
||||
import random
|
||||
import shutil
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
@@ -41,7 +41,6 @@ from peft import LoraConfig, set_peft_model_state_dict
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
from safetensors.torch import load_file, save_file
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import crop
|
||||
@@ -63,9 +62,7 @@ from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr
|
||||
from diffusers.utils import (
|
||||
check_min_version,
|
||||
convert_all_state_dict_to_peft,
|
||||
convert_state_dict_to_diffusers,
|
||||
convert_state_dict_to_kohya,
|
||||
convert_unet_state_dict_to_peft,
|
||||
is_wandb_available,
|
||||
)
|
||||
@@ -208,12 +205,11 @@ def log_validation(
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
|
||||
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
|
||||
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
inference_ctx = (
|
||||
contextlib.nullcontext() if "playground" in args.pretrained_model_name_or_path else torch.cuda.amp.autocast()
|
||||
)
|
||||
|
||||
with autocast_ctx:
|
||||
with inference_ctx:
|
||||
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
@@ -231,8 +227,7 @@ def log_validation(
|
||||
)
|
||||
|
||||
del pipeline
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return images
|
||||
|
||||
@@ -401,11 +396,6 @@ def parse_args(input_args=None):
|
||||
default="lora-dreambooth-model",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_kohya_format",
|
||||
action="store_true",
|
||||
help="Flag to additionally generate final state dict in the Kohya format so that it becomes compatible with A111, Comfy, Kohya, etc.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
@@ -969,12 +959,6 @@ def main(args):
|
||||
if args.do_edm_style_training and args.snr_gamma is not None:
|
||||
raise ValueError("Min-SNR formulation is not supported when conducting EDM-style training.")
|
||||
|
||||
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
|
||||
# due to pytorch#99272, MPS does not yet support bfloat16.
|
||||
raise ValueError(
|
||||
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
||||
)
|
||||
|
||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||
|
||||
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
||||
@@ -987,10 +971,6 @@ def main(args):
|
||||
kwargs_handlers=[kwargs],
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
@@ -1021,8 +1001,7 @@ def main(args):
|
||||
cur_class_images = len(list(class_images_dir.iterdir()))
|
||||
|
||||
if cur_class_images < args.num_class_images:
|
||||
has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()
|
||||
torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32
|
||||
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
|
||||
if args.prior_generation_precision == "fp32":
|
||||
torch_dtype = torch.float32
|
||||
elif args.prior_generation_precision == "fp16":
|
||||
@@ -1147,12 +1126,6 @@ def main(args):
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
|
||||
# due to pytorch#99272, MPS does not yet support bfloat16.
|
||||
raise ValueError(
|
||||
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
||||
)
|
||||
|
||||
# Move unet, vae and text_encoder to device and cast to weight_dtype
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
@@ -1297,7 +1270,7 @@ def main(args):
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
||||
if args.allow_tf32 and torch.cuda.is_available():
|
||||
if args.allow_tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
if args.scale_lr:
|
||||
@@ -1474,8 +1447,7 @@ def main(args):
|
||||
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
|
||||
del tokenizers, text_encoders
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
|
||||
# pack the statically computed variables appropriately here. This is so that we don't
|
||||
@@ -1918,11 +1890,6 @@ def main(args):
|
||||
text_encoder_lora_layers=text_encoder_lora_layers,
|
||||
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
|
||||
)
|
||||
if args.output_kohya_format:
|
||||
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
|
||||
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
|
||||
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
|
||||
save_file(kohya_state_dict, f"{args.output_dir}/pytorch_lora_weights_kohya.safetensors")
|
||||
|
||||
# Final inference
|
||||
# Load previous pipeline
|
||||
|
||||
@@ -21,7 +21,6 @@ import logging
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
|
||||
import accelerate
|
||||
@@ -405,10 +404,6 @@ def main():
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
|
||||
if args.report_to == "wandb":
|
||||
@@ -948,12 +943,9 @@ def main():
|
||||
# run inference
|
||||
original_image = download_image(args.val_image_url)
|
||||
edited_images = []
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
with torch.autocast(
|
||||
str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"
|
||||
):
|
||||
for _ in range(args.num_validation_images):
|
||||
edited_images.append(
|
||||
pipeline(
|
||||
|
||||
@@ -20,7 +20,6 @@ import math
|
||||
import os
|
||||
import shutil
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@@ -71,7 +70,14 @@ WANDB_TABLE_COL_NAMES = ["file_name", "edited_image", "edit_prompt"]
|
||||
TORCH_DTYPE_MAPPING = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
|
||||
|
||||
|
||||
def log_validation(pipeline, args, accelerator, generator, global_step, is_final_validation=False):
|
||||
def log_validation(
|
||||
pipeline,
|
||||
args,
|
||||
accelerator,
|
||||
generator,
|
||||
global_step,
|
||||
is_final_validation=False,
|
||||
):
|
||||
logger.info(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
f" {args.validation_prompt}."
|
||||
@@ -90,12 +96,7 @@ def log_validation(pipeline, args, accelerator, generator, global_step, is_final
|
||||
else Image.open(image_url_or_path).convert("RGB")
|
||||
)(args.val_image_url_or_path)
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
with torch.autocast(str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"):
|
||||
edited_images = []
|
||||
# Run inference
|
||||
for val_img_idx in range(args.num_validation_images):
|
||||
@@ -496,13 +497,6 @@ def main():
|
||||
),
|
||||
)
|
||||
logging_dir = os.path.join(args.output_dir, args.logging_dir)
|
||||
|
||||
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
|
||||
# due to pytorch#99272, MPS does not yet support bfloat16.
|
||||
raise ValueError(
|
||||
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
||||
)
|
||||
|
||||
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
@@ -511,10 +505,6 @@ def main():
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
|
||||
@@ -458,10 +458,6 @@ def main():
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
|
||||
@@ -343,11 +343,6 @@ def main():
|
||||
log_with=args.report_to,
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
|
||||
@@ -356,11 +356,6 @@ def main():
|
||||
log_with=args.report_to,
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
|
||||
@@ -459,10 +459,6 @@ def main():
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
|
||||
@@ -916,10 +916,6 @@ def main(args):
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
|
||||
@@ -484,10 +484,6 @@ def main(args):
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
|
||||
@@ -526,10 +526,6 @@ def main(args):
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
|
||||
@@ -1,121 +0,0 @@
|
||||
This project is an attempt to check if it's possible to apply to [ORPO](https://arxiv.org/abs/2403.07691) on a text-conditioned diffusion model to align it on preference data WITHOUT a reference model. The implementation is based on https://github.com/huggingface/trl/pull/1435/.
|
||||
|
||||
> [!WARNING]
|
||||
> We assume that MSE in the diffusion formulation approximates the log-probs as required by ORPO (hat-tip to [@kashif](https://github.com/kashif) for the idea). So, please consider this to be extremely experimental.
|
||||
|
||||
## Training
|
||||
|
||||
Here's training command you can use on a 40GB A100 to validate things on a [small preference
|
||||
dataset](https://hf.co/datasets/kashif/pickascore):
|
||||
|
||||
```bash
|
||||
accelerate launch train_diffusion_orpo_sdxl_lora.py \
|
||||
--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \
|
||||
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
|
||||
--output_dir="diffusion-sdxl-orpo" \
|
||||
--mixed_precision="fp16" \
|
||||
--dataset_name=kashif/pickascore \
|
||||
--train_batch_size=8 \
|
||||
--gradient_accumulation_steps=2 \
|
||||
--gradient_checkpointing \
|
||||
--use_8bit_adam \
|
||||
--rank=8 \
|
||||
--learning_rate=1e-5 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=2000 \
|
||||
--checkpointing_steps=500 \
|
||||
--run_validation --validation_steps=50 \
|
||||
--seed="0" \
|
||||
--report_to="wandb" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
We also provide a simple script to scale up the training on the [yuvalkirstain/pickapic_v2](https://huggingface.co/datasets/yuvalkirstain/pickapic_v2) dataset:
|
||||
|
||||
```bash
|
||||
accelerate launch --multi_gpu train_diffusion_orpo_sdxl_lora_wds.py \
|
||||
--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \
|
||||
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
|
||||
--dataset_path="pipe:aws s3 cp s3://diffusion-preference-opt/{00000..00644}.tar -" \
|
||||
--output_dir="diffusion-sdxl-orpo-wds" \
|
||||
--mixed_precision="fp16" \
|
||||
--gradient_accumulation_steps=1 \
|
||||
--gradient_checkpointing \
|
||||
--use_8bit_adam \
|
||||
--rank=8 \
|
||||
--dataloader_num_workers=8 \
|
||||
--learning_rate=3e-5 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=50000 \
|
||||
--checkpointing_steps=2000 \
|
||||
--run_validation --validation_steps=500 \
|
||||
--seed="0" \
|
||||
--report_to="wandb" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
We tested the above on a node of 8 H100s but it should also work on A100s. It requires the `webdataset` library for faster dataloading. Note that we kept the dataset shards on an S3 bucket but it should be also possible to have them stored locally.
|
||||
|
||||
You can use the code below to convert the original dataset into `webdataset` shards:
|
||||
|
||||
```python
|
||||
import os
|
||||
import io
|
||||
import ray
|
||||
import webdataset as wds
|
||||
from datasets import Dataset
|
||||
from PIL import Image
|
||||
|
||||
ray.init(num_cpus=8)
|
||||
|
||||
|
||||
def convert_to_image(im_bytes):
|
||||
return Image.open(io.BytesIO(im_bytes)).convert("RGB")
|
||||
|
||||
def main():
|
||||
dataset_path = "/pickapic_v2/data"
|
||||
wds_shards_path = "/pickapic_v2_webdataset"
|
||||
# get all .parquet files in the dataset path
|
||||
dataset_files = [
|
||||
os.path.join(dataset_path, f)
|
||||
for f in os.listdir(dataset_path)
|
||||
if f.endswith(".parquet")
|
||||
]
|
||||
|
||||
@ray.remote
|
||||
def create_shard(path):
|
||||
# get basename of the file
|
||||
basename = os.path.basename(path)
|
||||
# get the shard number data-00123-of-01034.parquet -> 00123
|
||||
shard_num = basename.split("-")[1]
|
||||
dataset = Dataset.from_parquet(path)
|
||||
# create a webdataset shard
|
||||
shard = wds.TarWriter(os.path.join(wds_shards_path, f"{shard_num}.tar"))
|
||||
|
||||
for i, example in enumerate(dataset):
|
||||
wds_example = {
|
||||
"__key__": str(i),
|
||||
"original_prompt.txt": example["caption"],
|
||||
"jpg_0.jpg": convert_to_image(example["jpg_0"]),
|
||||
"jpg_1.jpg": convert_to_image(example["jpg_1"]),
|
||||
"label_0.txt": str(example["label_0"]),
|
||||
"label_1.txt": str(example["label_1"])
|
||||
}
|
||||
shard.write(wds_example)
|
||||
shard.close()
|
||||
|
||||
futures = [create_shard.remote(path) for path in dataset_files]
|
||||
ray.get(futures)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
```
|
||||
|
||||
## Inference
|
||||
|
||||
Refer to [sayakpaul/diffusion-sdxl-orpo](https://huggingface.co/sayakpaul/diffusion-sdxl-orpo) for an experimental checkpoint.
|
||||
@@ -1,7 +0,0 @@
|
||||
datasets
|
||||
accelerate
|
||||
transformers
|
||||
torchvision
|
||||
wandb
|
||||
peft
|
||||
webdataset
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -21,7 +21,6 @@ import logging
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
|
||||
import accelerate
|
||||
@@ -411,10 +410,6 @@ def main():
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
|
||||
if args.report_to == "wandb":
|
||||
@@ -972,12 +967,9 @@ def main():
|
||||
# run inference
|
||||
original_image = download_image(args.val_image_url)
|
||||
edited_images = []
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
with torch.autocast(
|
||||
str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"
|
||||
):
|
||||
for _ in range(args.num_validation_images):
|
||||
edited_images.append(
|
||||
pipeline(
|
||||
|
||||
@@ -378,10 +378,6 @@ def main():
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
# If passed along, set the training seed now.
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
@@ -411,11 +411,6 @@ def main():
|
||||
log_with=args.report_to,
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
|
||||
@@ -698,10 +698,6 @@ def main(args):
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
|
||||
@@ -566,10 +566,6 @@ def main():
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
|
||||
@@ -439,10 +439,6 @@ def main():
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
|
||||
@@ -581,10 +581,6 @@ def main():
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
|
||||
-4
@@ -295,10 +295,6 @@ def main(args):
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
if args.logger == "tensorboard":
|
||||
if not is_tensorboard_available():
|
||||
raise ImportError("Make sure to install tensorboard if you want to use it for logging during training.")
|
||||
|
||||
@@ -1005,7 +1005,7 @@ class PromptDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
`._callback_tensor_inputs` attribute of your pipeine class.
|
||||
|
||||
Examples:
|
||||
|
||||
|
||||
@@ -799,10 +799,6 @@ def main(args):
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
|
||||
@@ -20,7 +20,6 @@ import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
|
||||
import accelerate
|
||||
@@ -165,12 +164,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight
|
||||
|
||||
images = []
|
||||
for i in range(len(args.validation_prompts)):
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
with torch.autocast("cuda"):
|
||||
image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
|
||||
|
||||
images.append(image)
|
||||
@@ -529,10 +523,6 @@ def main():
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
|
||||
@@ -21,7 +21,6 @@ import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
@@ -409,11 +408,6 @@ def main():
|
||||
log_with=args.report_to,
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
@@ -884,12 +878,7 @@ def main():
|
||||
if args.seed is not None:
|
||||
generator = generator.manual_seed(args.seed)
|
||||
images = []
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
with torch.cuda.amp.autocast():
|
||||
for _ in range(args.num_validation_images):
|
||||
images.append(
|
||||
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
|
||||
@@ -959,12 +948,7 @@ def main():
|
||||
if args.seed is not None:
|
||||
generator = generator.manual_seed(args.seed)
|
||||
images = []
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
with torch.cuda.amp.autocast():
|
||||
for _ in range(args.num_validation_images):
|
||||
images.append(
|
||||
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
|
||||
|
||||
@@ -21,7 +21,6 @@ import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
@@ -502,12 +501,6 @@ def main(args):
|
||||
|
||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||
|
||||
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
|
||||
# due to pytorch#99272, MPS does not yet support bfloat16.
|
||||
raise ValueError(
|
||||
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
||||
)
|
||||
|
||||
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
||||
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
accelerator = Accelerator(
|
||||
@@ -1205,12 +1198,8 @@ def main(args):
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
pipeline_args = {"prompt": args.validation_prompt}
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
with torch.cuda.amp.autocast():
|
||||
images = [
|
||||
pipeline(**pipeline_args, generator=generator).images[0]
|
||||
for _ in range(args.num_validation_images)
|
||||
|
||||
@@ -23,7 +23,6 @@ import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
|
||||
import accelerate
|
||||
@@ -591,12 +590,6 @@ def main(args):
|
||||
|
||||
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
||||
|
||||
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
|
||||
# due to pytorch#99272, MPS does not yet support bfloat16.
|
||||
raise ValueError(
|
||||
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
||||
)
|
||||
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
@@ -604,10 +597,6 @@ def main(args):
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
@@ -991,11 +980,6 @@ def main(args):
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
return model
|
||||
|
||||
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
@@ -1229,7 +1213,7 @@ def main(args):
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
pipeline_args = {"prompt": args.validation_prompt}
|
||||
|
||||
with autocast_ctx:
|
||||
with torch.cuda.amp.autocast():
|
||||
images = [
|
||||
pipeline(**pipeline_args, generator=generator, num_inference_steps=25).images[0]
|
||||
for _ in range(args.num_validation_images)
|
||||
@@ -1252,10 +1236,6 @@ def main(args):
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if args.use_ema:
|
||||
# Switch back to the original UNet parameters.
|
||||
ema_unet.restore(unet.parameters())
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
unet = unwrap_model(unet)
|
||||
@@ -1288,8 +1268,7 @@ def main(args):
|
||||
if args.validation_prompt and args.num_validation_images > 0:
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
|
||||
with autocast_ctx:
|
||||
with torch.cuda.amp.autocast():
|
||||
images = [
|
||||
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
|
||||
for _ in range(args.num_validation_images)
|
||||
|
||||
@@ -20,7 +20,6 @@ import os
|
||||
import random
|
||||
import shutil
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
@@ -144,12 +143,7 @@ def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight
|
||||
generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
images = []
|
||||
for _ in range(args.num_validation_images):
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
with torch.autocast("cuda"):
|
||||
image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
|
||||
images.append(image)
|
||||
|
||||
@@ -606,10 +600,6 @@ def main():
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
|
||||
@@ -605,10 +605,6 @@ def main():
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
|
||||
@@ -460,10 +460,6 @@ def main():
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
|
||||
@@ -458,10 +458,6 @@ def main():
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
|
||||
@@ -23,14 +23,13 @@ To create the package for PyPI.
|
||||
If releasing on a special branch, copy the updated README.md on the main branch for the commit you will make
|
||||
for the post-release and run `make fix-copies` on the main branch as well.
|
||||
|
||||
2. Unpin specific versions from setup.py that use a git install.
|
||||
2. Run Tests for Amazon Sagemaker. The documentation is located in `./tests/sagemaker/README.md`, otherwise @philschmid.
|
||||
|
||||
3. Checkout the release branch (v<RELEASE>-release, for example v4.19-release), and commit these changes with the
|
||||
3. Unpin specific versions from setup.py that use a git install.
|
||||
|
||||
4. Checkout the release branch (v<RELEASE>-release, for example v4.19-release), and commit these changes with the
|
||||
message: "Release: <RELEASE>" and push.
|
||||
|
||||
4. Manually trigger the "Nightly and release tests on main/release branch" workflow from the release branch. Wait for
|
||||
the tests to complete. We can safely ignore the known test failures.
|
||||
|
||||
5. Wait for the tests on main to be completed and be green (otherwise revert and fix bugs).
|
||||
|
||||
6. Add a tag in git to mark the release: "git tag v<RELEASE> -m 'Adds tag v<RELEASE> for PyPI'"
|
||||
@@ -73,11 +72,7 @@ To create the package for PyPI.
|
||||
9. Upload the final version to the actual PyPI:
|
||||
twine upload dist/* -r pypi
|
||||
|
||||
10. Prepare the release notes and publish them on GitHub once everything is looking hunky-dory. You can use the following
|
||||
Space to fetch all the commits applicable for the release: https://huggingface.co/spaces/lysandre/github-release. Repo should
|
||||
be `huggingface/diffusers`. `tag` should be the previous release tag (v0.26.1, for example), and `branch` should be
|
||||
the latest release branch (v0.27.0-release, for example). It denotes all commits that have happened on branch
|
||||
v0.27.0-release after the tag v0.26.1 was created.
|
||||
10. Prepare the release notes and publish them on GitHub once everything is looking hunky-dory.
|
||||
|
||||
11. Run `make post-release` (or, for a patch release, `make post-patch`). If you were on a branch for the release,
|
||||
you need to go back to main before executing this.
|
||||
@@ -86,8 +81,9 @@ To create the package for PyPI.
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from distutils.core import Command
|
||||
|
||||
from setuptools import Command, find_packages, setup
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
# IMPORTANT:
|
||||
@@ -134,7 +130,6 @@ _deps = [
|
||||
"torchvision",
|
||||
"transformers>=4.25.1",
|
||||
"urllib3<=2.0.0",
|
||||
"black",
|
||||
]
|
||||
|
||||
# this is a lookup table with items like:
|
||||
@@ -168,7 +163,7 @@ def deps_list(*pkgs):
|
||||
|
||||
class DepsTableUpdateCommand(Command):
|
||||
"""
|
||||
A custom command that updates the dependency table.
|
||||
A custom distutils command that updates the dependency table.
|
||||
usage: python setup.py deps_table_update
|
||||
"""
|
||||
|
||||
|
||||
@@ -42,5 +42,4 @@ deps = {
|
||||
"torchvision": "torchvision",
|
||||
"transformers": "transformers>=4.25.1",
|
||||
"urllib3": "urllib3<=2.0.0",
|
||||
"black": "black",
|
||||
}
|
||||
|
||||
@@ -173,9 +173,8 @@ class VaeImageProcessor(ConfigMixin):
|
||||
@staticmethod
|
||||
def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0):
|
||||
"""
|
||||
Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect
|
||||
ratio of the original image; for example, if user drew mask in a 128x32 region, and the dimensions for
|
||||
processing are 512x512, the region will be expanded to 128x128.
|
||||
Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect ratio of the original image;
|
||||
for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128.
|
||||
|
||||
Args:
|
||||
mask_image (PIL.Image.Image): Mask image.
|
||||
@@ -184,8 +183,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
pad (int, optional): Padding to be added to the crop region. Defaults to 0.
|
||||
|
||||
Returns:
|
||||
tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and
|
||||
matches the original aspect ratio.
|
||||
tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and matches the original aspect ratio.
|
||||
"""
|
||||
|
||||
mask_image = mask_image.convert("L")
|
||||
@@ -267,8 +265,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
height: int,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
|
||||
the image within the dimensions, filling empty with data from image.
|
||||
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image.
|
||||
|
||||
Args:
|
||||
image: The image to resize.
|
||||
@@ -312,8 +309,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
height: int,
|
||||
) -> PIL.Image.Image:
|
||||
"""
|
||||
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
|
||||
the image within the dimensions, cropping the excess.
|
||||
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess.
|
||||
|
||||
Args:
|
||||
image: The image to resize.
|
||||
@@ -350,12 +346,12 @@ class VaeImageProcessor(ConfigMixin):
|
||||
The width to resize to.
|
||||
resize_mode (`str`, *optional*, defaults to `default`):
|
||||
The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit
|
||||
within the specified width and height, and it may not maintaining the original aspect ratio. If `fill`,
|
||||
will resize the image to fit within the specified width and height, maintaining the aspect ratio, and
|
||||
then center the image within the dimensions, filling empty with data from image. If `crop`, will resize
|
||||
the image to fit within the specified width and height, maintaining the aspect ratio, and then center
|
||||
the image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
|
||||
supported for PIL image input.
|
||||
within the specified width and height, and it may not maintaining the original aspect ratio.
|
||||
If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
|
||||
within the dimensions, filling empty with data from image.
|
||||
If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
|
||||
within the dimensions, cropping the excess.
|
||||
Note that resize_mode `fill` and `crop` are only supported for PIL image input.
|
||||
|
||||
Returns:
|
||||
`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
|
||||
@@ -460,21 +456,19 @@ class VaeImageProcessor(ConfigMixin):
|
||||
|
||||
Args:
|
||||
image (`pipeline_image_input`):
|
||||
The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of
|
||||
supported formats.
|
||||
The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of supported formats.
|
||||
height (`int`, *optional*, defaults to `None`):
|
||||
The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
|
||||
height.
|
||||
The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default height.
|
||||
width (`int`, *optional*`, defaults to `None`):
|
||||
The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
|
||||
The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
|
||||
resize_mode (`str`, *optional*, defaults to `default`):
|
||||
The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
|
||||
the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will
|
||||
resize the image to fit within the specified width and height, maintaining the aspect ratio, and then
|
||||
center the image within the dimensions, filling empty with data from image. If `crop`, will resize the
|
||||
image to fit within the specified width and height, maintaining the aspect ratio, and then center the
|
||||
image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
|
||||
supported for PIL image input.
|
||||
The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit
|
||||
within the specified width and height, and it may not maintaining the original aspect ratio.
|
||||
If `fill`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
|
||||
within the dimensions, filling empty with data from image.
|
||||
If `crop`, will resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image
|
||||
within the dimensions, cropping the excess.
|
||||
Note that resize_mode `fill` and `crop` are only supported for PIL image input.
|
||||
crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
|
||||
The crop coordinates for each image in the batch. If `None`, will not crop the image.
|
||||
"""
|
||||
@@ -936,8 +930,8 @@ class IPAdapterMaskProcessor(VaeImageProcessor):
|
||||
@staticmethod
|
||||
def downsample(mask: torch.FloatTensor, batch_size: int, num_queries: int, value_embed_dim: int):
|
||||
"""
|
||||
Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention. If the
|
||||
aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued.
|
||||
Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention.
|
||||
If the aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued.
|
||||
|
||||
Args:
|
||||
mask (`torch.FloatTensor`):
|
||||
|
||||
@@ -67,18 +67,17 @@ class IPAdapterMixin:
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
subfolder (`str` or `List[str]`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
|
||||
list is passed, it should have the same length as `weight_name`.
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
If a list is passed, it should have the same length as `weight_name`.
|
||||
weight_name (`str` or `List[str]`):
|
||||
The name of the weight file to load. If a list is passed, it should have the same length as
|
||||
`weight_name`.
|
||||
image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
|
||||
The subfolder location of the image encoder within a larger model repository on the Hub or locally.
|
||||
Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
|
||||
`subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g.
|
||||
`image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than
|
||||
`subfolder`, you should pass the path to the folder that contains image encoder weights, for example,
|
||||
`image_encoder_folder="different_subfolder/image_encoder"`.
|
||||
Pass `None` to not load the image encoder. If the image encoder is located in a folder inside `subfolder`,
|
||||
you only need to pass the name of the folder that contains image encoder weights, e.g. `image_encoder_folder="image_encoder"`.
|
||||
If the image encoder is located in a folder other than `subfolder`, you should pass the path to the folder that contains image encoder weights,
|
||||
for example, `image_encoder_folder="different_subfolder/image_encoder"`.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
|
||||
+11
-103
@@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import inspect
|
||||
import os
|
||||
from pathlib import Path
|
||||
@@ -37,7 +36,6 @@ from ..utils import (
|
||||
get_adapter_name,
|
||||
get_peft_kwargs,
|
||||
is_accelerate_available,
|
||||
is_peft_version,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
recurse_remove_peft_layers,
|
||||
@@ -115,7 +113,7 @@ class LoraLoaderMixin:
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
@@ -453,15 +451,6 @@ class LoraLoaderMixin:
|
||||
rank[key] = val.shape[1]
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
@@ -583,15 +572,6 @@ class LoraLoaderMixin:
|
||||
}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict, is_unet=False)
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
@@ -674,13 +654,6 @@ class LoraLoaderMixin:
|
||||
rank[key] = val.shape[1]
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict)
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
@@ -986,7 +959,7 @@ class LoraLoaderMixin:
|
||||
self,
|
||||
adapter_names: Union[List[str], str],
|
||||
text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
|
||||
text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None,
|
||||
text_encoder_weights: List[float] = None,
|
||||
):
|
||||
"""
|
||||
Sets the adapter layers for the text encoder.
|
||||
@@ -1004,20 +977,15 @@ class LoraLoaderMixin:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
def process_weights(adapter_names, weights):
|
||||
# Expand weights into a list, one entry per adapter
|
||||
# e.g. for 2 adapters: 7 -> [7,7] ; [3, None] -> [3, None]
|
||||
if not isinstance(weights, list):
|
||||
weights = [weights] * len(adapter_names)
|
||||
if weights is None:
|
||||
weights = [1.0] * len(adapter_names)
|
||||
elif isinstance(weights, float):
|
||||
weights = [weights]
|
||||
|
||||
if len(adapter_names) != len(weights):
|
||||
raise ValueError(
|
||||
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
|
||||
)
|
||||
|
||||
# Set None values to default of 1.0
|
||||
# e.g. [7,7] -> [7,7] ; [3, None] -> [3,1]
|
||||
weights = [w if w is not None else 1.0 for w in weights]
|
||||
|
||||
return weights
|
||||
|
||||
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
|
||||
@@ -1065,77 +1033,17 @@ class LoraLoaderMixin:
|
||||
def set_adapters(
|
||||
self,
|
||||
adapter_names: Union[List[str], str],
|
||||
adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
|
||||
adapter_weights: Optional[List[float]] = None,
|
||||
):
|
||||
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
|
||||
|
||||
adapter_weights = copy.deepcopy(adapter_weights)
|
||||
|
||||
# Expand weights into a list, one entry per adapter
|
||||
if not isinstance(adapter_weights, list):
|
||||
adapter_weights = [adapter_weights] * len(adapter_names)
|
||||
|
||||
if len(adapter_names) != len(adapter_weights):
|
||||
raise ValueError(
|
||||
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}"
|
||||
)
|
||||
|
||||
# Decompose weights into weights for unet, text_encoder and text_encoder_2
|
||||
unet_lora_weights, text_encoder_lora_weights, text_encoder_2_lora_weights = [], [], []
|
||||
|
||||
list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
|
||||
all_adapters = {
|
||||
adapter for adapters in list_adapters.values() for adapter in adapters
|
||||
} # eg ["adapter1", "adapter2"]
|
||||
invert_list_adapters = {
|
||||
adapter: [part for part, adapters in list_adapters.items() if adapter in adapters]
|
||||
for adapter in all_adapters
|
||||
} # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}
|
||||
|
||||
for adapter_name, weights in zip(adapter_names, adapter_weights):
|
||||
if isinstance(weights, dict):
|
||||
unet_lora_weight = weights.pop("unet", None)
|
||||
text_encoder_lora_weight = weights.pop("text_encoder", None)
|
||||
text_encoder_2_lora_weight = weights.pop("text_encoder_2", None)
|
||||
|
||||
if len(weights) > 0:
|
||||
raise ValueError(
|
||||
f"Got invalid key '{weights.keys()}' in lora weight dict for adapter {adapter_name}."
|
||||
)
|
||||
|
||||
if text_encoder_2_lora_weight is not None and not hasattr(self, "text_encoder_2"):
|
||||
logger.warning(
|
||||
"Lora weight dict contains text_encoder_2 weights but will be ignored because pipeline does not have text_encoder_2."
|
||||
)
|
||||
|
||||
# warn if adapter doesn't have parts specified by adapter_weights
|
||||
for part_weight, part_name in zip(
|
||||
[unet_lora_weight, text_encoder_lora_weight, text_encoder_2_lora_weight],
|
||||
["unet", "text_encoder", "text_encoder_2"],
|
||||
):
|
||||
if part_weight is not None and part_name not in invert_list_adapters[adapter_name]:
|
||||
logger.warning(
|
||||
f"Lora weight dict for adapter '{adapter_name}' contains {part_name}, but this will be ignored because {adapter_name} does not contain weights for {part_name}. Valid parts for {adapter_name} are: {invert_list_adapters[adapter_name]}."
|
||||
)
|
||||
|
||||
else:
|
||||
unet_lora_weight = weights
|
||||
text_encoder_lora_weight = weights
|
||||
text_encoder_2_lora_weight = weights
|
||||
|
||||
unet_lora_weights.append(unet_lora_weight)
|
||||
text_encoder_lora_weights.append(text_encoder_lora_weight)
|
||||
text_encoder_2_lora_weights.append(text_encoder_2_lora_weight)
|
||||
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
# Handle the UNET
|
||||
unet.set_adapters(adapter_names, unet_lora_weights)
|
||||
unet.set_adapters(adapter_names, adapter_weights)
|
||||
|
||||
# Handle the Text Encoder
|
||||
if hasattr(self, "text_encoder"):
|
||||
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder, text_encoder_lora_weights)
|
||||
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder, adapter_weights)
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder_2, text_encoder_2_lora_weights)
|
||||
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder_2, adapter_weights)
|
||||
|
||||
def disable_lora(self):
|
||||
if not USE_PEFT_BACKEND:
|
||||
@@ -1335,7 +1243,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
unet_config=self.unet.config,
|
||||
**kwargs,
|
||||
)
|
||||
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import re
|
||||
|
||||
from ..utils import is_peft_version, logging
|
||||
from ..utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -128,15 +128,6 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
|
||||
te_state_dict = {}
|
||||
te2_state_dict = {}
|
||||
network_alphas = {}
|
||||
is_unet_dora_lora = any("dora_scale" in k and "lora_unet_" in k for k in state_dict)
|
||||
is_te_dora_lora = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict)
|
||||
is_te2_dora_lora = any("dora_scale" in k and "lora_te2_" in k for k in state_dict)
|
||||
|
||||
if is_unet_dora_lora or is_te_dora_lora or is_te2_dora_lora:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
|
||||
# every down weight has a corresponding up weight and potentially an alpha weight
|
||||
lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")]
|
||||
@@ -207,12 +198,6 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
|
||||
unet_state_dict[diffusers_name] = state_dict.pop(key)
|
||||
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||
|
||||
if is_unet_dora_lora:
|
||||
dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
|
||||
unet_state_dict[
|
||||
diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
|
||||
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
||||
|
||||
elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
|
||||
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
||||
key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_"
|
||||
@@ -244,19 +229,6 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
|
||||
te2_state_dict[diffusers_name] = state_dict.pop(key)
|
||||
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||
|
||||
if (is_te_dora_lora or is_te2_dora_lora) and lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
|
||||
dora_scale_key_to_replace_te = (
|
||||
"_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
|
||||
)
|
||||
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
||||
te_state_dict[
|
||||
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
|
||||
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
||||
elif lora_name.startswith("lora_te2_"):
|
||||
te2_state_dict[
|
||||
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
|
||||
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
||||
|
||||
# Rename the alphas so that they can be mapped appropriately.
|
||||
if lora_name_alpha in state_dict:
|
||||
alpha = state_dict.pop(lora_name_alpha).item()
|
||||
|
||||
@@ -20,8 +20,7 @@ from ..utils import MIN_PEFT_VERSION, check_peft_version, is_peft_available
|
||||
class PeftAdapterMixin:
|
||||
"""
|
||||
A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
|
||||
more details about adapters and injecting them in a transformer-based model, check out the PEFT
|
||||
[documentation](https://huggingface.co/docs/peft/index).
|
||||
more details about adapters and injecting them in a transformer-based model, check out the PEFT [documentation](https://huggingface.co/docs/peft/index).
|
||||
|
||||
Install the latest version of PEFT, and use this mixin to:
|
||||
|
||||
@@ -144,8 +143,8 @@ class PeftAdapterMixin:
|
||||
|
||||
def enable_adapters(self) -> None:
|
||||
"""
|
||||
Enable adapters that are attached to the model. The model uses `self.active_adapters()` to retrieve the list of
|
||||
adapters to enable.
|
||||
Enable adapters that are attached to the model. The model uses `self.active_adapters()` to retrieve the
|
||||
list of adapters to enable.
|
||||
|
||||
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
|
||||
[documentation](https://huggingface.co/docs/peft).
|
||||
|
||||
@@ -198,24 +198,19 @@ class FromSingleFileMixin:
|
||||
model_type (`str`, *optional*):
|
||||
The type of model to load. If not provided, the model type will be inferred from the checkpoint file.
|
||||
image_size (`int`, *optional*):
|
||||
The size of the image output. It's used to configure the `sample_size` parameter of the UNet and VAE
|
||||
model.
|
||||
The size of the image output. It's used to configure the `sample_size` parameter of the UNet and VAE model.
|
||||
load_safety_checker (`bool`, *optional*, defaults to `False`):
|
||||
Whether to load the safety checker model or not. By default, the safety checker is not loaded unless a
|
||||
`safety_checker` component is passed to the `kwargs`.
|
||||
Whether to load the safety checker model or not. By default, the safety checker is not loaded unless a `safety_checker` component is passed to the `kwargs`.
|
||||
num_in_channels (`int`, *optional*):
|
||||
Specify the number of input channels for the UNet model. Read more about how to configure UNet model
|
||||
with this parameter
|
||||
Specify the number of input channels for the UNet model. Read more about how to configure UNet model with this parameter
|
||||
[here](https://huggingface.co/docs/diffusers/training/adapt_a_model#configure-unet2dconditionmodel-parameters).
|
||||
scaling_factor (`float`, *optional*):
|
||||
The scaling factor to use for the VAE model. If not provided, it is inferred from the config file
|
||||
first. If the scaling factor is not found in the config file, the default value 0.18215 is used.
|
||||
The scaling factor to use for the VAE model. If not provided, it is inferred from the config file first.
|
||||
If the scaling factor is not found in the config file, the default value 0.18215 is used.
|
||||
scheduler_type (`str`, *optional*):
|
||||
The type of scheduler to load. If not provided, the scheduler type will be inferred from the checkpoint
|
||||
file.
|
||||
The type of scheduler to load. If not provided, the scheduler type will be inferred from the checkpoint file.
|
||||
prediction_type (`str`, *optional*):
|
||||
The type of prediction to load. If not provided, the prediction type will be inferred from the
|
||||
checkpoint file.
|
||||
The type of prediction to load. If not provided, the prediction type will be inferred from the checkpoint file.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
|
||||
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
|
||||
|
||||
@@ -50,8 +50,6 @@ if is_transformers_available():
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
from ..models.modeling_utils import load_model_dict_into_meta
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
CONFIG_URLS = {
|
||||
@@ -979,6 +977,8 @@ def create_diffusers_controlnet_model_from_ldm(
|
||||
controlnet = ControlNetModel(**diffusers_config)
|
||||
|
||||
if is_accelerate_available():
|
||||
from ..models.modeling_utils import load_model_dict_into_meta
|
||||
|
||||
unexpected_keys = load_model_dict_into_meta(
|
||||
controlnet, diffusers_format_controlnet_checkpoint, dtype=torch_dtype
|
||||
)
|
||||
@@ -1155,6 +1155,8 @@ def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_
|
||||
text_model_dict[diffusers_key] = checkpoint[key]
|
||||
|
||||
if is_accelerate_available():
|
||||
from ..models.modeling_utils import load_model_dict_into_meta
|
||||
|
||||
unexpected_keys = load_model_dict_into_meta(text_model, text_model_dict, dtype=torch_dtype)
|
||||
if text_model._keys_to_ignore_on_load_unexpected is not None:
|
||||
for pat in text_model._keys_to_ignore_on_load_unexpected:
|
||||
@@ -1248,6 +1250,8 @@ def create_text_encoder_from_open_clip_checkpoint(
|
||||
text_model_dict[diffusers_key] = checkpoint[key]
|
||||
|
||||
if is_accelerate_available():
|
||||
from ..models.modeling_utils import load_model_dict_into_meta
|
||||
|
||||
unexpected_keys = load_model_dict_into_meta(text_model, text_model_dict, dtype=torch_dtype)
|
||||
if text_model._keys_to_ignore_on_load_unexpected is not None:
|
||||
for pat in text_model._keys_to_ignore_on_load_unexpected:
|
||||
@@ -1313,6 +1317,8 @@ def create_diffusers_unet_model_from_ldm(
|
||||
unet = UNet2DConditionModel(**unet_config)
|
||||
|
||||
if is_accelerate_available():
|
||||
from ..models.modeling_utils import load_model_dict_into_meta
|
||||
|
||||
unexpected_keys = load_model_dict_into_meta(unet, diffusers_format_unet_checkpoint, dtype=torch_dtype)
|
||||
if unet._keys_to_ignore_on_load_unexpected is not None:
|
||||
for pat in unet._keys_to_ignore_on_load_unexpected:
|
||||
@@ -1373,6 +1379,8 @@ def create_diffusers_vae_model_from_ldm(
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
|
||||
if is_accelerate_available():
|
||||
from ..models.modeling_utils import load_model_dict_into_meta
|
||||
|
||||
unexpected_keys = load_model_dict_into_meta(vae, diffusers_format_vae_checkpoint, dtype=torch_dtype)
|
||||
if vae._keys_to_ignore_on_load_unexpected is not None:
|
||||
for pat in vae._keys_to_ignore_on_load_unexpected:
|
||||
|
||||
@@ -487,35 +487,20 @@ class TextualInversionLoaderMixin:
|
||||
|
||||
# Example 3: unload from SDXL
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
|
||||
embedding_path = hf_hub_download(
|
||||
repo_id="linoyts/web_y2k", filename="web_y2k_emb.safetensors", repo_type="model"
|
||||
)
|
||||
embedding_path = hf_hub_download(repo_id="linoyts/web_y2k", filename="web_y2k_emb.safetensors", repo_type="model")
|
||||
|
||||
# load embeddings to the text encoders
|
||||
state_dict = load_file(embedding_path)
|
||||
|
||||
# load embeddings of text_encoder 1 (CLIP ViT-L/14)
|
||||
pipeline.load_textual_inversion(
|
||||
state_dict["clip_l"],
|
||||
token=["<s0>", "<s1>"],
|
||||
text_encoder=pipeline.text_encoder,
|
||||
tokenizer=pipeline.tokenizer,
|
||||
)
|
||||
pipeline.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
|
||||
# load embeddings of text_encoder 2 (CLIP ViT-G/14)
|
||||
pipeline.load_textual_inversion(
|
||||
state_dict["clip_g"],
|
||||
token=["<s0>", "<s1>"],
|
||||
text_encoder=pipeline.text_encoder_2,
|
||||
tokenizer=pipeline.tokenizer_2,
|
||||
)
|
||||
pipeline.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
|
||||
|
||||
# Unload explicitly from both text encoders abd tokenizers
|
||||
pipeline.unload_textual_inversion(
|
||||
tokens=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer
|
||||
)
|
||||
pipeline.unload_textual_inversion(
|
||||
tokens=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2
|
||||
)
|
||||
pipeline.unload_textual_inversion(tokens=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
|
||||
pipeline.unload_textual_inversion(tokens=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
|
||||
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
@@ -47,7 +47,6 @@ from .single_file_utils import (
|
||||
infer_stable_cascade_single_file_config,
|
||||
load_single_file_model_checkpoint,
|
||||
)
|
||||
from .unet_loader_utils import _maybe_expand_lora_scales
|
||||
from .utils import AttnProcsLayers
|
||||
|
||||
|
||||
@@ -565,7 +564,7 @@ class UNet2DConditionLoadersMixin:
|
||||
def set_adapters(
|
||||
self,
|
||||
adapter_names: Union[List[str], str],
|
||||
weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None,
|
||||
weights: Optional[Union[List[float], float]] = None,
|
||||
):
|
||||
"""
|
||||
Set the currently active adapters for use in the UNet.
|
||||
@@ -598,9 +597,9 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
|
||||
|
||||
# Expand weights into a list, one entry per adapter
|
||||
# examples for e.g. 2 adapters: [{...}, 7] -> [7,7] ; None -> [None, None]
|
||||
if not isinstance(weights, list):
|
||||
if weights is None:
|
||||
weights = [1.0] * len(adapter_names)
|
||||
elif isinstance(weights, float):
|
||||
weights = [weights] * len(adapter_names)
|
||||
|
||||
if len(adapter_names) != len(weights):
|
||||
@@ -608,13 +607,6 @@ class UNet2DConditionLoadersMixin:
|
||||
f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}."
|
||||
)
|
||||
|
||||
# Set None values to default of 1.0
|
||||
# e.g. [{...}, 7] -> [{...}, 7] ; [None, None] -> [1.0, 1.0]
|
||||
weights = [w if w is not None else 1.0 for w in weights]
|
||||
|
||||
# e.g. [{...}, 7] -> [{expanded dict...}, 7]
|
||||
weights = _maybe_expand_lora_scales(self, weights)
|
||||
|
||||
set_weights_and_activate_adapters(self, adapter_names, weights)
|
||||
|
||||
def disable_lora(self):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user