Compare commits
101 Commits
debug
...
fix/lora-dtype
| Author | SHA1 | Date | |
|---|---|---|---|
| 62ece6ab5a | |||
| 7722f5b67c | |||
| 0281e85827 | |||
| 9106e66382 | |||
| 2fb3d141c2 | |||
| 6e1b06c01c | |||
| 50f2544697 | |||
| fc5fc8c8d2 | |||
| 030bb528ba | |||
| f803d3d1f5 | |||
| 07b297e7de | |||
| 2bfa55f4ed | |||
| 9bc55e8b7f | |||
| 3758d7a8b0 | |||
| 3a794b54c9 | |||
| fe623f3bea | |||
| bc65f829b7 | |||
| c22be1a557 | |||
| 05f716d4ac | |||
| 25b0d5b8c4 | |||
| 4d2c981d55 | |||
| cf03f5b718 | |||
| 5313aa6231 | |||
| ea8364e581 | |||
| e00df25aee | |||
| 91fd181245 | |||
| 0fa32bd674 | |||
| aea73834f6 | |||
| a139213578 | |||
| 9c82b68f07 | |||
| d3e0750d5d | |||
| 4ac205e32f | |||
| ed2f956072 | |||
| a844065384 | |||
| 35952e61c1 | |||
| d199bc62ec | |||
| 8d314c96ee | |||
| e2c0208c86 | |||
| bd72927c63 | |||
| c4d66200b7 | |||
| 2ed7e05fc2 | |||
| cc2c4ae759 | |||
| 6bd55b54bc | |||
| 0513a8cfd8 | |||
| 306dc6e751 | |||
| dd25ef5679 | |||
| 016866792d | |||
| f0a2c63753 | |||
| 7eaae83f16 | |||
| 872ae1dd12 | |||
| 6ce01bd647 | |||
| 0922210c5c | |||
| 02a8d664a2 | |||
| e6faf607f7 | |||
| d8d8b2ae77 | |||
| 84b82a6cb7 | |||
| e46ec5f88f | |||
| 25c177aace | |||
| c7e08958b8 | |||
| dd5a36291f | |||
| 7271f8b717 | |||
| dfcce3ca6e | |||
| 2457599114 | |||
| bdd16116f3 | |||
| c8b0f0eb21 | |||
| 7a4324cce3 | |||
| 37a787a106 | |||
| d56825e4b4 | |||
| cd1b8d7ca8 | |||
| db91e710da | |||
| 2a62aadcff | |||
| 4f74a5e1f7 | |||
| bbe8d3ae13 | |||
| 907fd91ce9 | |||
| 0c7cb9a613 | |||
| 84e5cc596c | |||
| cc92332096 | |||
| 9cfd4ef074 | |||
| 78a78515d6 | |||
| 9c03a7da43 | |||
| 1d3120fbaa | |||
| c78ee143e9 | |||
| 622f35b1d0 | |||
| 39baf0b41b | |||
| 1c4c4c48d9 | |||
| d840253f6a | |||
| 536c297a14 | |||
| 693a0d08e4 | |||
| cac7adab11 | |||
| a584d42ce5 | |||
| cdcc01be0e | |||
| ba59e92fb0 | |||
| 02247d9ce1 | |||
| 940f9410cb | |||
| ad06e5106e | |||
| ae2fc01a91 | |||
| 16d56c4b4f | |||
| c82f7bafba | |||
| d9e7857af3 | |||
| fd1c54abf2 | |||
| 9946dcf8db |
@@ -13,7 +13,7 @@ body:
|
||||
*Give your issue a fitting title. Assume that someone which very limited knowledge of diffusers can understand your issue. Add links to the source code, documentation other issues, pull requests etc...*
|
||||
- 2. If your issue is about something not working, **always** provide a reproducible code snippet. The reader should be able to reproduce your issue by **only copy-pasting your code snippet into a Python shell**.
|
||||
*The community cannot solve your issue if it cannot reproduce it. If your bug is related to training, add your training script and make everything needed to train public. Otherwise, just add a simple Python code snippet.*
|
||||
- 3. Add the **minimum amount of code / context that is needed to understand, reproduce your issue**.
|
||||
- 3. Add the **minimum** amount of code / context that is needed to understand, reproduce your issue.
|
||||
*Make the life of maintainers easy. `diffusers` is getting many issues every day. Make sure your issue is about one bug and one bug only. Make sure you add only the context, code needed to understand your issues - nothing more. Generally, every issue is a way of documenting this library, try to make it a good documentation entry.*
|
||||
- 4. For issues related to community pipelines (i.e., the pipelines located in the `examples/community` folder), please tag the author of the pipeline in your issue thread as those pipelines are not maintained.
|
||||
- type: markdown
|
||||
@@ -61,21 +61,46 @@ body:
|
||||
All issues are read by one of the core maintainers, so if you don't know who to tag, just leave this blank and
|
||||
a core maintainer will ping the right person.
|
||||
|
||||
Please tag fewer than 3 people.
|
||||
|
||||
General library related questions: @patrickvonplaten and @sayakpaul
|
||||
Please tag a maximum of 2 people.
|
||||
|
||||
Questions on the training examples: @williamberman, @sayakpaul, @yiyixuxu
|
||||
Questions on DiffusionPipeline (Saving, Loading, From pretrained, ...):
|
||||
|
||||
Questions on memory optimizations, LoRA, float16, etc.: @williamberman, @patrickvonplaten, and @sayakpaul
|
||||
Questions on pipelines:
|
||||
- Stable Diffusion @yiyixuxu @DN6 @patrickvonplaten @sayakpaul @patrickvonplaten
|
||||
- Stable Diffusion XL @yiyixuxu @sayakpaul @DN6 @patrickvonplaten
|
||||
- Kandinsky @yiyixuxu @patrickvonplaten
|
||||
- ControlNet @sayakpaul @yiyixuxu @DN6 @patrickvonplaten
|
||||
- T2I Adapter @sayakpaul @yiyixuxu @DN6 @patrickvonplaten
|
||||
- IF @DN6 @patrickvonplaten
|
||||
- Text-to-Video / Video-to-Video @DN6 @sayakpaul @patrickvonplaten
|
||||
- Wuerstchen @DN6 @patrickvonplaten
|
||||
- Other: @yiyixuxu @DN6
|
||||
|
||||
Questions on schedulers: @patrickvonplaten and @williamberman
|
||||
Questions on models:
|
||||
- UNet @DN6 @yiyixuxu @sayakpaul @patrickvonplaten
|
||||
- VAE @sayakpaul @DN6 @yiyixuxu @patrickvonplaten
|
||||
- Transformers/Attention @DN6 @yiyixuxu @sayakpaul @DN6 @patrickvonplaten
|
||||
|
||||
Questions on models and pipelines: @patrickvonplaten, @sayakpaul, and @williamberman (for community pipelines, please tag the original author of the pipeline)
|
||||
Questions on Schedulers: @yiyixuxu @patrickvonplaten
|
||||
|
||||
Questions on LoRA: @sayakpaul @patrickvonplaten
|
||||
|
||||
Questions on Textual Inversion: @sayakpaul @patrickvonplaten
|
||||
|
||||
Questions on Training:
|
||||
- DreamBooth @sayakpaul @patrickvonplaten
|
||||
- Text-to-Image Fine-tuning @sayakpaul @patrickvonplaten
|
||||
- Textual Inversion @sayakpaul @patrickvonplaten
|
||||
- ControlNet @sayakpaul @patrickvonplaten
|
||||
|
||||
Questions on Tests: @DN6 @sayakpaul @yiyixuxu
|
||||
|
||||
Questions on Documentation: @stevhliu
|
||||
|
||||
Questions on JAX- and MPS-related things: @pcuenca
|
||||
|
||||
Questions on audio pipelines: @patrickvonplaten, @kashif, and @sanchit-gandhi
|
||||
Questions on audio pipelines: @DN6 @patrickvonplaten
|
||||
|
||||
|
||||
|
||||
Documentation: @stevhliu and @yiyixuxu
|
||||
placeholder: "@Username ..."
|
||||
|
||||
@@ -27,6 +27,7 @@ jobs:
|
||||
- diffusers-pytorch-cpu
|
||||
- diffusers-pytorch-cuda
|
||||
- diffusers-pytorch-compile-cuda
|
||||
- diffusers-pytorch-xformers-cuda
|
||||
- diffusers-flax-cpu
|
||||
- diffusers-flax-tpu
|
||||
- diffusers-onnxruntime-cpu
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
name: Slow tests on main
|
||||
name: Slow Tests on main
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
|
||||
env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
HF_HOME: /mnt/cache
|
||||
@@ -12,53 +13,115 @@ env:
|
||||
MKL_NUM_THREADS: 8
|
||||
PYTEST_TIMEOUT: 600
|
||||
RUN_SLOW: yes
|
||||
PIPELINE_USAGE_CUTOFF: 50000
|
||||
|
||||
jobs:
|
||||
run_slow_tests:
|
||||
setup_torch_cuda_pipeline_matrix:
|
||||
name: Setup Torch Pipelines CUDA Slow Tests Matrix
|
||||
runs-on: docker-gpu
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cpu # this is a CPU image, but we need it to fetch the matrix
|
||||
options: --shm-size "16gb" --ipc host
|
||||
outputs:
|
||||
pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m pip install -e .[quality,test]
|
||||
python -m pip install git+https://github.com/huggingface/accelerate.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Fetch Pipeline Matrix
|
||||
id: fetch_pipeline_matrix
|
||||
run: |
|
||||
matrix=$(python utils/fetch_torch_cuda_pipeline_test_matrix.py)
|
||||
echo $matrix
|
||||
echo "pipeline_test_matrix=$matrix" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Pipeline Tests Artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: test-pipelines.json
|
||||
path: reports
|
||||
|
||||
torch_pipelines_cuda_tests:
|
||||
name: Torch Pipelines CUDA Slow Tests
|
||||
needs: setup_torch_cuda_pipeline_matrix
|
||||
strategy:
|
||||
fail-fast: false
|
||||
max-parallel: 1
|
||||
matrix:
|
||||
config:
|
||||
- name: Slow PyTorch CUDA tests on Ubuntu
|
||||
framework: pytorch
|
||||
runner: docker-gpu
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
report: torch_cuda
|
||||
- name: Slow Flax TPU tests on Ubuntu
|
||||
framework: flax
|
||||
runner: docker-tpu
|
||||
image: diffusers/diffusers-flax-tpu
|
||||
report: flax_tpu
|
||||
- name: Slow ONNXRuntime CUDA tests on Ubuntu
|
||||
framework: onnxruntime
|
||||
runner: docker-gpu
|
||||
image: diffusers/diffusers-onnxruntime-cuda
|
||||
report: onnx_cuda
|
||||
|
||||
name: ${{ matrix.config.name }}
|
||||
|
||||
runs-on: ${{ matrix.config.runner }}
|
||||
|
||||
module: ${{ fromJson(needs.setup_torch_cuda_pipeline_matrix.outputs.pipeline_test_matrix) }}
|
||||
runs-on: docker-gpu
|
||||
container:
|
||||
image: ${{ matrix.config.image }}
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ ${{ matrix.config.runner == 'docker-tpu' && '--privileged' || '--gpus 0'}}
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ --gpus 0
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: NVIDIA-SMI
|
||||
run: |
|
||||
nvidia-smi
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m pip install -e .[quality,test]
|
||||
python -m pip install git+https://github.com/huggingface/accelerate.git
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
- name: Slow PyTorch CUDA checkpoint tests on Ubuntu
|
||||
env:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
|
||||
tests/pipelines/${{ matrix.module }}
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_pipeline_${{ matrix.module }}_cuda_stats.txt
|
||||
cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: pipeline_${{ matrix.module }}_test_reports
|
||||
path: reports
|
||||
|
||||
torch_cuda_tests:
|
||||
name: Torch CUDA Tests
|
||||
runs-on: docker-gpu
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ --gpus 0
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
module: [models, schedulers, lora, others]
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: NVIDIA-SMI
|
||||
if : ${{ matrix.config.runner == 'docker-gpu' }}
|
||||
run: |
|
||||
nvidia-smi
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
@@ -70,47 +133,121 @@ jobs:
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run slow PyTorch CUDA tests
|
||||
if: ${{ matrix.config.framework == 'pytorch' }}
|
||||
env:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "not Flax and not Onnx and not compile" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
|
||||
- name: Run slow Flax TPU tests
|
||||
if: ${{ matrix.config.framework == 'flax' }}
|
||||
env:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 0 \
|
||||
-s -v -k "Flax" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
|
||||
- name: Run slow ONNXRuntime CUDA tests
|
||||
if: ${{ matrix.config.framework == 'onnxruntime' }}
|
||||
env:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
-s -v -k "not Flax and not Onnx" \
|
||||
--make-reports=tests_torch_cuda \
|
||||
tests/${{ matrix.module }}
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
|
||||
run: |
|
||||
cat reports/tests_torch_cuda_stats.txt
|
||||
cat reports/tests_torch_cuda_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: ${{ matrix.config.report }}_test_reports
|
||||
name: torch_cuda_test_reports
|
||||
path: reports
|
||||
|
||||
flax_tpu_tests:
|
||||
name: Flax TPU Tests
|
||||
runs-on: docker-tpu
|
||||
container:
|
||||
image: diffusers/diffusers-flax-tpu
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ --privileged
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m pip install -e .[quality,test]
|
||||
python -m pip install git+https://github.com/huggingface/accelerate.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run slow Flax TPU tests
|
||||
env:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 0 \
|
||||
-s -v -k "Flax" \
|
||||
--make-reports=tests_flax_tpu \
|
||||
tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_flax_tpu_stats.txt
|
||||
cat reports/tests_flax_tpu_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: flax_tpu_test_reports
|
||||
path: reports
|
||||
|
||||
onnx_cuda_tests:
|
||||
name: ONNX CUDA Tests
|
||||
runs-on: docker-gpu
|
||||
container:
|
||||
image: diffusers/diffusers-onnxruntime-cuda
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ --gpus 0
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt-get update && apt-get install libsndfile1-dev libgl1 -y
|
||||
python -m pip install -e .[quality,test]
|
||||
python -m pip install git+https://github.com/huggingface/accelerate.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run slow ONNXRuntime CUDA tests
|
||||
env:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Onnx" \
|
||||
--make-reports=tests_onnx_cuda \
|
||||
tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_onnx_cuda_stats.txt
|
||||
cat reports/tests_onnx_cuda_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: onnx_cuda_test_reports
|
||||
path: reports
|
||||
|
||||
run_torch_compile_tests:
|
||||
@@ -131,21 +268,17 @@ jobs:
|
||||
- name: NVIDIA-SMI
|
||||
run: |
|
||||
nvidia-smi
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install -e .[quality,test,training]
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run example tests on GPU
|
||||
env:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_torch_compile_cuda_failures_short.txt
|
||||
@@ -157,6 +290,46 @@ jobs:
|
||||
name: torch_compile_test_reports
|
||||
path: reports
|
||||
|
||||
run_xformers_tests:
|
||||
name: PyTorch xformers CUDA tests
|
||||
|
||||
runs-on: docker-gpu
|
||||
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-xformers-cuda
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: NVIDIA-SMI
|
||||
run: |
|
||||
nvidia-smi
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install -e .[quality,test,training]
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
- name: Run example tests on GPU
|
||||
env:
|
||||
HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_torch_xformers_cuda_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: torch_xformers_test_reports
|
||||
path: reports
|
||||
|
||||
run_examples_tests:
|
||||
name: Examples PyTorch CUDA tests on Ubuntu
|
||||
|
||||
@@ -192,11 +365,13 @@ jobs:
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/examples_torch_cuda_failures_short.txt
|
||||
run: |
|
||||
cat reports/examples_torch_cuda_stats.txt
|
||||
cat reports/examples_torch_cuda_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: examples_test_reports
|
||||
path: reports
|
||||
path: reports
|
||||
@@ -14,22 +14,23 @@ RUN apt update && \
|
||||
libsndfile1-dev \
|
||||
libgl1 \
|
||||
python3.9 \
|
||||
python3.9-dev \
|
||||
python3-pip \
|
||||
python3.9-venv && \
|
||||
rm -rf /var/lib/apt/lists
|
||||
|
||||
# make sure to use venv
|
||||
RUN python3 -m venv /opt/venv
|
||||
RUN python3.9 -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
RUN python3.9 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3.9 -m pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
invisible_watermark && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
python3.9 -m pip install --no-cache-dir \
|
||||
accelerate \
|
||||
datasets \
|
||||
hf-doc-builder \
|
||||
@@ -40,8 +41,6 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers \
|
||||
omegaconf \
|
||||
pytorch-lightning \
|
||||
xformers
|
||||
omegaconf
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
@@ -40,8 +40,6 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers \
|
||||
omegaconf \
|
||||
pytorch-lightning \
|
||||
xformers
|
||||
omegaconf
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
FROM nvidia/cuda:11.7.1-cudnn8-runtime-ubuntu20.04
|
||||
LABEL maintainer="Hugging Face"
|
||||
LABEL repository="diffusers"
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt update && \
|
||||
apt install -y bash \
|
||||
build-essential \
|
||||
git \
|
||||
git-lfs \
|
||||
curl \
|
||||
ca-certificates \
|
||||
libsndfile1-dev \
|
||||
libgl1 \
|
||||
python3.8 \
|
||||
python3-pip \
|
||||
python3.8-venv && \
|
||||
rm -rf /var/lib/apt/lists
|
||||
|
||||
# make sure to use venv
|
||||
RUN python3 -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
torch==2.0.1 \
|
||||
torchvision==0.15.2 \
|
||||
torchaudio \
|
||||
invisible_watermark && \
|
||||
python3 -m pip install --no-cache-dir \
|
||||
accelerate \
|
||||
datasets \
|
||||
hf-doc-builder \
|
||||
huggingface-hub \
|
||||
Jinja2 \
|
||||
librosa \
|
||||
numpy \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers \
|
||||
omegaconf \
|
||||
xformers
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
+8
-8
@@ -128,7 +128,7 @@ When adding a new pipeline:
|
||||
- Possible an end-to-end example of how to use it
|
||||
- Add all the pipeline classes that should be linked in the diffusion model. These classes should be added using our Markdown syntax. By default as follows:
|
||||
|
||||
```
|
||||
```py
|
||||
## XXXPipeline
|
||||
|
||||
[[autodoc]] XXXPipeline
|
||||
@@ -138,7 +138,7 @@ When adding a new pipeline:
|
||||
|
||||
This will include every public method of the pipeline that is documented, as well as the `__call__` method that is not documented by default. If you just want to add additional methods that are not documented, you can put the list of all methods to add in a list that contains `all`.
|
||||
|
||||
```
|
||||
```py
|
||||
[[autodoc]] XXXPipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -172,7 +172,7 @@ Arguments should be defined with the `Args:` (or `Arguments:` or `Parameters:`)
|
||||
an indentation. The argument should be followed by its type, with its shape if it is a tensor, a colon, and its
|
||||
description:
|
||||
|
||||
```
|
||||
```py
|
||||
Args:
|
||||
n_layers (`int`): The number of layers of the model.
|
||||
```
|
||||
@@ -182,7 +182,7 @@ after the argument.
|
||||
|
||||
Here's an example showcasing everything so far:
|
||||
|
||||
```
|
||||
```py
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
@@ -196,13 +196,13 @@ Here's an example showcasing everything so far:
|
||||
For optional arguments or arguments with defaults we follow the following syntax: imagine we have a function with the
|
||||
following signature:
|
||||
|
||||
```
|
||||
```py
|
||||
def my_function(x: str = None, a: float = 1):
|
||||
```
|
||||
|
||||
then its documentation should look like this:
|
||||
|
||||
```
|
||||
```py
|
||||
Args:
|
||||
x (`str`, *optional*):
|
||||
This argument controls ...
|
||||
@@ -235,14 +235,14 @@ building the return.
|
||||
|
||||
Here's an example of a single value return:
|
||||
|
||||
```
|
||||
```py
|
||||
Returns:
|
||||
`List[int]`: A list of integers in the range [0, 1] --- 1 for a special token, 0 for a sequence token.
|
||||
```
|
||||
|
||||
Here's an example of a tuple return, comprising several objects:
|
||||
|
||||
```
|
||||
```py
|
||||
Returns:
|
||||
`tuple(torch.FloatTensor)` comprising various elements depending on the configuration ([`BertConfig`]) and inputs:
|
||||
- ** loss** (*optional*, returned when `masked_lm_labels` is provided) `torch.FloatTensor` of shape `(1,)` --
|
||||
|
||||
@@ -58,6 +58,8 @@
|
||||
title: Control image brightness
|
||||
- local: using-diffusers/weighted_prompts
|
||||
title: Prompt weighting
|
||||
- local: using-diffusers/freeu
|
||||
title: Improve generation quality with FreeU
|
||||
title: Techniques
|
||||
- sections:
|
||||
- local: using-diffusers/pipeline_overview
|
||||
@@ -104,6 +106,8 @@
|
||||
title: Custom Diffusion
|
||||
- local: training/t2i_adapters
|
||||
title: T2I-Adapters
|
||||
- local: training/ddpo
|
||||
title: Reinforcement learning training with DDPO
|
||||
title: Training
|
||||
- sections:
|
||||
- local: using-diffusers/other-modalities
|
||||
|
||||
@@ -28,8 +28,8 @@ This model was contributed by the community contributor [HimariO](https://github
|
||||
|
||||
| Pipeline | Tasks | Demo
|
||||
|---|---|:---:|
|
||||
| [StableDiffusionAdapterPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_adapter.py) | *Text-to-Image Generation with T2I-Adapter Conditioning* | -
|
||||
| [StableDiffusionXLAdapterPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_xl_adapter.py) | *Text-to-Image Generation with T2I-Adapter Conditioning on StableDiffusion-XL* | -
|
||||
| [StableDiffusionAdapterPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py) | *Text-to-Image Generation with T2I-Adapter Conditioning* | -
|
||||
| [StableDiffusionXLAdapterPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py) | *Text-to-Image Generation with T2I-Adapter Conditioning on StableDiffusion-XL* | -
|
||||
|
||||
## Usage example with the base model of StableDiffusion-1.4/1.5
|
||||
|
||||
|
||||
@@ -321,21 +321,9 @@ with torch.inference_mode():
|
||||
|
||||
Recent work on optimizing bandwidth in the attention block has generated huge speed-ups and reductions in GPU memory usage. The most recent type of memory-efficient attention is [Flash Attention](https://arxiv.org/pdf/2205.14135.pdf) (you can check out the original code at [HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention)).
|
||||
|
||||
The table below details the speed-ups from a few different Nvidia GPUs when running inference on image sizes of 512x512 and a batch size of 1 (one prompt):
|
||||
<Tip>
|
||||
|
||||
| GPU | base attention (fp16) | memory-efficient attention (fp16) |
|
||||
|------------------|-----------------------|-----------------------------------|
|
||||
| NVIDIA Tesla T4 | 3.5it/s | 5.5it/s |
|
||||
| NVIDIA 3060 RTX | 4.6it/s | 7.8it/s |
|
||||
| NVIDIA A10G | 8.88it/s | 15.6it/s |
|
||||
| NVIDIA RTX A6000 | 11.7it/s | 21.09it/s |
|
||||
| NVIDIA TITAN RTX | 12.51it/s | 18.22it/s |
|
||||
| A100-SXM4-40GB | 18.6it/s | 29.it/s |
|
||||
| A100-SXM-80GB | 18.7it/s | 29.5it/s |
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
If you have PyTorch 2.0 installed, you shouldn't use xFormers!
|
||||
If you have PyTorch >= 2.0 installed, you should not expect a speed-up for inference when enabling `xformers`.
|
||||
|
||||
</Tip>
|
||||
|
||||
@@ -365,3 +353,5 @@ with torch.inference_mode():
|
||||
# optional: You can disable it via
|
||||
# pipe.disable_xformers_memory_efficient_attention()
|
||||
```
|
||||
|
||||
The iteration speed when using `xformers` should match the iteration speed of Torch 2.0 as described [here](torch2.0).
|
||||
|
||||
@@ -276,6 +276,7 @@ In the following tables, we report our findings in terms of the *number of itera
|
||||
| SD - inpaint | 22.24 | 23.23 | 43.76 | 49.25 |
|
||||
| SD - controlnet | 15.02 | 15.82 | 32.13 | 36.08 |
|
||||
| IF | 20.21 / <br>13.84 / <br>24.00 | 20.12 / <br>13.70 / <br>24.03 | ❌ | 97.34 / <br>27.23 / <br>111.66 |
|
||||
| SDXL - txt2img | 8.64 | 9.9 | - | - |
|
||||
|
||||
### A100 (batch size: 4)
|
||||
|
||||
@@ -286,6 +287,7 @@ In the following tables, we report our findings in terms of the *number of itera
|
||||
| SD - inpaint | 11.67 | 13.31 | 14.88 | 17.48 |
|
||||
| SD - controlnet | 8.28 | 9.38 | 10.51 | 12.41 |
|
||||
| IF | 25.02 | 18.04 | ❌ | 48.47 |
|
||||
| SDXL - txt2img | 2.44 | 2.74 | - | - |
|
||||
|
||||
### A100 (batch size: 16)
|
||||
|
||||
@@ -296,6 +298,7 @@ In the following tables, we report our findings in terms of the *number of itera
|
||||
| SD - inpaint | 3.04 | 3.66 | 3.9 | 4.76 |
|
||||
| SD - controlnet | 2.15 | 2.58 | 2.74 | 3.35 |
|
||||
| IF | 8.78 | 9.82 | ❌ | 16.77 |
|
||||
| SDXL - txt2img | 0.64 | 0.72 | - | - |
|
||||
|
||||
### V100 (batch size: 1)
|
||||
|
||||
@@ -336,6 +339,7 @@ In the following tables, we report our findings in terms of the *number of itera
|
||||
| SD - inpaint | 6.91 | 6.7 | 7.01 | 7.37 |
|
||||
| SD - controlnet | 4.89 | 4.86 | 5.35 | 5.48 |
|
||||
| IF | 17.42 / <br>2.47 / <br>18.52 | 16.96 / <br>2.45 / <br>18.69 | ❌ | 24.63 / <br>2.47 / <br>23.39 |
|
||||
| SDXL - txt2img | 1.15 | 1.16 | - | - |
|
||||
|
||||
### T4 (batch size: 4)
|
||||
|
||||
@@ -346,6 +350,7 @@ In the following tables, we report our findings in terms of the *number of itera
|
||||
| SD - inpaint | 1.81 | 1.82 | 2.09 | 2.09 |
|
||||
| SD - controlnet | 1.34 | 1.27 | 1.47 | 1.46 |
|
||||
| IF | 5.79 | 5.61 | ❌ | 7.39 |
|
||||
| SDXL - txt2img | 0.288 | 0.289 | - | - |
|
||||
|
||||
### T4 (batch size: 16)
|
||||
|
||||
@@ -356,6 +361,7 @@ In the following tables, we report our findings in terms of the *number of itera
|
||||
| SD - inpaint | 2.30s | 2.26s | OOM after 2nd iteration | 1.95s |
|
||||
| SD - controlnet | OOM after 2nd iteration | OOM after 2nd iteration | OOM after warmup | OOM after warmup |
|
||||
| IF * | 1.44 | 1.44 | ❌ | 1.94 |
|
||||
| SDXL - txt2img | OOM | OOM | - | - |
|
||||
|
||||
### RTX 3090 (batch size: 1)
|
||||
|
||||
@@ -396,6 +402,7 @@ In the following tables, we report our findings in terms of the *number of itera
|
||||
| SD - inpaint | 40.51 | 41.88 | 44.58 | 49.72 |
|
||||
| SD - controlnet | 29.27 | 30.29 | 32.26 | 36.03 |
|
||||
| IF | 69.71 / <br>18.78 / <br>85.49 | 69.13 / <br>18.80 / <br>85.56 | ❌ | 124.60 / <br>26.37 / <br>138.79 |
|
||||
| SDXL - txt2img | 6.8 | 8.18 | - | - |
|
||||
|
||||
### RTX 4090 (batch size: 4)
|
||||
|
||||
@@ -406,6 +413,7 @@ In the following tables, we report our findings in terms of the *number of itera
|
||||
| SD - inpaint | 12.65 | 12.81 | 15.3 | 15.58 |
|
||||
| SD - controlnet | 9.1 | 9.25 | 11.03 | 11.22 |
|
||||
| IF | 31.88 | 31.14 | ❌ | 43.92 |
|
||||
| SDXL - txt2img | 2.19 | 2.35 | - | - |
|
||||
|
||||
### RTX 4090 (batch size: 16)
|
||||
|
||||
@@ -416,10 +424,11 @@ In the following tables, we report our findings in terms of the *number of itera
|
||||
| SD - inpaint | 3.17 | 3.2 | 3.85 | 3.85 |
|
||||
| SD - controlnet | 2.23 | 2.3 | 2.7 | 2.75 |
|
||||
| IF | 9.26 | 9.2 | ❌ | 13.31 |
|
||||
| SDXL - txt2img | 0.52 | 0.53 | - | - |
|
||||
|
||||
## Notes
|
||||
|
||||
* Follow this [PR](https://github.com/huggingface/diffusers/pull/3313) for more details on the environment used for conducting the benchmarks.
|
||||
* For the DeepFloyd IF pipeline where batch sizes > 1, we only used a batch size of > 1 in the first IF pipeline for text-to-image generation and NOT for upscaling. That means the two upscaling pipelines received a batch size of 1.
|
||||
|
||||
*Thanks to [Horace He](https://github.com/Chillee) from the PyTorch team for their support in improving our support of `torch.compile()` in Diffusers.*
|
||||
*Thanks to [Horace He](https://github.com/Chillee) from the PyTorch team for their support in improving our support of `torch.compile()` in Diffusers.*
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Reinforcement learning training with DDPO
|
||||
|
||||
You can fine-tune Stable Diffusion on a reward function via reinforcement learning with the 🤗 TRL library and 🤗 Diffusers. This is done with the Denoising Diffusion Policy Optimization (DDPO) algorithm introduced by Black et al. in [Training Diffusion Models with Reinforcement Learning](https://arxiv.org/abs/2305.13301), which is implemented in 🤗 TRL with the [`~trl.DDPOTrainer`].
|
||||
|
||||
For more information, check out the [`~trl.DDPOTrainer`] API reference and the [Finetune Stable Diffusion Models with DDPO via TRL](https://huggingface.co/blog/trl-ddpo) blog post.
|
||||
@@ -10,51 +10,297 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Conditional image generation
|
||||
# Text-to-image
|
||||
|
||||
[[open-in-colab]]
|
||||
|
||||
Conditional image generation allows you to generate images from a text prompt. The text is converted into embeddings which are used to condition the model to generate an image from noise.
|
||||
When you think of diffusion models, text-to-image is usually one of the first things that come to mind. Text-to-image generates an image from a text description (for example, "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k") which is also known as a *prompt*.
|
||||
|
||||
The [`DiffusionPipeline`] is the easiest way to use a pre-trained diffusion system for inference.
|
||||
From a very high level, a diffusion model takes a prompt and some random initial noise, and iteratively removes the noise to construct an image. The *denoising* process is guided by the prompt, and once the denoising process ends after a predetermined number of time steps, the image representation is decoded into an image.
|
||||
|
||||
Start by creating an instance of [`DiffusionPipeline`] and specify which pipeline [checkpoint](https://huggingface.co/models?library=diffusers&sort=downloads) you would like to download.
|
||||
<Tip>
|
||||
|
||||
In this guide, you'll use [`DiffusionPipeline`] for text-to-image generation with [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5):
|
||||
Read the [How does Stable Diffusion work?](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work) blog post to learn more about how a latent diffusion model works.
|
||||
|
||||
```python
|
||||
>>> from diffusers import DiffusionPipeline
|
||||
</Tip>
|
||||
|
||||
>>> generator = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_safetensors=True)
|
||||
You can generate images from a prompt in 🤗 Diffusers in two steps:
|
||||
|
||||
1. Load a checkpoint into the [`AutoPipelineForText2Image`] class, which automatically detects the appropriate pipeline class to use based on the checkpoint:
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16"
|
||||
).to("cuda")
|
||||
```
|
||||
|
||||
The [`DiffusionPipeline`] downloads and caches all modeling, tokenization, and scheduling components.
|
||||
Because the model consists of roughly 1.4 billion parameters, we strongly recommend running it on a GPU.
|
||||
You can move the generator object to a GPU, just like you would in PyTorch:
|
||||
2. Pass a prompt to the pipeline to generate an image:
|
||||
|
||||
```python
|
||||
>>> generator.to("cuda")
|
||||
```py
|
||||
image = pipeline(
|
||||
"stained glass of darth vader, backlight, centered composition, masterpiece, photorealistic, 8k"
|
||||
).images[0]
|
||||
```
|
||||
|
||||
Now you can use the `generator` on your text prompt:
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/text2img-vader.png"/>
|
||||
</div>
|
||||
|
||||
```python
|
||||
>>> image = generator("An image of a squirrel in Picasso style").images[0]
|
||||
## Popular models
|
||||
|
||||
The most common text-to-image models are [Stable Diffusion v1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5), [Stable Diffusion XL (SDXL)](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0), and [Kandinsky 2.2](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder). There are also ControlNet models or adapters that can be used with text-to-image models for more direct control in generating images. The results from each model are slightly different because of their architecture and training process, but no matter which model you choose, their usage is more or less the same. Let's use the same prompt for each model and compare their results.
|
||||
|
||||
### Stable Diffusion v1.5
|
||||
|
||||
[Stable Diffusion v1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5) is a latent diffusion model initialized from [Stable Diffusion v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4), and finetuned for 595K steps on 512x512 images from the LAION-Aesthetics V2 dataset. You can use this model like:
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16"
|
||||
).to("cuda")
|
||||
generator = torch.Generator("cuda").manual_seed(31)
|
||||
image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", generator=generator).images[0]
|
||||
```
|
||||
|
||||
The output is by default wrapped into a [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class) object.
|
||||
### Stable Diffusion XL
|
||||
|
||||
You can save the image by calling:
|
||||
SDXL is a much larger version of the previous Stable Diffusion models, and involves a two-stage model process that adds even more details to an image. It also includes some additional *micro-conditionings* to generate high-quality images centered subjects. Take a look at the more comprehensive [SDXL](sdxl) guide to learn more about how to use it. In general, you can use SDXL like:
|
||||
|
||||
```python
|
||||
>>> image.save("image_of_squirrel_painting.png")
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16"
|
||||
).to("cuda")
|
||||
generator = torch.Generator("cuda").manual_seed(31)
|
||||
image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", generator=generator).images[0]
|
||||
```
|
||||
|
||||
Try out the Spaces below, and feel free to play around with the guidance scale parameter to see how it affects the image quality!
|
||||
### Kandinsky 2.2
|
||||
|
||||
<iframe
|
||||
src="https://stabilityai-stable-diffusion.hf.space"
|
||||
frameborder="0"
|
||||
width="850"
|
||||
height="500"
|
||||
></iframe>
|
||||
The Kandinsky model is a bit different from the Stable Diffusion models because it also uses an image prior model to create embeddings that are used to better align text and images in the diffusion model.
|
||||
|
||||
The easiest way to use Kandinsky 2.2 is:
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
"kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16, variant="fp16"
|
||||
).to("cuda")
|
||||
generator = torch.Generator("cuda").manual_seed(31)
|
||||
image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", generator=generator).images[0]
|
||||
```
|
||||
|
||||
### ControlNet
|
||||
|
||||
ControlNet are auxiliary models or adapters that are finetuned on top of text-to-image models, such as [Stable Diffusion V1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5). Using ControlNet models in combination with text-to-image models offers diverse options for more explicit control over how to generate an image. With ControlNet's, you add an additional conditioning input image to the model. For example, if you provide an image of a human pose (usually represented as multiple keypoints that are connected into a skeleton) as a conditioning input, the model generates an image that follows the pose of the image. Check out the more in-depth [ControlNet](controlnet) guide to learn more about other conditioning inputs and how to use them.
|
||||
|
||||
In this example, let's condition the ControlNet with a human pose estimation image. Load the ControlNet model pretrained on human pose estimations:
|
||||
|
||||
```py
|
||||
from diffusers import ControlNetModel, AutoPipelineForText2Image
|
||||
from diffusers.utils import load_image
|
||||
import torch
|
||||
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
"lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16, variant="fp16"
|
||||
).to("cuda")
|
||||
pose_image = load_image("https://huggingface.co/lllyasviel/control_v11p_sd15_openpose/resolve/main/images/control.png")
|
||||
```
|
||||
|
||||
Pass the `controlnet` to the [`AutoPipelineForText2Image`], and provide the prompt and pose estimation image:
|
||||
|
||||
```py
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, variant="fp16"
|
||||
).to("cuda")
|
||||
generator = torch.Generator("cuda").manual_seed(31)
|
||||
image = pipeline("Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", image=pose_image, generator=generator).images[0]
|
||||
```
|
||||
|
||||
<div class="flex flex-row gap-4">
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/text2img-1.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">Stable Diffusion v1.5</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">Stable Diffusion XL</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/text2img-2.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">Kandinsky 2.2</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/text2img-3.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">ControlNet (pose conditioning)</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Configure pipeline parameters
|
||||
|
||||
There are a number of parameters that can be configured in the pipeline that affect how an image is generated. You can change the image's output size, specify a negative prompt to improve image quality, and more. This section dives deeper into how to use these parameters.
|
||||
|
||||
### Height and width
|
||||
|
||||
The `height` and `width` parameters control the height and width (in pixels) of the generated image. By default, the Stable Diffusion v1.5 model outputs 512x512 images, but you can change this to any size that is a multiple of 8. For example, to create a rectangular image:
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16"
|
||||
).to("cuda")
|
||||
image = pipeline(
|
||||
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", height=768, width=512
|
||||
).images[0]
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/text2img-hw.png"/>
|
||||
</div>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Other models may have different default image sizes depending on the image size's in the training dataset. For example, SDXL's default image size is 1024x1024 and using lower `height` and `width` values may result in lower quality images. Make sure you check the model's API reference first!
|
||||
|
||||
</Tip>
|
||||
|
||||
### Guidance scale
|
||||
|
||||
The `guidance_scale` parameter affects how much the prompt influences image generation. A lower value gives the model "creativity" to generate images that are more loosely related to the prompt. Higher `guidance_scale` values push the model to follow the prompt more closely, and if this value is too high, you may observe some artifacts in the generated image.
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
image = pipeline(
|
||||
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", guidance_scale=3.5
|
||||
).images[0]
|
||||
```
|
||||
|
||||
<div class="flex flex-row gap-4">
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/text2img-guidance-scale-2.5.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">guidance_scale = 2.5</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/text2img-guidance-scale-7.5.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">guidance_scale = 7.5</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/text2img-guidance-scale-10.5.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">guidance_scale = 10.5</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
### Negative prompt
|
||||
|
||||
Just like how a prompt guides generation, a *negative prompt* steers the model away from things you don't want the model to generate. This is commonly used to improve overall image quality by removing poor or bad image features such as "low resolution" or "bad details". You can also use a negative prompt to remove or modify the content and style of an image.
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
image = pipeline(
|
||||
prompt="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
||||
negative_prompt="ugly, deformed, disfigured, poor details, bad anatomy",
|
||||
).images[0]
|
||||
```
|
||||
|
||||
<div class="flex flex-row gap-4">
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/text2img-neg-prompt-1.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">negative prompt = "ugly, deformed, disfigured, poor details, bad anatomy"</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/text2img-neg-prompt-2.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">negative prompt = "astronaut"</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
### Generator
|
||||
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html#generator) object enables reproducibility in a pipeline by setting a manual seed. You can use a `Generator` to generate batches of images and iteratively improve on an image generated from a seed as detailed in the [Improve image quality with deterministic generation](reusing_seeds) guide.
|
||||
|
||||
You can set a seed and `Generator` as shown below. Creating an image with a `Generator` should return the same result each time instead of randomly generating a new image.
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
generator = torch.Generator(device="cuda").manual_seed(30)
|
||||
image = pipeline(
|
||||
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
||||
generator=generator,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
## Control image generation
|
||||
|
||||
There are several ways to exert more control over how an image is generated outside of configuring a pipeline's parameters, such as prompt weighting and ControlNet models.
|
||||
|
||||
### Prompt weighting
|
||||
|
||||
Prompt weighting is a technique for increasing or decreasing the importance of concepts in a prompt to emphasize or minimize certain features in an image. We recommend using the [Compel](https://github.com/damian0815/compel) library to help you generate the weighted prompt embeddings.
|
||||
|
||||
<Tip>
|
||||
|
||||
Learn how to create the prompt embeddings in the [Prompt weighting](weighted_prompts) guide. This example focuses on how to use the prompt embeddings in the pipeline.
|
||||
|
||||
</Tip>
|
||||
|
||||
Once you've created the embeddings, you can pass them to the `prompt_embeds` (and `negative_prompt_embeds` if you're using a negative prompt) parameter in the pipeline.
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
image = pipeline(
|
||||
prompt_emebds=prompt_embeds, # generated from Compel
|
||||
negative_prompt_embeds=negative_prompt_embeds, # generated from Compel
|
||||
).images[0]
|
||||
```
|
||||
|
||||
### ControlNet
|
||||
|
||||
As you saw in the [ControlNet](#controlnet) section, these models offer a more flexible and accurate way to generate images by incorporating an additional conditioning image input. Each ControlNet model is pretrained on a particular type of conditioning image to generate new images that resemble it. For example, if you take a ControlNet pretrained on depth maps, you can give the model a depth map as a conditioning input and it'll generate an image that preserves the spatial information in it. This is quicker and easier than specifying the depth information in a prompt. You can even combine multiple conditioning inputs with a [MultiControlNet](controlnet#multicontrolnet)!
|
||||
|
||||
There are many types of conditioning inputs you can use, and 🤗 Diffusers supports ControlNet for Stable Diffusion and SDXL models. Take a look at the more comprehensive [ControlNet](controlnet) guide to learn how you can use these models.
|
||||
|
||||
## Optimize
|
||||
|
||||
Diffusion models are large, and the iterative nature of denoising an image is computationally expensive and intensive. But this doesn't mean you need access to powerful - or even many - GPUs to use them. There are many optimization techniques for running diffusion models on consumer and free-tier resources. For example, you can load model weights in half-precision to save GPU memory and increase speed or offload the entire model to the GPU to save even more memory.
|
||||
|
||||
PyTorch 2.0 also supports a more memory-efficient attention mechanism called [*scaled dot product attention*](../optimization/torch2.0#scaled-dot-product-attention) that is automatically enabled if you're using PyTorch 2.0. You can combine this with [`torch.compile`](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) to speed your code up even more:
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16").to("cuda")
|
||||
pipeline.unet = torch.compile(pipeline.unet, mode="reduce-overheard", fullgraph=True)
|
||||
```
|
||||
|
||||
For more tips on how to optimize your code to save memory and speed up inference, read the [Memory and speed](../optimization/fp16) and [Torch 2.0](../optimization/torch2.0) guides.
|
||||
@@ -0,0 +1,123 @@
|
||||
# Improve generation quality with FreeU
|
||||
|
||||
[[open-in-colab]]
|
||||
|
||||
The UNet is responsible for denoising during the reverse diffusion process, and there are two distinct features in its architecture:
|
||||
|
||||
1. Backbone features primarily contribute to the denoising process
|
||||
2. Skip features mainly introduce high-frequency features into the decoder module and can make the network overlook the semantics in the backbone features
|
||||
|
||||
However, the skip connection can sometimes introduce unnatural image details. [FreeU](https://hf.co/papers/2309.11497) is a technique for improving image quality by rebalancing the contributions from the UNet’s skip connections and backbone feature maps.
|
||||
|
||||
FreeU is applied during inference and it does not require any additional training. The technique works for different tasks such as text-to-image, image-to-image, and text-to-video.
|
||||
|
||||
In this guide, you will apply FreeU to the [`StableDiffusionPipeline`], [`StableDiffusionXLPipeline`], and [`TextToVideoSDPipeline`].
|
||||
|
||||
## StableDiffusionPipeline
|
||||
|
||||
Load the pipeline:
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, safety_checker=None
|
||||
).to("cuda")
|
||||
```
|
||||
|
||||
Then enable the FreeU mechanism with the FreeU-specific hyperparameters. These values are scaling factors for the backbone and skip features.
|
||||
|
||||
```py
|
||||
pipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
|
||||
```
|
||||
|
||||
The values above are from the official FreeU [code repository](https://github.com/ChenyangSi/FreeU) where you can also find [reference hyperparameters](https://github.com/ChenyangSi/FreeU#range-for-more-parameters) for different models.
|
||||
|
||||
<Tip>
|
||||
|
||||
Disable the FreeU mechanism by calling `disable_freeu()` on a pipeline.
|
||||
|
||||
</Tip>
|
||||
|
||||
And then run inference:
|
||||
|
||||
```py
|
||||
prompt = "A squirrel eating a burger"
|
||||
seed = 2023
|
||||
image = pipeline(prompt, generator=torch.manual_seed(seed)).images[0]
|
||||
```
|
||||
|
||||
The figure below compares non-FreeU and FreeU results respectively for the same hyperparameters used above (`prompt` and `seed`):
|
||||
|
||||

|
||||
|
||||
|
||||
Let's see how Stable Diffusion 2 results are impacted:
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16, safety_checker=None
|
||||
).to("cuda")
|
||||
|
||||
prompt = "A squirrel eating a burger"
|
||||
seed = 2023
|
||||
|
||||
pipeline.enable_freeu(s1=0.9, s2=0.2, b1=1.1, b2=1.2)
|
||||
image = pipeline(prompt, generator=torch.manual_seed(seed)).images[0]
|
||||
```
|
||||
|
||||
|
||||

|
||||
|
||||
## Stable Diffusion XL
|
||||
|
||||
Finally, let's take a look at how FreeU affects Stable Diffusion XL results:
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16,
|
||||
).to("cuda")
|
||||
|
||||
prompt = "A squirrel eating a burger"
|
||||
seed = 2023
|
||||
|
||||
# Comes from
|
||||
# https://wandb.ai/nasirk24/UNET-FreeU-SDXL/reports/FreeU-SDXL-Optimal-Parameters--Vmlldzo1NDg4NTUw
|
||||
pipeline.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2)
|
||||
image = pipeline(prompt, generator=torch.manual_seed(seed)).images[0]
|
||||
```
|
||||
|
||||
|
||||

|
||||
|
||||
## Text-to-video generation
|
||||
|
||||
FreeU can also be used to improve video quality:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
import torch
|
||||
|
||||
model_id = "cerspense/zeroscope_v2_576w"
|
||||
pipe = DiffusionPipeline.from_pretrained("cerspense/zeroscope_v2_576w", torch_dtype=torch.float16).to("cuda")
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
prompt = "an astronaut riding a horse on mars"
|
||||
seed = 2023
|
||||
|
||||
# The values come from
|
||||
# https://github.com/lyn-rgb/FreeU_Diffusers#video-pipelines
|
||||
pipe.enable_freeu(b1=1.2, b2=1.4, s1=0.9, s2=0.2)
|
||||
video_frames = pipe(prompt, height=320, width=576, num_frames=30, generator=torch.manual_seed(seed)).frames
|
||||
export_to_video(video_frames, "astronaut_rides_horse.mp4")
|
||||
```
|
||||
|
||||
Thanks to [kadirnar](https://github.com/kadirnar/) for helping to integrate the feature, and to [justindujardin](https://github.com/justindujardin) for the helpful discussions.
|
||||
@@ -33,7 +33,7 @@ pipeline.enable_xformers_memory_efficient_attention()
|
||||
|
||||
<Tip>
|
||||
|
||||
You'll notice throughout the guide, we use [`~DiffusionPipeline.enable_model_cpu_offload`] and [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`], to save memory and increase inference speed. If you're using PyTorch 2.0, then you don't need to call [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`] on your pipeline because it'll already be using PyTorch 2.0's native [scaled-dot product attention](/optimization/torch2.0#scaled-dot-product-attention).
|
||||
You'll notice throughout the guide, we use [`~DiffusionPipeline.enable_model_cpu_offload`] and [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`], to save memory and increase inference speed. If you're using PyTorch 2.0, then you don't need to call [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`] on your pipeline because it'll already be using PyTorch 2.0's native [scaled-dot product attention](../optimization/torch2.0#scaled-dot-product-attention).
|
||||
|
||||
</Tip>
|
||||
|
||||
@@ -131,7 +131,7 @@ init_image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
|
||||
|
||||
# pass prompt and image to pipeline
|
||||
image = pipeline(prompt, image=init_image, strength=).images[0]
|
||||
image = pipeline(prompt, image=init_image, strength=0.5).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
@@ -590,17 +590,17 @@ image
|
||||
|
||||
## Optimize
|
||||
|
||||
Running diffusion models is computationally expensive and intensive, but with a few optimization tricks, it is entirely possible to run them on consumer and free-tier GPUs. For example, you can use a more memory-efficient form of attention such as PyTorch 2.0's [scaled-dot product attention](optimization/torch2.0#scaled-dot-product-attention) or [xFormers](optimization/xformers) (you can use one or the other, but there's no need to use both). You can also offload the model to the GPU while the other pipeline components wait on the CPU.
|
||||
Running diffusion models is computationally expensive and intensive, but with a few optimization tricks, it is entirely possible to run them on consumer and free-tier GPUs. For example, you can use a more memory-efficient form of attention such as PyTorch 2.0's [scaled-dot product attention](../optimization/torch2.0#scaled-dot-product-attention) or [xFormers](../optimization/xformers) (you can use one or the other, but there's no need to use both). You can also offload the model to the GPU while the other pipeline components wait on the CPU.
|
||||
|
||||
```diff
|
||||
+ pipeline.enable_model_cpu_offload()
|
||||
+ pipeline.enable_xformers_memory_efficient_attention()
|
||||
```
|
||||
|
||||
With [`torch.compile`](optimization/torch2.0#torch.compile), you can boost your inference speed even more by wrapping your UNet with it:
|
||||
With [`torch.compile`](../optimization/torch2.0#torch.compile), you can boost your inference speed even more by wrapping your UNet with it:
|
||||
|
||||
```py
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
```
|
||||
|
||||
To learn more, take a look at the [Reduce memory usage](optimization/memory) and [Torch 2.0](optimization/torch2.0) guides.
|
||||
To learn more, take a look at the [Reduce memory usage](../optimization/memory) and [Torch 2.0](../optimization/torch2.0) guides.
|
||||
|
||||
@@ -10,87 +10,302 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Text-guided image-inpainting
|
||||
# Inpainting
|
||||
|
||||
[[open-in-colab]]
|
||||
|
||||
The [`StableDiffusionInpaintPipeline`] allows you to edit specific parts of an image by providing a mask and a text prompt. It uses a version of Stable Diffusion, like [`runwayml/stable-diffusion-inpainting`](https://huggingface.co/runwayml/stable-diffusion-inpainting) specifically trained for inpainting tasks.
|
||||
Inpainting replaces or edits specific areas of an image. This makes it a useful tool for image restoration like removing defects and artifacts, or even replacing an image area with something entirely new. Inpainting relies on a mask to determine which regions of an image to fill in; the area to inpaint is represented by white pixels and the area to keep is represented by black pixels. The white pixels are filled in by the prompt.
|
||||
|
||||
Get started by loading an instance of the [`StableDiffusionInpaintPipeline`]:
|
||||
With 🤗 Diffusers, here is how you can do inpainting:
|
||||
|
||||
```python
|
||||
import PIL
|
||||
import requests
|
||||
1. Load an inpainting checkpoint with the [`AutoPipelineForInpainting`] class. This'll automatically detect the appropriate pipeline class to load based on the checkpoint:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from io import BytesIO
|
||||
from diffusers import AutoPipelineForInpainting
|
||||
from diffusers.utils import load_image
|
||||
|
||||
from diffusers import StableDiffusionInpaintPipeline
|
||||
|
||||
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
torch_dtype=torch.float16,
|
||||
use_safetensors=True,
|
||||
variant="fp16",
|
||||
)
|
||||
pipeline = pipeline.to("cuda")
|
||||
pipeline = AutoPipelineForInpainting.from_pretrained(
|
||||
"kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.enable_xformers_memory_efficient_attention()
|
||||
```
|
||||
|
||||
Download an image and a mask of a dog which you'll eventually replace:
|
||||
<Tip>
|
||||
|
||||
```python
|
||||
def download_image(url):
|
||||
response = requests.get(url)
|
||||
return PIL.Image.open(BytesIO(response.content)).convert("RGB")
|
||||
|
||||
|
||||
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
|
||||
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
|
||||
|
||||
init_image = download_image(img_url).resize((512, 512))
|
||||
mask_image = download_image(mask_url).resize((512, 512))
|
||||
```
|
||||
|
||||
Now you can create a prompt to replace the mask with something else:
|
||||
|
||||
```python
|
||||
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
|
||||
image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
|
||||
```
|
||||
|
||||
`image` | `mask_image` | `prompt` | output |
|
||||
:-------------------------:|:-------------------------:|:-------------------------:|-------------------------:|
|
||||
<img src="https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" alt="drawing" width="250"/> | <img src="https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" alt="drawing" width="250"/> | ***Face of a yellow cat, high resolution, sitting on a park bench*** | <img src="https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint/yellow_cat_sitting_on_a_park_bench.png" alt="drawing" width="250"/> |
|
||||
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
A previous experimental implementation of inpainting used a different, lower-quality process. To ensure backwards compatibility, loading a pretrained pipeline that doesn't contain the new model will still apply the old inpainting method.
|
||||
You'll notice throughout the guide, we use [`~DiffusionPipeline.enable_model_cpu_offload`] and [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`], to save memory and increase inference speed. If you're using PyTorch 2.0, it's not necessary to call [`~DiffusionPipeline.enable_xformers_memory_efficient_attention`] on your pipeline because it'll already be using PyTorch 2.0's native [scaled-dot product attention](../optimization/torch2.0#scaled-dot-product-attention).
|
||||
|
||||
</Tip>
|
||||
|
||||
Check out the Spaces below to try out image inpainting yourself!
|
||||
2. Load the base and mask images:
|
||||
|
||||
```py
|
||||
init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
|
||||
mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
|
||||
```
|
||||
|
||||
3. Create a prompt to inpaint the image with and pass it to the pipeline with the base and mask images:
|
||||
|
||||
```py
|
||||
prompt = "a black cat with glowing eyes, cute, adorable, disney, pixar, highly detailed, 8k"
|
||||
negative_prompt = "bad anatomy, deformed, ugly, disfigured"
|
||||
image = pipeline(prompt=prompt, negative_prompt=negative_prompt, image=init_image, mask_image=mask_image).images[0]
|
||||
```
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">base image</figcaption>
|
||||
</div>
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint-cat.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">generated image</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Create a mask image
|
||||
|
||||
Throughout this guide, the mask image is provided in all of the code examples for convenience. You can inpaint on your own images, but you'll need to create a mask image for it. Use the Space below to easily create a mask image.
|
||||
|
||||
Upload a base image to inpaint on and use the sketch tool to draw a mask. Once you're done, click **Run** to generate and download the mask image.
|
||||
|
||||
<iframe
|
||||
src="https://runwayml-stable-diffusion-inpainting.hf.space"
|
||||
src="https://stevhliu-inpaint-mask-maker.hf.space"
|
||||
frameborder="0"
|
||||
width="850"
|
||||
height="500"
|
||||
height="450"
|
||||
></iframe>
|
||||
|
||||
## Preserving the Unmasked Area of the Image
|
||||
## Popular models
|
||||
|
||||
Generally speaking, [`StableDiffusionInpaintPipeline`] (and other inpainting pipelines) will change the unmasked part of the image as well. If this behavior is undesirable, you can force the unmasked area to remain the same as follows:
|
||||
[Stable Diffusion Inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting), [Stable Diffusion XL (SDXL) Inpainting](https://huggingface.co/diffusers/stable-diffusion-xl-1.0-inpainting-0.1), and [Kandinsky 2.2](https://huggingface.co/kandinsky-community/kandinsky-2-2-decoder-inpaint) are among the most popular models for inpainting. SDXL typically produces higher resolution images than Stable Diffusion v1.5, and Kandinsky 2.2 is also capable of generating high-quality images.
|
||||
|
||||
```python
|
||||
### Stable Diffusion Inpainting
|
||||
|
||||
Stable Diffusion Inpainting is a latent diffusion model finetuned on 512x512 images on inpainting. It is a good starting point because it is relatively fast and generates good quality images. To use this model for inpainting, you'll need to pass a prompt, base and mask image to the pipeline:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import AutoPipelineForInpainting
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipeline = AutoPipelineForInpainting.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
|
||||
).to("cuda")
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.enable_xformers_memory_efficient_attention()
|
||||
|
||||
# load base and mask image
|
||||
init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
|
||||
mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
|
||||
|
||||
generator = torch.Generator("cuda").manual_seed(92)
|
||||
prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
|
||||
image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, generator=generator).images[0]
|
||||
```
|
||||
|
||||
### Stable Diffusion XL (SDXL) Inpainting
|
||||
|
||||
SDXL is a larger and more powerful version of Stable Diffusion v1.5. This model can follow a two-stage model process (though each model can also be used alone); the base model generates an image, and a refiner model takes that image and further enhances its details and quality. Take a look at the [SDXL](sdxl) guide for a more comprehensive guide on how to use SDXL and configure it's parameters.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import AutoPipelineForInpainting
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipeline = AutoPipelineForInpainting.from_pretrained(
|
||||
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16, variant="fp16"
|
||||
).to("cuda")
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.enable_xformers_memory_efficient_attention()
|
||||
|
||||
# load base and mask image
|
||||
init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
|
||||
mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
|
||||
|
||||
generator = torch.Generator("cuda").manual_seed(92)
|
||||
prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
|
||||
image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, generator=generator).images[0]
|
||||
```
|
||||
|
||||
### Kandinsky 2.2 Inpainting
|
||||
|
||||
The Kandinsky model family is similar to SDXL because it uses two models as well; the image prior model creates image embeddings, and the diffusion model generates images from them. You can load the image prior and diffusion model separately, but the easiest way to use Kandinsky 2.2 is to load it into the [`AutoPipelineForInpainting`] class which uses the [`KandinskyV22InpaintCombinedPipeline`] under the hood.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import AutoPipelineForInpainting
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipeline = AutoPipelineForInpainting.from_pretrained(
|
||||
"kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.enable_xformers_memory_efficient_attention()
|
||||
|
||||
# load base and mask image
|
||||
init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
|
||||
mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
|
||||
|
||||
generator = torch.Generator("cuda").manual_seed(92)
|
||||
prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
|
||||
image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, generator=generator).images[0]
|
||||
```
|
||||
|
||||
<div class="flex flex-row gap-4">
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">base image</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint-sdv1.5.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">Stable Diffusion Inpainting</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint-sdxl.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">Stable Diffusion XL Inpainting</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint-kandinsky.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">Kandinsky 2.2 Inpainting</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Configure pipeline parameters
|
||||
|
||||
Image features - like quality and "creativity" - are dependent on pipeline parameters. Knowing what these parameters do is important for getting the results you want. Let's take a look at the most important parameters and see how changing them affects the output.
|
||||
|
||||
### Strength
|
||||
|
||||
`strength` is a measure of how much noise is added to the base image, which influences how similar the output is to the base image.
|
||||
|
||||
* 📈 a high `strength` value means more noise is added to an image and the denoising process takes longer, but you'll get higher quality images that are more different from the base image
|
||||
* 📉 a low `strength` value means less noise is added to an image and the denoising process is faster, but the image quality may not be as great and the generated image resembles the base image more
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import AutoPipelineForInpainting
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipeline = AutoPipelineForInpainting.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
|
||||
).to("cuda")
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.enable_xformers_memory_efficient_attention()
|
||||
|
||||
# load base and mask image
|
||||
init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
|
||||
mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
|
||||
|
||||
prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
|
||||
image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, strength=0.6).images[0]
|
||||
```
|
||||
|
||||
<div class="flex flex-row gap-4">
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint-strength-0.6.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">strength = 0.6</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint-strength-0.8.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">strength = 0.8</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint-strength-1.0.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">strength = 1.0</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
### Guidance scale
|
||||
|
||||
`guidance_scale` affects how aligned the text prompt and generated image are.
|
||||
|
||||
* 📈 a high `guidance_scale` value means the prompt and generated image are closely aligned, so the output is a stricter interpretation of the prompt
|
||||
* 📉 a low `guidance_scale` value means the prompt and generated image are more loosely aligned, so the output may be more varied from the prompt
|
||||
|
||||
You can use `strength` and `guidance_scale` together for more control over how expressive the model is. For example, a combination high `strength` and `guidance_scale` values gives the model the most creative freedom.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import AutoPipelineForInpainting
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipeline = AutoPipelineForInpainting.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
|
||||
).to("cuda")
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.enable_xformers_memory_efficient_attention()
|
||||
|
||||
# load base and mask image
|
||||
init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
|
||||
mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
|
||||
|
||||
prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
|
||||
image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, guidance_scale=2.5).images[0]
|
||||
```
|
||||
|
||||
<div class="flex flex-row gap-4">
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint-guidance-2.5.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">guidance_scale = 2.5</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint-guidance-7.5.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">guidance_scale = 7.5</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint-guidance-12.5.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">guidance_scale = 12.5</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
### Negative prompt
|
||||
|
||||
A negative prompt assumes the opposite role of a prompt; it guides the model away from generating certain things in an image. This is useful for quickly improving image quality and preventing the model from generating things you don't want.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import AutoPipelineForInpainting
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipeline = AutoPipelineForInpainting.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
|
||||
).to("cuda")
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.enable_xformers_memory_efficient_attention()
|
||||
|
||||
# load base and mask image
|
||||
init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
|
||||
mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
|
||||
|
||||
prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
|
||||
negative_prompt = "bad architecture, unstable, poor details, blurry"
|
||||
image = pipeline(prompt=prompt, negative_prompt=negative_prompt, image=init_image, mask_image=mask_image).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<figure>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint-negative.png" />
|
||||
<figcaption class="text-center">negative_prompt = "bad architecture, unstable, poor details, blurry"</figcaption>
|
||||
</figure>
|
||||
</div>
|
||||
|
||||
## Preserve unmasked areas
|
||||
|
||||
The [`AutoPipelineForInpainting`] (and other inpainting pipelines) generally changes the unmasked parts of an image to create a more natural transition between the masked and unmasked region. If this behavior is undesirable, you can force the unmasked area to remain the same. However, forcing the unmasked portion of the image to remain the same may result in some unusual transitions between the unmasked and masked areas.
|
||||
|
||||
```py
|
||||
import PIL
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import StableDiffusionInpaintPipeline
|
||||
from diffusers import AutoPipelineForInpainting
|
||||
from diffusers.utils import load_image
|
||||
|
||||
device = "cuda"
|
||||
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
pipeline = AutoPipelineForInpainting.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting",
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
@@ -121,4 +336,257 @@ unmasked_unchanged_image = PIL.Image.fromarray(unmasked_unchanged_image_arr.roun
|
||||
unmasked_unchanged_image.save("force_unmasked_unchanged.png")
|
||||
```
|
||||
|
||||
Forcing the unmasked portion of the image to remain the same might result in some weird transitions between the unmasked and masked areas, since the model will typically change the masked and unmasked areas to make the transition more natural.
|
||||
## Chained inpainting pipelines
|
||||
|
||||
[`AutoPipelineForInpainting`] can be chained with other 🤗 Diffusers pipelines to edit their outputs. This is often useful for improving the output quality from your other diffusion pipelines, and if you're using multiple pipelines, it can be more memory-efficient to chain them together to keep the outputs in latent space and reuse the same pipeline components.
|
||||
|
||||
### Text-to-image-to-inpaint
|
||||
|
||||
Chaining a text-to-image and inpainting pipeline allows you to inpaint the generated image, and you don't have to provide a base image to begin with. This makes it convenient to edit your favorite text-to-image outputs without having to generate an entirely new image.
|
||||
|
||||
Start with the text-to-image pipeline to create a castle:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import AutoPipelineForText2Image, AutoPipelineForInpainting
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
).to("cuda")
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.enable_xformers_memory_efficient_attention()
|
||||
|
||||
image = pipeline("concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k").images[0]
|
||||
```
|
||||
|
||||
Load the mask image of the output from above:
|
||||
|
||||
```py
|
||||
mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_text-chain-mask.png").convert("RGB")
|
||||
```
|
||||
|
||||
And let's inpaint the masked area with a waterfall:
|
||||
|
||||
```py
|
||||
pipeline = AutoPipelineForInpainting.from_pretrained(
|
||||
"kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16, variant="fp16"
|
||||
).to("cuda")
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.enable_xformers_memory_efficient_attention()
|
||||
|
||||
prompt = "digital painting of a fantasy waterfall, cloudy"
|
||||
image = pipeline(prompt=prompt, image=image, mask_image=mask_image).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex flex-row gap-4">
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint-text-chain.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">text-to-image</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint-text-chain-out.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">inpaint</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
### Inpaint-to-image-to-image
|
||||
|
||||
You can also chain an inpainting pipeline before another pipeline like image-to-image or an upscaler to improve the quality.
|
||||
|
||||
Begin by inpainting an image:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import AutoPipelineForInpainting, AutoPipelineForImage2Image
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipeline = AutoPipelineForInpainting.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
|
||||
).to("cuda")
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.enable_xformers_memory_efficient_attention()
|
||||
|
||||
# load base and mask image
|
||||
init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
|
||||
mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
|
||||
|
||||
prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
|
||||
image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
|
||||
|
||||
# resize image to 1024x1024 for SDXL
|
||||
image = image.resize((1024, 1024))
|
||||
```
|
||||
|
||||
Now let's pass the image to another inpainting pipeline with SDXL's refiner model to enhance the image details and quality:
|
||||
|
||||
```py
|
||||
pipeline = AutoPipelineForInpainting.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16"
|
||||
).to("cuda")
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.enable_xformers_memory_efficient_attention()
|
||||
|
||||
image = pipeline(prompt=prompt, image=image, mask_image=mask_image, output_type="latent").images[0]
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
It is important to specify `output_type="latent"` in the pipeline to keep all the outputs in latent space to avoid an unnecessary decode-encode step. This only works if the chained pipelines are using the same VAE. For example, in the [Text-to-image-to-inpaint](#text-to-image-to-inpaint) section, Kandinsky 2.2 uses a different VAE class than the Stable Diffusion model so it won't work. But if you use Stable Diffusion v1.5 for both pipelines, then you can keep everything in latent space because they both use [`AutoencoderKL`].
|
||||
|
||||
</Tip>
|
||||
|
||||
Finally, you can pass this image to an image-to-image pipeline to put the finishing touches on it. It is more efficient to use the [`~AutoPipelineForImage2Image.from_pipe`] method to reuse the existing pipeline components, and avoid unnecessarily loading all the pipeline components into memory again.
|
||||
|
||||
```py
|
||||
pipeline = AutoPipelineForImage2Image.from_pipe(pipeline)
|
||||
pipeline.enable_xformers_memory_efficient_attention()
|
||||
|
||||
image = pipeline(prompt=prompt, image=image).images[0]
|
||||
```
|
||||
|
||||
<div class="flex flex-row gap-4">
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">initial image</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint-to-image-chain.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">inpaint</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint-to-image-final.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">image-to-image</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
Image-to-image and inpainting are actually very similar tasks. Image-to-image generates a new image that resembles the existing provided image. Inpainting does the same thing, but it only transforms the image area defined by the mask and the rest of the image is unchanged. You can think of inpainting as a more precise tool for making specific changes and image-to-image has a broader scope for making more sweeping changes.
|
||||
|
||||
## Control image generation
|
||||
|
||||
Getting an image to look exactly the way you want is challenging because the denoising process is random. While you can control certain aspects of generation by configuring parameters like `negative_prompt`, there are better and more efficient methods for controlling image generation.
|
||||
|
||||
### Prompt weighting
|
||||
|
||||
Prompt weighting provides a quantifiable way to scale the representation of concepts in a prompt. You can use it to increase or decrease the magnitude of the text embedding vector for each concept in the prompt, which subsequently determines how much of each concept is generated. The [Compel](https://github.com/damian0815/compel) library offers an intuitive syntax for scaling the prompt weights and generating the embeddings. Learn how to create the embeddings in the [Prompt weighting](../using-diffusers/weighted_prompts) guide.
|
||||
|
||||
Once you've generated the embeddings, pass them to the `prompt_embeds` (and `negative_prompt_embeds` if you're using a negative prompt) parameter in the [`AutoPipelineForInpainting`]. The embeddings replace the `prompt` parameter:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import AutoPipelineForInpainting
|
||||
|
||||
pipeline = AutoPipelineForInpainting.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16,
|
||||
).to("cuda")
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.enable_xformers_memory_efficient_attention()
|
||||
|
||||
image = pipeline(prompt_emebds=prompt_embeds, # generated from Compel
|
||||
negative_prompt_embeds, # generated from Compel
|
||||
image=init_image,
|
||||
mask_image=mask_image
|
||||
).images[0]
|
||||
```
|
||||
|
||||
### ControlNet
|
||||
|
||||
ControlNet models are used with other diffusion models like Stable Diffusion, and they provide an even more flexible and accurate way to control how an image is generated. A ControlNet accepts an additional conditioning image input that guides the diffusion model to preserve the features in it.
|
||||
|
||||
For example, let's condition an image with a ControlNet pretrained on inpaint images:
|
||||
|
||||
```py
|
||||
import torch
|
||||
import numpy as np
|
||||
from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
# load ControlNet
|
||||
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16, variant="fp16")
|
||||
|
||||
# pass ControlNet to the pipeline
|
||||
pipeline = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16, variant="fp16"
|
||||
).to("cuda")
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.enable_xformers_memory_efficient_attention()
|
||||
|
||||
# load base and mask image
|
||||
init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png").convert("RGB")
|
||||
mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png").convert("RGB")
|
||||
|
||||
# prepare control image
|
||||
def make_inpaint_condition(init_image, mask_image):
|
||||
init_image = np.array(init_image.convert("RGB")).astype(np.float32) / 255.0
|
||||
mask_image = np.array(mask_image.convert("L")).astype(np.float32) / 255.0
|
||||
|
||||
assert init_image.shape[0:1] == mask_image.shape[0:1], "image and image_mask must have the same image size"
|
||||
init_image[mask_image > 0.5] = -1.0 # set as masked pixel
|
||||
init_image = np.expand_dims(init_image, 0).transpose(0, 3, 1, 2)
|
||||
init_image = torch.from_numpy(init_image)
|
||||
return init_image
|
||||
|
||||
control_image = make_inpaint_condition(init_image, mask_image)
|
||||
```
|
||||
|
||||
Now generate an image from the base, mask and control images. You'll notice features of the base image are strongly preserved in the generated image.
|
||||
|
||||
```py
|
||||
prompt = "concept art digital painting of an elven castle, inspired by lord of the rings, highly detailed, 8k"
|
||||
image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, control_image=control_image).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
You can take this a step further and chain it with an image-to-image pipeline to apply a new [style](https://huggingface.co/nitrosocke/elden-ring-diffusion):
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForImage2Image
|
||||
|
||||
pipeline = AutoPipelineForImage2Image.from_pretrained(
|
||||
"nitrosocke/elden-ring-diffusion", torch_dtype=torch.float16,
|
||||
).to("cuda")
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.enable_xformers_memory_efficient_attention()
|
||||
|
||||
prompt = "elden ring style castle" # include the token "elden ring style" in the prompt
|
||||
negative_prompt = "bad architecture, deformed, disfigured, poor details"
|
||||
|
||||
image = pipeline(prompt, negative_prompt=negative_prompt, image=image).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex flex-row gap-4">
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">initial image</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint-controlnet.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">ControlNet inpaint</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint-img2img.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">image-to-image</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## Optimize
|
||||
|
||||
It can be difficult and slow to run diffusion models if you're resource constrained, but it dosen't have to be with a few optimization tricks. One of the biggest (and easiest) optimizations you can enable is switching to memory-efficient attention. If you're using PyTorch 2.0, [scaled-dot product attention](../optimization/torch2.0#scaled-dot-product-attention) is automatically enabled and you don't need to do anything else. For non-PyTorch 2.0 users, you can install and use [xFormers](../optimization/xformers)'s implementation of memory-efficient attention. Both options reduce memory usage and accelerate inference.
|
||||
|
||||
You can also offload the model to the GPU to save even more memory:
|
||||
|
||||
```diff
|
||||
+ pipeline.enable_xformers_memory_efficient_attention()
|
||||
+ pipeline.enable_model_cpu_offload()
|
||||
```
|
||||
|
||||
To speed-up your inference code even more, use [`torch_compile`](../optimization/torch2.0#torch.compile). You should wrap `torch.compile` around the most intensive component in the pipeline which is typically the UNet:
|
||||
|
||||
```py
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
```
|
||||
|
||||
Learn more in the [Reduce memory usage](../optimization/memory) and [Torch 2.0](../optimization/torch2.0) guides.
|
||||
@@ -39,7 +39,7 @@ pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
).to("cuda")
|
||||
|
||||
refiner = StableDiffusionXLImg2ImgPipeline.from_single_file(
|
||||
refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
|
||||
).to("cuda")
|
||||
```
|
||||
|
||||
@@ -112,7 +112,7 @@ As you can see, this is already more complex than the DDPM pipeline which only c
|
||||
|
||||
<Tip>
|
||||
|
||||
💡 Read the [How does Stable Diffusion work?](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work) blog for more details about how the VAE, UNet, and text encoder models.
|
||||
💡 Read the [How does Stable Diffusion work?](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work) blog for more details about how the VAE, UNet, and text encoder models work.
|
||||
|
||||
</Tip>
|
||||
|
||||
@@ -214,7 +214,7 @@ Next, generate some initial random noise as a starting point for the diffusion p
|
||||
|
||||
```py
|
||||
>>> latents = torch.randn(
|
||||
... (batch_size, unet.in_channels, height // 8, width // 8),
|
||||
... (batch_size, unet.config.in_channels, height // 8, width // 8),
|
||||
... generator=generator,
|
||||
... )
|
||||
>>> latents = latents.to(torch_device)
|
||||
|
||||
@@ -29,26 +29,32 @@ Unconditional 이미지 생성은 비교적 간단한 작업입니다. 모델이
|
||||
|
||||
이 가이드에서는 unconditional 이미지 생성에 ['DiffusionPipeline']과 [DDPM](https://arxiv.org/abs/2006.11239)을 사용합니다:
|
||||
|
||||
```python
|
||||
```python
|
||||
>>> from diffusers import DiffusionPipeline
|
||||
|
||||
>>> generator = DiffusionPipeline.from_pretrained("anton-l/ddpm-butterflies-128")
|
||||
```
|
||||
```
|
||||
|
||||
[diffusion 파이프라인]은 모든 모델링, 토큰화, 스케줄링 구성 요소를 다운로드하고 캐시합니다. 이 모델은 약 14억 개의 파라미터로 구성되어 있기 때문에 GPU에서 실행할 것을 강력히 권장합니다. PyTorch에서와 마찬가지로 제너레이터 객체를 GPU로 옮길 수 있습니다:
|
||||
```python
|
||||
|
||||
```python
|
||||
>>> generator.to("cuda")
|
||||
```
|
||||
```
|
||||
|
||||
이제 제너레이터를 사용하여 이미지를 생성할 수 있습니다:
|
||||
```python
|
||||
|
||||
```python
|
||||
>>> image = generator().images[0]
|
||||
```
|
||||
```
|
||||
|
||||
출력은 기본적으로 [PIL.Image](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class) 객체로 감싸집니다.
|
||||
|
||||
다음을 호출하여 이미지를 저장할 수 있습니다:
|
||||
```python
|
||||
|
||||
```python
|
||||
>>> image.save("generated_image.png")
|
||||
```
|
||||
```
|
||||
|
||||
아래 스페이스(데모 링크)를 이용해 보고, 추론 단계의 매개변수를 자유롭게 조절하여 이미지 품질에 어떤 영향을 미치는지 확인해 보세요!
|
||||
|
||||
<iframe src="https://stevhliu-ddpm-butterflies-128.hf.space" frameborder="0" width="850" height="500"></iframe>
|
||||
<iframe src="https://stevhliu-ddpm-butterflies-128.hf.space" frameborder="0" width="850" height="500"></iframe>
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
title: 🧨 Diffusers
|
||||
- local: quicktour
|
||||
title: 快速入门
|
||||
- local: stable_diffusion
|
||||
title: 有效和高效的扩散
|
||||
- local: installation
|
||||
title: 安装
|
||||
title: 开始
|
||||
|
||||
@@ -0,0 +1,264 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# 有效且高效的扩散
|
||||
|
||||
[[open-in-colab]]
|
||||
|
||||
让 [`DiffusionPipeline`] 生成特定风格或包含你所想要的内容的图像可能会有些棘手。 通常情况下,你需要多次运行 [`DiffusionPipeline`] 才能得到满意的图像。但是从无到有生成图像是一个计算密集的过程,特别是如果你要一遍又一遍地进行推理运算。
|
||||
|
||||
这就是为什么从pipeline中获得最高的 *computational* (speed) 和 *memory* (GPU RAM) 非常重要 ,以减少推理周期之间的时间,从而使迭代速度更快。
|
||||
|
||||
|
||||
本教程将指导您如何通过 [`DiffusionPipeline`] 更快、更好地生成图像。
|
||||
|
||||
|
||||
首先,加载 [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) 模型:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
model_id = "runwayml/stable-diffusion-v1-5"
|
||||
pipeline = DiffusionPipeline.from_pretrained(model_id, use_safetensors=True)
|
||||
```
|
||||
|
||||
本教程将使用的提示词是 [`portrait photo of a old warrior chief`] ,但是你可以随心所欲的想象和构造自己的提示词:
|
||||
|
||||
```python
|
||||
prompt = "portrait photo of a old warrior chief"
|
||||
```
|
||||
|
||||
## 速度
|
||||
|
||||
<Tip>
|
||||
|
||||
💡 如果你没有 GPU, 你可以从像 [Colab](https://colab.research.google.com/) 这样的 GPU 提供商获取免费的 GPU !
|
||||
|
||||
</Tip>
|
||||
|
||||
加速推理的最简单方法之一是将 pipeline 放在 GPU 上 ,就像使用任何 PyTorch 模块一样:
|
||||
|
||||
```python
|
||||
pipeline = pipeline.to("cuda")
|
||||
```
|
||||
|
||||
为了确保您可以使用相同的图像并对其进行改进,使用 [`Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) 方法,然后设置一个随机数种子 以确保其 [复现性](./using-diffusers/reproducibility):
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
generator = torch.Generator("cuda").manual_seed(0)
|
||||
```
|
||||
|
||||
现在,你可以生成一个图像:
|
||||
|
||||
```python
|
||||
image = pipeline(prompt, generator=generator).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/stable_diffusion_101/sd_101_1.png">
|
||||
</div>
|
||||
|
||||
在 T4 GPU 上,这个过程大概要30秒(如果你的 GPU 比 T4 好,可能会更快)。在默认情况下,[`DiffusionPipeline`] 使用完整的 `float32` 精度进行 50 步推理。你可以通过降低精度(如 `float16` )或者减少推理步数来加速整个过程
|
||||
|
||||
|
||||
让我们把模型的精度降低至 `float16` ,然后生成一张图像:
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, use_safetensors=True)
|
||||
pipeline = pipeline.to("cuda")
|
||||
generator = torch.Generator("cuda").manual_seed(0)
|
||||
image = pipeline(prompt, generator=generator).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/stable_diffusion_101/sd_101_2.png">
|
||||
</div>
|
||||
|
||||
这一次,生成图像只花了约 11 秒,比之前快了近 3 倍!
|
||||
|
||||
<Tip>
|
||||
|
||||
💡 我们强烈建议把 pipeline 精度降低至 `float16` , 到目前为止, 我们很少看到输出质量有任何下降。
|
||||
|
||||
</Tip>
|
||||
|
||||
另一个选择是减少推理步数。 你可以选择一个更高效的调度器 (*scheduler*) 可以减少推理步数同时保证输出质量。您可以在 [DiffusionPipeline] 中通过调用compatibles方法找到与当前模型兼容的调度器 (*scheduler*)。
|
||||
|
||||
```python
|
||||
pipeline.scheduler.compatibles
|
||||
[
|
||||
diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler,
|
||||
diffusers.schedulers.scheduling_unipc_multistep.UniPCMultistepScheduler,
|
||||
diffusers.schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteScheduler,
|
||||
diffusers.schedulers.scheduling_deis_multistep.DEISMultistepScheduler,
|
||||
diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler,
|
||||
diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler,
|
||||
diffusers.schedulers.scheduling_ddpm.DDPMScheduler,
|
||||
diffusers.schedulers.scheduling_dpmsolver_singlestep.DPMSolverSinglestepScheduler,
|
||||
diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteScheduler,
|
||||
diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler,
|
||||
diffusers.schedulers.scheduling_pndm.PNDMScheduler,
|
||||
diffusers.schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteScheduler,
|
||||
diffusers.schedulers.scheduling_ddim.DDIMScheduler,
|
||||
]
|
||||
```
|
||||
|
||||
Stable Diffusion 模型默认使用的是 [`PNDMScheduler`] ,通常要大概50步推理, 但是像 [`DPMSolverMultistepScheduler`] 这样更高效的调度器只要大概 20 或 25 步推理. 使用 [`ConfigMixin.from_config`] 方法加载新的调度器:
|
||||
|
||||
```python
|
||||
from diffusers import DPMSolverMultistepScheduler
|
||||
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
||||
```
|
||||
|
||||
现在将 `num_inference_steps` 设置为 20:
|
||||
|
||||
```python
|
||||
generator = torch.Generator("cuda").manual_seed(0)
|
||||
image = pipeline(prompt, generator=generator, num_inference_steps=20).images[0]
|
||||
image
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/stable_diffusion_101/sd_101_3.png">
|
||||
</div>
|
||||
|
||||
太棒了!你成功把推理时间缩短到 4 秒!⚡️
|
||||
|
||||
## 内存
|
||||
|
||||
改善 pipeline 性能的另一个关键是减少内存的使用量,这间接意味着速度更快,因为你经常试图最大化每秒生成的图像数量。要想知道你一次可以生成多少张图片,最简单的方法是尝试不同的batch size,直到出现`OutOfMemoryError` (OOM)。
|
||||
|
||||
创建一个函数,为每一批要生成的图像分配提示词和 `Generators` 。请务必为每个`Generator` 分配一个种子,以便于复现良好的结果。
|
||||
|
||||
|
||||
```python
|
||||
def get_inputs(batch_size=1):
|
||||
generator = [torch.Generator("cuda").manual_seed(i) for i in range(batch_size)]
|
||||
prompts = batch_size * [prompt]
|
||||
num_inference_steps = 20
|
||||
|
||||
return {"prompt": prompts, "generator": generator, "num_inference_steps": num_inference_steps}
|
||||
```
|
||||
|
||||
设置 `batch_size=4` ,然后看一看我们消耗了多少内存:
|
||||
|
||||
```python
|
||||
from diffusers.utils import make_image_grid
|
||||
|
||||
images = pipeline(**get_inputs(batch_size=4)).images
|
||||
make_image_grid(images, 2, 2)
|
||||
```
|
||||
|
||||
除非你有一个更大内存的GPU, 否则上述代码会返回 `OOM` 错误! 大部分内存被 cross-attention 层使用。按顺序运行可以节省大量内存,而不是在批处理中进行。你可以为 pipeline 配置 [`~DiffusionPipeline.enable_attention_slicing`] 函数:
|
||||
|
||||
```python
|
||||
pipeline.enable_attention_slicing()
|
||||
```
|
||||
|
||||
现在尝试把 `batch_size` 增加到 8!
|
||||
|
||||
```python
|
||||
images = pipeline(**get_inputs(batch_size=8)).images
|
||||
make_image_grid(images, rows=2, cols=4)
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/stable_diffusion_101/sd_101_5.png">
|
||||
</div>
|
||||
|
||||
以前你不能一批生成 4 张图片,而现在你可以在一张图片里面生成八张图片而只需要大概3.5秒!这可能是 T4 GPU 在不牺牲质量的情况运行速度最快的一种方法。
|
||||
|
||||
## 质量
|
||||
|
||||
在最后两节中, 你要学习如何通过 `fp16` 来优化 pipeline 的速度, 通过使用性能更高的调度器来减少推理步数, 使用注意力切片(*enabling attention slicing*)方法来节省内存。现在,你将关注的是如何提高图像的质量。
|
||||
|
||||
### 更好的 checkpoints
|
||||
|
||||
有个显而易见的方法是使用更好的 checkpoints。 Stable Diffusion 模型是一个很好的起点, 自正式发布以来,还发布了几个改进版本。然而, 使用更新的版本并不意味着你会得到更好的结果。你仍然需要尝试不同的 checkpoints ,并做一些研究 (例如使用 [negative prompts](https://minimaxir.com/2022/11/stable-diffusion-negative-prompt/)) 来获得更好的结果。
|
||||
|
||||
随着该领域的发展, 有越来越多经过微调的高质量的 checkpoints 用来生成不一样的风格. 在 [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) 和 [Diffusers Gallery](https://huggingface.co/spaces/huggingface-projects/diffusers-gallery) 寻找你感兴趣的一种!
|
||||
|
||||
### 更好的 pipeline 组件
|
||||
|
||||
也可以尝试用新版本替换当前 pipeline 组件。让我们加载最新的 [autodecoder](https://huggingface.co/stabilityai/stable-diffusion-2-1/tree/main/vae) 从 Stability AI 加载到 pipeline, 并生成一些图像:
|
||||
|
||||
```python
|
||||
from diffusers import AutoencoderKL
|
||||
|
||||
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16).to("cuda")
|
||||
pipeline.vae = vae
|
||||
images = pipeline(**get_inputs(batch_size=8)).images
|
||||
make_image_grid(images, rows=2, cols=4)
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/stable_diffusion_101/sd_101_6.png">
|
||||
</div>
|
||||
|
||||
### 更好的提示词工程
|
||||
|
||||
用于生成图像的文本非常重要, 因此被称为 *提示词工程*。 在设计提示词工程应注意如下事项:
|
||||
|
||||
- 我想生成的图像或类似图像如何存储在互联网上?
|
||||
- 我可以提供哪些额外的细节来引导模型朝着我想要的风格生成?
|
||||
|
||||
考虑到这一点,让我们改进提示词,以包含颜色和更高质量的细节:
|
||||
|
||||
```python
|
||||
prompt += ", tribal panther make up, blue on red, side profile, looking away, serious eyes"
|
||||
prompt += " 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta"
|
||||
```
|
||||
|
||||
使用新的提示词生成一批图像:
|
||||
|
||||
```python
|
||||
images = pipeline(**get_inputs(batch_size=8)).images
|
||||
make_image_grid(images, rows=2, cols=4)
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/stable_diffusion_101/sd_101_7.png">
|
||||
</div>
|
||||
|
||||
非常的令人印象深刻! Let's tweak the second image - 把 `Generator` 的种子设置为 `1` - 添加一些关于年龄的主题文本:
|
||||
|
||||
```python
|
||||
prompts = [
|
||||
"portrait photo of the oldest warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
|
||||
"portrait photo of a old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
|
||||
"portrait photo of a warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
|
||||
"portrait photo of a young warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
|
||||
]
|
||||
|
||||
generator = [torch.Generator("cuda").manual_seed(1) for _ in range(len(prompts))]
|
||||
images = pipeline(prompt=prompts, generator=generator, num_inference_steps=25).images
|
||||
make_image_grid(images, 2, 2)
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/diffusers/docs-images/resolve/main/stable_diffusion_101/sd_101_8.png">
|
||||
</div>
|
||||
|
||||
## 最后
|
||||
|
||||
在本教程中, 您学习了如何优化[`DiffusionPipeline`]以提高计算和内存效率,以及提高生成输出的质量. 如果你有兴趣让你的 pipeline 更快, 可以看一看以下资源:
|
||||
|
||||
- 学习 [PyTorch 2.0](./optimization/torch2.0) 和 [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) 可以让推理速度提高 5 - 300% . 在 A100 GPU 上, 推理速度可以提高 50% !
|
||||
- 如果你没法用 PyTorch 2, 我们建议你安装 [xFormers](./optimization/xformers)。它的内存高效注意力机制(*memory-efficient attention mechanism*)与PyTorch 1.13.1配合使用,速度更快,内存消耗更少。
|
||||
- 其他的优化技术, 如:模型卸载(*model offloading*), 包含在 [这份指南](./optimization/fp16).
|
||||
@@ -562,7 +562,8 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline):
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
# 8. Post-processing
|
||||
image = self.decode_latents(latents)
|
||||
|
||||
@@ -434,7 +434,8 @@ class ImageToImageInpaintingPipeline(DiffusionPipeline):
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
|
||||
@@ -372,7 +372,8 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
|
||||
@@ -1088,7 +1088,8 @@ class StableDiffusionLongPromptWeightingPipeline(
|
||||
progress_bar.update()
|
||||
if i % callback_steps == 0:
|
||||
if callback is not None:
|
||||
callback(i, t, latents)
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
if is_cancelled_callback is not None and is_cancelled_callback():
|
||||
return None
|
||||
|
||||
|
||||
@@ -846,7 +846,8 @@ class OnnxStableDiffusionLongPromptWeightingPipeline(OnnxStableDiffusionPipeline
|
||||
# call the callback, if provided
|
||||
if i % callback_steps == 0:
|
||||
if callback is not None:
|
||||
callback(i, t, latents)
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
if is_cancelled_callback is not None and is_cancelled_callback():
|
||||
return None
|
||||
|
||||
|
||||
@@ -1182,7 +1182,8 @@ class SDXLLongPromptWeightingPipeline(DiffusionPipeline, FromSingleFileMixin, Lo
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
if not output_type == "latent":
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
|
||||
@@ -202,7 +202,8 @@ class MaskedStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
if not output_type == "latent":
|
||||
scaled = latents / self.vae.config.scaling_factor
|
||||
|
||||
@@ -407,7 +407,8 @@ class MultilingualStableDiffusion(DiffusionPipeline):
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
|
||||
@@ -254,7 +254,8 @@ class Prompt2PromptPipeline(StableDiffusionPipeline):
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
# 8. Post-processing
|
||||
if not output_type == "latent":
|
||||
|
||||
@@ -865,7 +865,8 @@ class Zero1to3StableDiffusionPipeline(DiffusionPipeline):
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
# 8. Post-processing
|
||||
has_nsfw_concept = None
|
||||
|
||||
@@ -815,7 +815,8 @@ class OnnxStableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
if not output_type == "latent":
|
||||
_latents = latents.cpu().detach().numpy() / 0.18215
|
||||
@@ -886,7 +887,7 @@ if __name__ == "__main__":
|
||||
onnx_pipeline = onnx_pipeline.to("cuda")
|
||||
|
||||
prompt = "a cute cat fly to the moon"
|
||||
negative_prompt = "paintings, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, age spot, glans, nsfw, nipples, necklace, worst quality, low quality, watermark, username, signature, multiple breasts, lowres, bad anatomy, bad hands, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, bad feet, single color, ugly, duplicate, morbid, mutilated, tranny, trans, trannsexual, hermaphrodite, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, ugly, blurry, bad anatomy, bad proportions, extra limbs, disfigured, bad anatomy, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, mutated hands, fused fingers, too many fingers, long neck, bad body perspect"
|
||||
negative_prompt = "paintings, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, age spot, glans, nsfw, nipples, necklace, worst quality, low quality, watermark, username, signature, multiple breasts, lowres, bad anatomy, bad hands, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, bad feet, single color, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, ugly, blurry, bad anatomy, bad proportions, extra limbs, disfigured, bad anatomy, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, mutated hands, fused fingers, too many fingers, long neck, bad body perspect"
|
||||
|
||||
for i in range(10):
|
||||
start_time = time.time()
|
||||
|
||||
@@ -919,7 +919,8 @@ class TensorRTStableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
if not output_type == "latent":
|
||||
_latents = latents.cpu().detach().numpy() / 0.18215
|
||||
@@ -997,7 +998,7 @@ if __name__ == "__main__":
|
||||
onnx_pipeline = onnx_pipeline.to("cuda")
|
||||
|
||||
prompt = "a cute cat fly to the moon"
|
||||
negative_prompt = "paintings, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, age spot, glans, nsfw, nipples, necklace, worst quality, low quality, watermark, username, signature, multiple breasts, lowres, bad anatomy, bad hands, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, bad feet, single color, ugly, duplicate, morbid, mutilated, tranny, trans, trannsexual, hermaphrodite, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, ugly, blurry, bad anatomy, bad proportions, extra limbs, disfigured, bad anatomy, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, mutated hands, fused fingers, too many fingers, long neck, bad body perspect"
|
||||
negative_prompt = "paintings, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, age spot, glans, nsfw, nipples, necklace, worst quality, low quality, watermark, username, signature, multiple breasts, lowres, bad anatomy, bad hands, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, bad feet, single color, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, ugly, blurry, bad anatomy, bad proportions, extra limbs, disfigured, bad anatomy, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, mutated hands, fused fingers, too many fingers, long neck, bad body perspect"
|
||||
|
||||
for i in range(10):
|
||||
start_time = time.time()
|
||||
|
||||
@@ -337,7 +337,8 @@ class SeedResizeStableDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
|
||||
@@ -242,7 +242,8 @@ class SpeechToImagePipeline(DiffusionPipeline):
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
|
||||
@@ -951,7 +951,8 @@ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline):
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
# If we do sequential model offloading, let's offload unet and controlnet
|
||||
# manually for max memory savings
|
||||
|
||||
@@ -1100,7 +1100,8 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline):
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
# If we do sequential model offloading, let's offload unet and controlnet
|
||||
# manually for max memory savings
|
||||
|
||||
@@ -1081,7 +1081,8 @@ class StableDiffusionControlNetInpaintImg2ImgPipeline(DiffusionPipeline):
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
# If we do sequential model offloading, let's offload unet and controlnet
|
||||
# manually for max memory savings
|
||||
|
||||
@@ -802,7 +802,8 @@ class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeli
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
# If we do sequential model offloading, let's offload unet and controlnet
|
||||
# manually for max memory savings
|
||||
|
||||
@@ -817,7 +817,8 @@ class StableDiffusionIPEXPipeline(DiffusionPipeline):
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
|
||||
@@ -770,7 +770,8 @@ class StableDiffusionReferencePipeline(StableDiffusionPipeline):
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
|
||||
@@ -932,7 +932,8 @@ class StableDiffusionRepaintPipeline(DiffusionPipeline, TextualInversionLoaderMi
|
||||
# call the callback, if provided
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
t_last = t
|
||||
|
||||
|
||||
@@ -771,7 +771,8 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
if not output_type == "latent":
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
|
||||
@@ -389,7 +389,8 @@ class WildcardStableDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
|
||||
@@ -907,17 +907,10 @@ def main():
|
||||
|
||||
if args.snr_gamma is not None:
|
||||
snr = jnp.array(compute_snr(timesteps))
|
||||
base_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
snr_loss_weights = base_weights + 1
|
||||
else:
|
||||
# Epsilon and sample prediction use the base weights.
|
||||
snr_loss_weights = base_weights
|
||||
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
|
||||
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
|
||||
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
|
||||
snr_loss_weights[snr == 0] = 1.0
|
||||
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr
|
||||
loss = loss * snr_loss_weights
|
||||
|
||||
loss = loss.mean()
|
||||
|
||||
@@ -207,7 +207,7 @@ class CustomDiffusionDataset(Dataset):
|
||||
with open(concept["class_prompt"], "r") as f:
|
||||
class_prompt = f.read().splitlines()
|
||||
|
||||
class_img_path = [(x, y) for (x, y) in zip(class_images_path, class_prompt)]
|
||||
class_img_path = list(zip(class_images_path, class_prompt))
|
||||
self.class_images_path.extend(class_img_path[:num_class_images])
|
||||
|
||||
random.shuffle(self.instance_images_path)
|
||||
@@ -1075,30 +1075,30 @@ def main(args):
|
||||
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||
)
|
||||
args.resume_from_checkpoint = None
|
||||
initial_global_step = 0
|
||||
else:
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(args.output_dir, path))
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
||||
initial_global_step = global_step
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
||||
else:
|
||||
initial_global_step = 0
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
progress_bar = tqdm(
|
||||
range(0, args.max_train_steps),
|
||||
initial=initial_global_step,
|
||||
desc="Steps",
|
||||
# Only show the progress bar once on each machine.
|
||||
disable=not accelerator.is_local_main_process,
|
||||
)
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
unet.train()
|
||||
if args.modifier_token is not None:
|
||||
text_encoder.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(unet), accelerator.accumulate(text_encoder):
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||
@@ -1214,50 +1214,52 @@ def main(args):
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if accelerator.is_main_process:
|
||||
images = []
|
||||
if accelerator.is_main_process:
|
||||
images = []
|
||||
|
||||
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
|
||||
logger.info(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
f" {args.validation_prompt}."
|
||||
)
|
||||
# create pipeline
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
tokenizer=tokenizer,
|
||||
revision=args.revision,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
|
||||
logger.info(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
f" {args.validation_prompt}."
|
||||
)
|
||||
# create pipeline
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
tokenizer=tokenizer,
|
||||
revision=args.revision,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
images = [
|
||||
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator, eta=1.0).images[0]
|
||||
for _ in range(args.num_validation_images)
|
||||
]
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
images = [
|
||||
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator, eta=1.0).images[
|
||||
0
|
||||
]
|
||||
for _ in range(args.num_validation_images)
|
||||
]
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "tensorboard":
|
||||
np_images = np.stack([np.asarray(img) for img in images])
|
||||
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
|
||||
if tracker.name == "wandb":
|
||||
tracker.log(
|
||||
{
|
||||
"validation": [
|
||||
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
|
||||
for i, image in enumerate(images)
|
||||
]
|
||||
}
|
||||
)
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "tensorboard":
|
||||
np_images = np.stack([np.asarray(img) for img in images])
|
||||
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
|
||||
if tracker.name == "wandb":
|
||||
tracker.log(
|
||||
{
|
||||
"validation": [
|
||||
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
|
||||
for i, image in enumerate(images)
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Save the custom diffusion layers
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
@@ -52,6 +52,7 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import compute_snr
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
@@ -224,30 +225,6 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
|
||||
raise ValueError(f"{model_class} is not supported.")
|
||||
|
||||
|
||||
def compute_snr(timesteps, noise_scheduler):
|
||||
"""
|
||||
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
||||
"""
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod
|
||||
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
||||
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
||||
# Expand the tensors.
|
||||
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
|
||||
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
||||
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
||||
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
||||
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
||||
|
||||
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
||||
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
||||
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
||||
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
||||
|
||||
# Compute SNR
|
||||
snr = (alpha / sigma) ** 2
|
||||
return snr
|
||||
|
||||
|
||||
def parse_args(input_args=None):
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
@@ -1201,30 +1178,30 @@ def main(args):
|
||||
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||
)
|
||||
args.resume_from_checkpoint = None
|
||||
initial_global_step = 0
|
||||
else:
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(args.output_dir, path))
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
||||
initial_global_step = global_step
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
||||
else:
|
||||
initial_global_step = 0
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
progress_bar = tqdm(
|
||||
range(0, args.max_train_steps),
|
||||
initial=initial_global_step,
|
||||
desc="Steps",
|
||||
# Only show the progress bar once on each machine.
|
||||
disable=not accelerator.is_local_main_process,
|
||||
)
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
unet.train()
|
||||
if args.train_text_encoder:
|
||||
text_encoder.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(unet):
|
||||
pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
|
||||
|
||||
@@ -1302,7 +1279,7 @@ def main(args):
|
||||
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(timesteps, noise_scheduler)
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
base_weight = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
|
||||
@@ -24,7 +24,6 @@ import os
|
||||
import shutil
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -59,12 +58,11 @@ from diffusers.loaders import (
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnAddedKVProcessor,
|
||||
AttnAddedKVProcessor2_0,
|
||||
LoRAAttnAddedKVProcessor,
|
||||
LoRAAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
SlicedAttnAddedKVProcessor,
|
||||
)
|
||||
from diffusers.models.lora import LoRALinearLayer
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import unet_lora_state_dict
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
@@ -658,22 +656,6 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte
|
||||
return prompt_embeds
|
||||
|
||||
|
||||
def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]:
|
||||
r"""
|
||||
Returns:
|
||||
a state dict containing just the attention processor parameters.
|
||||
"""
|
||||
attn_processors = unet.attn_processors
|
||||
|
||||
attn_processors_state_dict = {}
|
||||
|
||||
for attn_processor_key, attn_processor in attn_processors.items():
|
||||
for parameter_key, parameter in attn_processor.state_dict().items():
|
||||
attn_processors_state_dict[f"{attn_processor_key}.{parameter_key}"] = parameter
|
||||
|
||||
return attn_processors_state_dict
|
||||
|
||||
|
||||
def main(args):
|
||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||
|
||||
@@ -854,37 +836,64 @@ def main(args):
|
||||
# For Stable Diffusion, it should be equal to:
|
||||
# - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
|
||||
# - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
|
||||
# - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
|
||||
# - up blocks (2x attention layers) * (3x transformer layers) * (3x up blocks) = 18
|
||||
# => 32 layers
|
||||
|
||||
# Set correct lora layers
|
||||
unet_lora_attn_procs = {}
|
||||
unet_lora_parameters = []
|
||||
for name, attn_processor in unet.attn_processors.items():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = unet.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = unet.config.block_out_channels[block_id]
|
||||
for attn_processor_name, attn_processor in unet.attn_processors.items():
|
||||
# Parse the attention module.
|
||||
attn_module = unet
|
||||
for n in attn_processor_name.split(".")[:-1]:
|
||||
attn_module = getattr(attn_module, n)
|
||||
|
||||
# Set the `lora_layer` attribute of the attention-related matrices.
|
||||
attn_module.to_q.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
|
||||
)
|
||||
)
|
||||
attn_module.to_k.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
|
||||
)
|
||||
)
|
||||
attn_module.to_v.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
|
||||
)
|
||||
)
|
||||
attn_module.to_out[0].set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_out[0].in_features,
|
||||
out_features=attn_module.to_out[0].out_features,
|
||||
rank=args.rank,
|
||||
)
|
||||
)
|
||||
|
||||
# Accumulate the LoRA params to optimize.
|
||||
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
|
||||
|
||||
if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):
|
||||
lora_attn_processor_class = LoRAAttnAddedKVProcessor
|
||||
else:
|
||||
lora_attn_processor_class = (
|
||||
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
|
||||
attn_module.add_k_proj.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.add_k_proj.in_features,
|
||||
out_features=attn_module.add_k_proj.out_features,
|
||||
rank=args.rank,
|
||||
)
|
||||
)
|
||||
|
||||
module = lora_attn_processor_class(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=args.rank
|
||||
)
|
||||
unet_lora_attn_procs[name] = module
|
||||
unet_lora_parameters.extend(module.parameters())
|
||||
|
||||
unet.set_attn_processor(unet_lora_attn_procs)
|
||||
attn_module.add_v_proj.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.add_v_proj.in_features,
|
||||
out_features=attn_module.add_v_proj.out_features,
|
||||
rank=args.rank,
|
||||
)
|
||||
)
|
||||
unet_lora_parameters.extend(attn_module.add_k_proj.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.add_v_proj.lora_layer.parameters())
|
||||
|
||||
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
|
||||
# So, instead, we monkey-patch the forward calls of its attention-blocks.
|
||||
@@ -902,7 +911,7 @@ def main(args):
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
|
||||
unet_lora_layers_to_save = unet_lora_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
|
||||
text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
||||
else:
|
||||
@@ -1108,30 +1117,30 @@ def main(args):
|
||||
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||
)
|
||||
args.resume_from_checkpoint = None
|
||||
initial_global_step = 0
|
||||
else:
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(args.output_dir, path))
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
||||
initial_global_step = global_step
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
||||
else:
|
||||
initial_global_step = 0
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
progress_bar = tqdm(
|
||||
range(0, args.max_train_steps),
|
||||
initial=initial_global_step,
|
||||
desc="Steps",
|
||||
# Only show the progress bar once on each machine.
|
||||
disable=not accelerator.is_local_main_process,
|
||||
)
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
unet.train()
|
||||
if args.train_text_encoder:
|
||||
text_encoder.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(unet):
|
||||
pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
|
||||
|
||||
@@ -1338,7 +1347,7 @@ def main(args):
|
||||
if accelerator.is_main_process:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
unet = unet.to(torch.float32)
|
||||
unet_lora_layers = unet_attn_processors_state_dict(unet)
|
||||
unet_lora_layers = unet_lora_state_dict(unet)
|
||||
|
||||
if text_encoder is not None and args.train_text_encoder:
|
||||
text_encoder = accelerator.unwrap_model(text_encoder)
|
||||
|
||||
@@ -23,7 +23,6 @@ import os
|
||||
import shutil
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -51,8 +50,9 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.loaders import LoraLoaderMixin, text_encoder_lora_state_dict
|
||||
from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0
|
||||
from diffusers.models.lora import LoRALinearLayer
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import unet_lora_state_dict
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
@@ -575,22 +575,6 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
|
||||
def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]:
|
||||
"""
|
||||
Returns:
|
||||
a state dict containing just the attention processor parameters.
|
||||
"""
|
||||
attn_processors = unet.attn_processors
|
||||
|
||||
attn_processors_state_dict = {}
|
||||
|
||||
for attn_processor_key, attn_processor in attn_processors.items():
|
||||
for parameter_key, parameter in attn_processor.state_dict().items():
|
||||
attn_processors_state_dict[f"{attn_processor_key}.{parameter_key}"] = parameter
|
||||
|
||||
return attn_processors_state_dict
|
||||
|
||||
|
||||
def main(args):
|
||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||
|
||||
@@ -761,29 +745,52 @@ def main(args):
|
||||
|
||||
# now we will add new LoRA weights to the attention layers
|
||||
# Set correct lora layers
|
||||
unet_lora_attn_procs = {}
|
||||
unet_lora_parameters = []
|
||||
for name, attn_processor in unet.attn_processors.items():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = unet.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = unet.config.block_out_channels[block_id]
|
||||
for attn_processor_name, attn_processor in unet.attn_processors.items():
|
||||
# Parse the attention module.
|
||||
attn_module = unet
|
||||
for n in attn_processor_name.split(".")[:-1]:
|
||||
attn_module = getattr(attn_module, n)
|
||||
|
||||
lora_attn_processor_class = (
|
||||
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
|
||||
# Set the `lora_layer` attribute of the attention-related matrices.
|
||||
attn_module.to_q.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_q.in_features,
|
||||
out_features=attn_module.to_q.out_features,
|
||||
rank=args.rank,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
)
|
||||
module = lora_attn_processor_class(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=args.rank
|
||||
attn_module.to_k.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_k.in_features,
|
||||
out_features=attn_module.to_k.out_features,
|
||||
rank=args.rank,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
)
|
||||
attn_module.to_v.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_v.in_features,
|
||||
out_features=attn_module.to_v.out_features,
|
||||
rank=args.rank,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
)
|
||||
attn_module.to_out[0].set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_out[0].in_features,
|
||||
out_features=attn_module.to_out[0].out_features,
|
||||
rank=args.rank,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
)
|
||||
unet_lora_attn_procs[name] = module
|
||||
unet_lora_parameters.extend(module.parameters())
|
||||
|
||||
unet.set_attn_processor(unet_lora_attn_procs)
|
||||
# Accumulate the LoRA params to optimize.
|
||||
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
|
||||
|
||||
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
|
||||
# So, instead, we monkey-patch the forward calls of its attention-blocks.
|
||||
@@ -807,7 +814,7 @@ def main(args):
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
|
||||
unet_lora_layers_to_save = unet_lora_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
||||
@@ -1048,18 +1055,25 @@ def main(args):
|
||||
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||
)
|
||||
args.resume_from_checkpoint = None
|
||||
initial_global_step = 0
|
||||
else:
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(args.output_dir, path))
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
||||
initial_global_step = global_step
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
else:
|
||||
initial_global_step = 0
|
||||
|
||||
progress_bar = tqdm(
|
||||
range(0, args.max_train_steps),
|
||||
initial=initial_global_step,
|
||||
desc="Steps",
|
||||
# Only show the progress bar once on each machine.
|
||||
disable=not accelerator.is_local_main_process,
|
||||
)
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
unet.train()
|
||||
@@ -1067,12 +1081,6 @@ def main(args):
|
||||
text_encoder_one.train()
|
||||
text_encoder_two.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(unet):
|
||||
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
|
||||
|
||||
@@ -1273,7 +1281,7 @@ def main(args):
|
||||
if accelerator.is_main_process:
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
unet = unet.to(torch.float32)
|
||||
unet_lora_layers = unet_attn_processors_state_dict(unet)
|
||||
unet_lora_layers = unet_lora_state_dict(unet)
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
|
||||
|
||||
@@ -726,6 +726,9 @@ def main():
|
||||
text_encoder_1.requires_grad_(False)
|
||||
text_encoder_2.requires_grad_(False)
|
||||
|
||||
# Set UNet to trainable.
|
||||
unet.train()
|
||||
|
||||
# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
|
||||
def encode_prompt(text_encoders, tokenizers, prompt):
|
||||
prompt_embeds_list = []
|
||||
@@ -933,29 +936,28 @@ def main():
|
||||
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||
)
|
||||
args.resume_from_checkpoint = None
|
||||
initial_global_step = 0
|
||||
else:
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(args.output_dir, path))
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
||||
initial_global_step = global_step
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
||||
else:
|
||||
initial_global_step = 0
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
progress_bar = tqdm(
|
||||
range(0, args.max_train_steps),
|
||||
initial=initial_global_step,
|
||||
desc="Steps",
|
||||
# Only show the progress bar once on each machine.
|
||||
disable=not accelerator.is_local_main_process,
|
||||
)
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
unet.train()
|
||||
train_loss = 0.0
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(unet):
|
||||
# We want to learn the denoising process w.r.t the edited images which
|
||||
# are conditioned on the original image (which was edited) and the edit instruction.
|
||||
|
||||
@@ -42,7 +42,7 @@ from transformers.utils import ContextManagers
|
||||
import diffusers
|
||||
from diffusers import AutoPipelineForText2Image, DDPMScheduler, UNet2DConditionModel, VQModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import EMAModel
|
||||
from diffusers.training_utils import EMAModel, compute_snr
|
||||
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
@@ -512,6 +512,9 @@ def main():
|
||||
vae.requires_grad_(False)
|
||||
image_encoder.requires_grad_(False)
|
||||
|
||||
# Set unet to trainable.
|
||||
unet.train()
|
||||
|
||||
# Create EMA for the unet.
|
||||
if args.use_ema:
|
||||
ema_unet = UNet2DConditionModel.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder="unet")
|
||||
@@ -530,30 +533,6 @@ def main():
|
||||
else:
|
||||
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
||||
|
||||
def compute_snr(timesteps):
|
||||
"""
|
||||
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
||||
"""
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod
|
||||
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
||||
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
||||
|
||||
# Expand the tensors.
|
||||
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
|
||||
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
||||
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
||||
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
||||
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
||||
|
||||
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
||||
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
||||
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
||||
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
||||
|
||||
# Compute SNR.
|
||||
snr = (alpha / sigma) ** 2
|
||||
return snr
|
||||
|
||||
# `accelerate` 0.16.0 will have better support for customized saving
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
@@ -751,27 +730,28 @@ def main():
|
||||
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||
)
|
||||
args.resume_from_checkpoint = None
|
||||
initial_global_step = 0
|
||||
else:
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(args.output_dir, path))
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
||||
initial_global_step = global_step
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
||||
else:
|
||||
initial_global_step = 0
|
||||
|
||||
progress_bar = tqdm(
|
||||
range(0, args.max_train_steps),
|
||||
initial=initial_global_step,
|
||||
desc="Steps",
|
||||
# Only show the progress bar once on each machine.
|
||||
disable=not accelerator.is_local_main_process,
|
||||
)
|
||||
|
||||
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
unet.train()
|
||||
train_loss = 0.0
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(unet):
|
||||
# Convert images to latent space
|
||||
images = batch["pixel_values"].to(weight_dtype)
|
||||
@@ -800,26 +780,14 @@ def main():
|
||||
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(timesteps)
|
||||
base_weight = (
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective needs to be floored to an SNR weight of one.
|
||||
mse_loss_weights = base_weight + 1
|
||||
else:
|
||||
# Epsilon and sample both use the same loss weights.
|
||||
mse_loss_weights = base_weight
|
||||
|
||||
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
|
||||
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
|
||||
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
|
||||
mse_loss_weights[snr == 0] = 1.0
|
||||
|
||||
# We first calculate the original loss. Then we mean over the non-batch dimensions and
|
||||
# rebalance the sample-wise losses with their respective loss weights.
|
||||
# Finally, we take the mean of the rebalanced loss.
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
loss = loss.mean()
|
||||
|
||||
@@ -41,6 +41,7 @@ from diffusers import AutoPipelineForText2Image, DDPMScheduler, UNet2DConditionM
|
||||
from diffusers.loaders import AttnProcsLayers
|
||||
from diffusers.models.attention_processor import LoRAAttnAddedKVProcessor
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import compute_snr
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
@@ -419,30 +420,6 @@ def main():
|
||||
|
||||
unet.set_attn_processor(lora_attn_procs)
|
||||
|
||||
def compute_snr(timesteps):
|
||||
"""
|
||||
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
||||
"""
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod
|
||||
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
||||
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
||||
|
||||
# Expand the tensors.
|
||||
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
|
||||
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
||||
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
||||
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
||||
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
||||
|
||||
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
||||
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
||||
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
||||
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
||||
|
||||
# Compute SNR.
|
||||
snr = (alpha / sigma) ** 2
|
||||
return snr
|
||||
|
||||
lora_layers = AttnProcsLayers(unet.attn_processors)
|
||||
|
||||
if args.allow_tf32:
|
||||
@@ -602,29 +579,29 @@ def main():
|
||||
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||
)
|
||||
args.resume_from_checkpoint = None
|
||||
initial_global_step = 0
|
||||
else:
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(args.output_dir, path))
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
||||
initial_global_step = global_step
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
||||
else:
|
||||
initial_global_step = 0
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
progress_bar = tqdm(
|
||||
range(0, args.max_train_steps),
|
||||
initial=initial_global_step,
|
||||
desc="Steps",
|
||||
# Only show the progress bar once on each machine.
|
||||
disable=not accelerator.is_local_main_process,
|
||||
)
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
unet.train()
|
||||
train_loss = 0.0
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(unet):
|
||||
# Convert images to latent space
|
||||
images = batch["pixel_values"].to(weight_dtype)
|
||||
@@ -653,26 +630,14 @@ def main():
|
||||
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(timesteps)
|
||||
base_weight = (
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective needs to be floored to an SNR weight of one.
|
||||
mse_loss_weights = base_weight + 1
|
||||
else:
|
||||
# Epsilon and sample both use the same loss weights.
|
||||
mse_loss_weights = base_weight
|
||||
|
||||
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
|
||||
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
|
||||
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
|
||||
mse_loss_weights[snr == 0] = 1.0
|
||||
|
||||
# We first calculate the original loss. Then we mean over the non-batch dimensions and
|
||||
# rebalance the sample-wise losses with their respective loss weights.
|
||||
# Finally, we take the mean of the rebalanced loss.
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
loss = loss.mean()
|
||||
|
||||
@@ -41,6 +41,7 @@ from diffusers import AutoPipelineForText2Image, DDPMScheduler, PriorTransformer
|
||||
from diffusers.loaders import AttnProcsLayers
|
||||
from diffusers.models.attention_processor import LoRAAttnProcessor
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import compute_snr
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
@@ -413,31 +414,6 @@ def main():
|
||||
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=2048, rank=args.rank)
|
||||
|
||||
prior.set_attn_processor(lora_attn_procs)
|
||||
|
||||
def compute_snr(timesteps):
|
||||
"""
|
||||
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
||||
"""
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod
|
||||
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
||||
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
||||
|
||||
# Expand the tensors.
|
||||
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
|
||||
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
||||
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
||||
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
||||
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
||||
|
||||
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
||||
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
||||
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
||||
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
||||
|
||||
# Compute SNR.
|
||||
snr = (alpha / sigma) ** 2
|
||||
return snr
|
||||
|
||||
lora_layers = AttnProcsLayers(prior.attn_processors)
|
||||
|
||||
if args.allow_tf32:
|
||||
@@ -619,30 +595,33 @@ def main():
|
||||
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||
)
|
||||
args.resume_from_checkpoint = None
|
||||
initial_global_step = 0
|
||||
else:
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(args.output_dir, path))
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
||||
initial_global_step = global_step
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
else:
|
||||
initial_global_step = 0
|
||||
|
||||
progress_bar = tqdm(
|
||||
range(0, args.max_train_steps),
|
||||
initial=initial_global_step,
|
||||
desc="Steps",
|
||||
# Only show the progress bar once on each machine.
|
||||
disable=not accelerator.is_local_main_process,
|
||||
)
|
||||
|
||||
clip_mean = clip_mean.to(weight_dtype).to(accelerator.device)
|
||||
clip_std = clip_std.to(weight_dtype).to(accelerator.device)
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
prior.train()
|
||||
train_loss = 0.0
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(prior):
|
||||
# Convert images to latent space
|
||||
text_input_ids, text_mask, clip_images = (
|
||||
@@ -684,26 +663,14 @@ def main():
|
||||
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(timesteps)
|
||||
base_weight = (
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective needs to be floored to an SNR weight of one.
|
||||
mse_loss_weights = base_weight + 1
|
||||
else:
|
||||
# Epsilon and sample both use the same loss weights.
|
||||
mse_loss_weights = base_weight
|
||||
|
||||
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
|
||||
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
|
||||
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
|
||||
mse_loss_weights[snr == 0] = 1.0
|
||||
|
||||
# We first calculate the original loss. Then we mean over the non-batch dimensions and
|
||||
# rebalance the sample-wise losses with their respective loss weights.
|
||||
# Finally, we take the mean of the rebalanced loss.
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
loss = loss.mean()
|
||||
|
||||
@@ -42,7 +42,7 @@ from transformers.utils import ContextManagers
|
||||
import diffusers
|
||||
from diffusers import AutoPipelineForText2Image, DDPMScheduler, PriorTransformer
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import EMAModel
|
||||
from diffusers.training_utils import EMAModel, compute_snr
|
||||
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
|
||||
|
||||
|
||||
@@ -517,36 +517,15 @@ def main():
|
||||
text_encoder.requires_grad_(False)
|
||||
image_encoder.requires_grad_(False)
|
||||
|
||||
# Set prior to trainable.
|
||||
prior.train()
|
||||
|
||||
# Create EMA for the prior.
|
||||
if args.use_ema:
|
||||
ema_prior = PriorTransformer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
|
||||
ema_prior = EMAModel(ema_prior.parameters(), model_cls=PriorTransformer, model_config=ema_prior.config)
|
||||
ema_prior.to(accelerator.device)
|
||||
|
||||
def compute_snr(timesteps):
|
||||
"""
|
||||
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
||||
"""
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod
|
||||
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
||||
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
||||
|
||||
# Expand the tensors.
|
||||
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
|
||||
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
||||
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
||||
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
||||
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
||||
|
||||
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
||||
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
||||
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
||||
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
||||
|
||||
# Compute SNR.
|
||||
snr = (alpha / sigma) ** 2
|
||||
return snr
|
||||
|
||||
# `accelerate` 0.16.0 will have better support for customized saving
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
@@ -765,32 +744,31 @@ def main():
|
||||
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||
)
|
||||
args.resume_from_checkpoint = None
|
||||
initial_global_step = 0
|
||||
else:
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(args.output_dir, path))
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
||||
initial_global_step = global_step
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
||||
else:
|
||||
initial_global_step = 0
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
progress_bar = tqdm(
|
||||
range(0, args.max_train_steps),
|
||||
initial=initial_global_step,
|
||||
desc="Steps",
|
||||
# Only show the progress bar once on each machine.
|
||||
disable=not accelerator.is_local_main_process,
|
||||
)
|
||||
|
||||
clip_mean = clip_mean.to(weight_dtype).to(accelerator.device)
|
||||
clip_std = clip_std.to(weight_dtype).to(accelerator.device)
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
prior.train()
|
||||
train_loss = 0.0
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(prior):
|
||||
# Convert images to latent space
|
||||
text_input_ids, text_mask, clip_images = (
|
||||
@@ -832,26 +810,14 @@ def main():
|
||||
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(timesteps)
|
||||
base_weight = (
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective needs to be floored to an SNR weight of one.
|
||||
mse_loss_weights = base_weight + 1
|
||||
else:
|
||||
# Epsilon and sample both use the same loss weights.
|
||||
mse_loss_weights = base_weight
|
||||
|
||||
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
|
||||
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
|
||||
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
|
||||
mse_loss_weights[snr == 0] = 1.0
|
||||
|
||||
# We first calculate the original loss. Then we mean over the non-batch dimensions and
|
||||
# rebalance the sample-wise losses with their respective loss weights.
|
||||
# Finally, we take the mean of the rebalanced loss.
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
loss = loss.mean()
|
||||
|
||||
@@ -44,7 +44,7 @@ from transformers.utils import ContextManagers
|
||||
import diffusers
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import EMAModel
|
||||
from diffusers.training_utils import EMAModel, compute_snr
|
||||
from diffusers.utils import check_min_version, deprecate, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
@@ -524,30 +524,6 @@ def main():
|
||||
else:
|
||||
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
||||
|
||||
def compute_snr(timesteps):
|
||||
"""
|
||||
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
||||
"""
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod
|
||||
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
||||
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
||||
|
||||
# Expand the tensors.
|
||||
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
|
||||
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
||||
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
||||
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
||||
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
||||
|
||||
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
||||
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
||||
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
||||
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
||||
|
||||
# Compute SNR.
|
||||
snr = (alpha / sigma) ** 2
|
||||
return snr
|
||||
|
||||
# `accelerate` 0.16.0 will have better support for customized saving
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
@@ -871,25 +847,14 @@ def main():
|
||||
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(timesteps)
|
||||
base_weight = (
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
|
||||
mse_loss_weights = base_weight + 1
|
||||
else:
|
||||
# Epsilon and sample prediction use the base weights.
|
||||
mse_loss_weights = base_weight
|
||||
|
||||
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
|
||||
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
|
||||
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
|
||||
mse_loss_weights[snr == 0] = 1.0
|
||||
|
||||
# We first calculate the original loss. Then we mean over the non-batch dimensions and
|
||||
# rebalance the sample-wise losses with their respective loss weights.
|
||||
# Finally, we take the mean of the rebalanced loss.
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
loss = loss.mean()
|
||||
|
||||
@@ -432,7 +432,8 @@ class RDMPipeline(DiffusionPipeline):
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
else:
|
||||
|
||||
@@ -0,0 +1,243 @@
|
||||
# Stable Diffusion XL for JAX + TPUv5e
|
||||
|
||||
[TPU v5e](https://cloud.google.com/blog/products/compute/how-cloud-tpu-v5e-accelerates-large-scale-ai-inference) is a new generation of TPUs from Google Cloud. It is the most cost-effective, versatile, and scalable Cloud TPU to date. This makes them ideal for serving and scaling large diffusion models.
|
||||
|
||||
[JAX](https://github.com/google/jax) is a high-performance numerical computation library that is well-suited to develop and deploy diffusion models:
|
||||
|
||||
- **High performance**. All JAX operations are implemented in terms of operations in [XLA](https://www.tensorflow.org/xla/) - the Accelerated Linear Algebra compiler
|
||||
|
||||
- **Compilation**. JAX uses just-in-time (jit) compilation of JAX Python functions so it can be executed efficiently in XLA. In order to get the best performance, we must use static shapes for jitted functions, this is because JAX transforms work by tracing a function and to determine its effect on inputs of a specific shape and type. When a new shape is introduced to an already compiled function, it retriggers compilation on the new shape, which can greatly reduce performance. **Note**: JIT compilation is particularly well-suited for text-to-image generation because all inputs and outputs (image input / output sizes) are static.
|
||||
|
||||
- **Parallelization**. Workloads can be scaled across multiple devices using JAX's [pmap](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html), which expresses single-program multiple-data (SPMD) programs. Applying pmap to a function will compile a function with XLA, then execute in parallel on XLA devices. For text-to-image generation workloads this means that increasing the number of images rendered simultaneously is straightforward to implement and doesn't compromise performance.
|
||||
|
||||
👉 Try it out for yourself:
|
||||
|
||||
[](https://huggingface.co/spaces/google/sdxl)
|
||||
|
||||
## Stable Diffusion XL pipeline in JAX
|
||||
|
||||
Upon having access to a TPU VM (TPUs higher than version 3), you should first install
|
||||
a TPU-compatible version of JAX:
|
||||
```
|
||||
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
||||
```
|
||||
|
||||
Next, we can install [flax](https://github.com/google/flax) and the diffusers library:
|
||||
|
||||
```
|
||||
pip install flax diffusers transformers
|
||||
```
|
||||
|
||||
In [sdxl_single.py](./sdxl_single.py) we give a simple example of how to write a text-to-image generation pipeline in JAX using [StabilityAI's Stable Diffusion XL](stabilityai/stable-diffusion-xl-base-1.0).
|
||||
|
||||
Let's explain it step-by-step:
|
||||
|
||||
**Imports and Setup**
|
||||
|
||||
```python
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from flax.jax_utils import replicate
|
||||
from diffusers import FlaxStableDiffusionXLPipeline
|
||||
|
||||
from jax.experimental.compilation_cache import compilation_cache as cc
|
||||
cc.initialize_cache("/tmp/sdxl_cache")
|
||||
import time
|
||||
|
||||
NUM_DEVICES = jax.device_count()
|
||||
```
|
||||
|
||||
First, we import the necessary libraries:
|
||||
- `jax` is provides the primitives for TPU operations
|
||||
- `flax.jax_utils` contains some useful utility functions for `Flax`, a neural network library built on top of JAX
|
||||
- `diffusers` has all the code that is relevant for SDXL.
|
||||
- We also initialize a cache to speed up the JAX model compilation.
|
||||
- We automatically determine the number of available TPU devices.
|
||||
|
||||
**1. Downloading Model and Loading Pipeline**
|
||||
|
||||
```python
|
||||
pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", revision="refs/pr/95", split_head_dim=True
|
||||
)
|
||||
```
|
||||
Here, a pre-trained model `stable-diffusion-xl-base-1.0` from the namespace `stabilityai` is loaded. It returns a pipeline for inference and its parameters.
|
||||
|
||||
**2. Casting Parameter Types**
|
||||
|
||||
```python
|
||||
scheduler_state = params.pop("scheduler")
|
||||
params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
|
||||
params["scheduler"] = scheduler_state
|
||||
```
|
||||
This section adjusts the data types of the model parameters.
|
||||
We convert all parameters to `bfloat16` to speed-up the computation with model weights.
|
||||
**Note** that the scheduler parameters are **not** converted to `blfoat16` as the loss
|
||||
in precision is degrading the pipeline's performance too significantly.
|
||||
|
||||
**3. Define Inputs to Pipeline**
|
||||
|
||||
```python
|
||||
default_prompt = ...
|
||||
default_neg_prompt = ...
|
||||
default_seed = 33
|
||||
default_guidance_scale = 5.0
|
||||
default_num_steps = 25
|
||||
```
|
||||
Here, various default inputs for the pipeline are set, including the prompt, negative prompt, random seed, guidance scale, and the number of inference steps.
|
||||
|
||||
**4. Tokenizing Inputs**
|
||||
|
||||
```python
|
||||
def tokenize_prompt(prompt, neg_prompt):
|
||||
prompt_ids = pipeline.prepare_inputs(prompt)
|
||||
neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)
|
||||
return prompt_ids, neg_prompt_ids
|
||||
```
|
||||
This function tokenizes the given prompts. It's essential because the text encoders of SDXL don't understand raw text; they work with numbers. Tokenization converts text to numbers.
|
||||
|
||||
**5. Parallelization and Replication**
|
||||
|
||||
```python
|
||||
p_params = replicate(params)
|
||||
|
||||
def replicate_all(prompt_ids, neg_prompt_ids, seed):
|
||||
...
|
||||
```
|
||||
To utilize JAX's parallel capabilities, the parameters and input tensors are duplicated across devices. The `replicate_all` function also ensures that every device produces a different image by creating a unique random seed for each device.
|
||||
|
||||
**6. Putting Everything Together**
|
||||
|
||||
```python
|
||||
def generate(...):
|
||||
...
|
||||
```
|
||||
This function integrates all the steps to produce the desired outputs from the model. It takes in prompts, tokenizes them, replicates them across devices, runs them through the pipeline, and converts the images to a format that's more interpretable (PIL format).
|
||||
|
||||
**7. Compilation Step**
|
||||
|
||||
```python
|
||||
start = time.time()
|
||||
print(f"Compiling ...")
|
||||
generate(default_prompt, default_neg_prompt)
|
||||
print(f"Compiled in {time.time() - start}")
|
||||
```
|
||||
The initial run of the `generate` function will be slow because JAX compiles the function during this call. By running it once here, subsequent calls will be much faster. This section measures and prints the compilation time.
|
||||
|
||||
**8. Fast Inference**
|
||||
|
||||
```python
|
||||
start = time.time()
|
||||
prompt = ...
|
||||
neg_prompt = ...
|
||||
images = generate(prompt, neg_prompt)
|
||||
print(f"Inference in {time.time() - start}")
|
||||
```
|
||||
Now that the function is compiled, this section shows how to use it for fast inference. It measures and prints the inference time.
|
||||
|
||||
In summary, the code demonstrates how to load a pre-trained model using Flax and JAX, prepare it for inference, and run it efficiently using JAX's capabilities.
|
||||
|
||||
## Ahead of Time (AOT) Compilation
|
||||
|
||||
FlaxStableDiffusionXLPipeline takes care of parallelization across multiple devices using jit. Now let's build parallelization ourselves.
|
||||
|
||||
For this we will be using a JAX feature called [Ahead of Time](https://jax.readthedocs.io/en/latest/aot.html) (AOT) lowering and compilation. AOT allows to fully compile prior to execution time and have control over different parts of the compilation process.
|
||||
|
||||
In [sdxl_single_aot.py](./sdxl_single_aot.py) we give a simple example of how to write our own parallelization logic for text-to-image generation pipeline in JAX using [StabilityAI's Stable Diffusion XL](stabilityai/stable-diffusion-xl-base-1.0)
|
||||
|
||||
We add a `aot_compile` function that compiles the `pipeline._generate` function
|
||||
telling JAX which input arguments are static, that is, arguments that
|
||||
are known at compile time and won't change. In our case, it is num_inference_steps,
|
||||
height, width and return_latents.
|
||||
|
||||
Once the function is compiled, these parameters are ommited from future calls and
|
||||
cannot be changed without modifying the code and recompiling.
|
||||
|
||||
```python
|
||||
def aot_compile(
|
||||
prompt=default_prompt,
|
||||
negative_prompt=default_neg_prompt,
|
||||
seed=default_seed,
|
||||
guidance_scale=default_guidance_scale,
|
||||
num_inference_steps=default_num_steps
|
||||
):
|
||||
prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
|
||||
prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
|
||||
g = jnp.array([guidance_scale] * prompt_ids.shape[0], dtype=jnp.float32)
|
||||
g = g[:, None]
|
||||
|
||||
return pmap(
|
||||
pipeline._generate,static_broadcasted_argnums=[3, 4, 5, 9]
|
||||
).lower(
|
||||
prompt_ids,
|
||||
p_params,
|
||||
rng,
|
||||
num_inference_steps, # num_inference_steps
|
||||
height, # height
|
||||
width, # width
|
||||
g,
|
||||
None,
|
||||
neg_prompt_ids,
|
||||
False # return_latents
|
||||
).compile()
|
||||
````
|
||||
|
||||
Next we can compile the generate function by executing `aot_compile`.
|
||||
|
||||
```python
|
||||
start = time.time()
|
||||
print("Compiling ...")
|
||||
p_generate = aot_compile()
|
||||
print(f"Compiled in {time.time() - start}")
|
||||
```
|
||||
And again we put everything together in a `generate` function.
|
||||
|
||||
```python
|
||||
def generate(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
seed=default_seed,
|
||||
guidance_scale=default_guidance_scale
|
||||
):
|
||||
prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
|
||||
prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
|
||||
g = jnp.array([guidance_scale] * prompt_ids.shape[0], dtype=jnp.float32)
|
||||
g = g[:, None]
|
||||
images = p_generate(
|
||||
prompt_ids,
|
||||
p_params,
|
||||
rng,
|
||||
g,
|
||||
None,
|
||||
neg_prompt_ids)
|
||||
|
||||
# convert the images to PIL
|
||||
images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
|
||||
return pipeline.numpy_to_pil(np.array(images))
|
||||
```
|
||||
|
||||
The first forward pass after AOT compilation still takes a while longer than
|
||||
subsequent passes, this is because on the first pass, JAX uses Python dispatch, which
|
||||
Fills the C++ dispatch cache.
|
||||
When using jit, this extra step is done automatically, but when using AOT compilation,
|
||||
it doesn't happen until the function call is made.
|
||||
|
||||
```python
|
||||
start = time.time()
|
||||
prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang"
|
||||
neg_prompt = "cartoon, illustration, animation. face. male, female"
|
||||
images = generate(prompt, neg_prompt)
|
||||
print(f"First inference in {time.time() - start}")
|
||||
```
|
||||
|
||||
From this point forward, any calls to generate should result in a faster inference
|
||||
time and it won't change.
|
||||
|
||||
```python
|
||||
start = time.time()
|
||||
prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang"
|
||||
neg_prompt = "cartoon, illustration, animation. face. male, female"
|
||||
images = generate(prompt, neg_prompt)
|
||||
print(f"Inference in {time.time() - start}")
|
||||
```
|
||||
@@ -0,0 +1,106 @@
|
||||
# Show best practices for SDXL JAX
|
||||
import time
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from flax.jax_utils import replicate
|
||||
|
||||
# Let's cache the model compilation, so that it doesn't take as long the next time around.
|
||||
from jax.experimental.compilation_cache import compilation_cache as cc
|
||||
|
||||
from diffusers import FlaxStableDiffusionXLPipeline
|
||||
|
||||
|
||||
cc.initialize_cache("/tmp/sdxl_cache")
|
||||
|
||||
|
||||
NUM_DEVICES = jax.device_count()
|
||||
|
||||
# 1. Let's start by downloading the model and loading it into our pipeline class
|
||||
# Adhering to JAX's functional approach, the model's parameters are returned seperatetely and
|
||||
# will have to be passed to the pipeline during inference
|
||||
pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", revision="refs/pr/95", split_head_dim=True
|
||||
)
|
||||
|
||||
# 2. We cast all parameters to bfloat16 EXCEPT the scheduler which we leave in
|
||||
# float32 to keep maximal precision
|
||||
scheduler_state = params.pop("scheduler")
|
||||
params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
|
||||
params["scheduler"] = scheduler_state
|
||||
|
||||
# 3. Next, we define the different inputs to the pipeline
|
||||
default_prompt = "a colorful photo of a castle in the middle of a forest with trees and bushes, by Ismail Inceoglu, shadows, high contrast, dynamic shading, hdr, detailed vegetation, digital painting, digital drawing, detailed painting, a detailed digital painting, gothic art, featured on deviantart"
|
||||
default_neg_prompt = "fog, grainy, purple"
|
||||
default_seed = 33
|
||||
default_guidance_scale = 5.0
|
||||
default_num_steps = 25
|
||||
|
||||
|
||||
# 4. In order to be able to compile the pipeline
|
||||
# all inputs have to be tensors or strings
|
||||
# Let's tokenize the prompt and negative prompt
|
||||
def tokenize_prompt(prompt, neg_prompt):
|
||||
prompt_ids = pipeline.prepare_inputs(prompt)
|
||||
neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)
|
||||
return prompt_ids, neg_prompt_ids
|
||||
|
||||
|
||||
# 5. To make full use of JAX's parallelization capabilities
|
||||
# the parameters and input tensors are duplicated across devices
|
||||
# To make sure every device generates a different image, we create
|
||||
# different seeds for each image. The model parameters won't change
|
||||
# during inference so we do not wrap them into a function
|
||||
p_params = replicate(params)
|
||||
|
||||
|
||||
def replicate_all(prompt_ids, neg_prompt_ids, seed):
|
||||
p_prompt_ids = replicate(prompt_ids)
|
||||
p_neg_prompt_ids = replicate(neg_prompt_ids)
|
||||
rng = jax.random.PRNGKey(seed)
|
||||
rng = jax.random.split(rng, NUM_DEVICES)
|
||||
return p_prompt_ids, p_neg_prompt_ids, rng
|
||||
|
||||
|
||||
# 6. Let's now put it all together in a generate function
|
||||
def generate(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
seed=default_seed,
|
||||
guidance_scale=default_guidance_scale,
|
||||
num_inference_steps=default_num_steps,
|
||||
):
|
||||
prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
|
||||
prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
|
||||
images = pipeline(
|
||||
prompt_ids,
|
||||
p_params,
|
||||
rng,
|
||||
num_inference_steps=num_inference_steps,
|
||||
neg_prompt_ids=neg_prompt_ids,
|
||||
guidance_scale=guidance_scale,
|
||||
jit=True,
|
||||
).images
|
||||
|
||||
# convert the images to PIL
|
||||
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
|
||||
return pipeline.numpy_to_pil(np.array(images))
|
||||
|
||||
|
||||
# 7. Remember that the first call will compile the function and hence be very slow. Let's run generate once
|
||||
# so that the pipeline call is compiled
|
||||
start = time.time()
|
||||
print("Compiling ...")
|
||||
generate(default_prompt, default_neg_prompt)
|
||||
print(f"Compiled in {time.time() - start}")
|
||||
|
||||
# 8. Now the model forward pass will run very quickly, let's try it again
|
||||
start = time.time()
|
||||
prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang"
|
||||
neg_prompt = "cartoon, illustration, animation. face. male, female"
|
||||
images = generate(prompt, neg_prompt)
|
||||
print(f"Inference in {time.time() - start}")
|
||||
|
||||
for i, image in enumerate(images):
|
||||
image.save(f"castle_{i}.png")
|
||||
@@ -0,0 +1,143 @@
|
||||
import time
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from flax.jax_utils import replicate
|
||||
from jax import pmap
|
||||
|
||||
# Let's cache the model compilation, so that it doesn't take as long the next time around.
|
||||
from jax.experimental.compilation_cache import compilation_cache as cc
|
||||
|
||||
from diffusers import FlaxStableDiffusionXLPipeline
|
||||
|
||||
|
||||
cc.initialize_cache("/tmp/sdxl_cache")
|
||||
|
||||
|
||||
NUM_DEVICES = jax.device_count()
|
||||
|
||||
# 1. Let's start by downloading the model and loading it into our pipeline class
|
||||
# Adhering to JAX's functional approach, the model's parameters are returned seperatetely and
|
||||
# will have to be passed to the pipeline during inference
|
||||
pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", revision="refs/pr/95", split_head_dim=True
|
||||
)
|
||||
|
||||
# 2. We cast all parameters to bfloat16 EXCEPT the scheduler which we leave in
|
||||
# float32 to keep maximal precision
|
||||
scheduler_state = params.pop("scheduler")
|
||||
params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
|
||||
params["scheduler"] = scheduler_state
|
||||
|
||||
# 3. Next, we define the different inputs to the pipeline
|
||||
default_prompt = "a colorful photo of a castle in the middle of a forest with trees and bushes, by Ismail Inceoglu, shadows, high contrast, dynamic shading, hdr, detailed vegetation, digital painting, digital drawing, detailed painting, a detailed digital painting, gothic art, featured on deviantart"
|
||||
default_neg_prompt = "fog, grainy, purple"
|
||||
default_seed = 33
|
||||
default_guidance_scale = 5.0
|
||||
default_num_steps = 25
|
||||
width = 1024
|
||||
height = 1024
|
||||
|
||||
|
||||
# 4. In order to be able to compile the pipeline
|
||||
# all inputs have to be tensors or strings
|
||||
# Let's tokenize the prompt and negative prompt
|
||||
def tokenize_prompt(prompt, neg_prompt):
|
||||
prompt_ids = pipeline.prepare_inputs(prompt)
|
||||
neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)
|
||||
return prompt_ids, neg_prompt_ids
|
||||
|
||||
|
||||
# 5. To make full use of JAX's parallelization capabilities
|
||||
# the parameters and input tensors are duplicated across devices
|
||||
# To make sure every device generates a different image, we create
|
||||
# different seeds for each image. The model parameters won't change
|
||||
# during inference so we do not wrap them into a function
|
||||
p_params = replicate(params)
|
||||
|
||||
|
||||
def replicate_all(prompt_ids, neg_prompt_ids, seed):
|
||||
p_prompt_ids = replicate(prompt_ids)
|
||||
p_neg_prompt_ids = replicate(neg_prompt_ids)
|
||||
rng = jax.random.PRNGKey(seed)
|
||||
rng = jax.random.split(rng, NUM_DEVICES)
|
||||
return p_prompt_ids, p_neg_prompt_ids, rng
|
||||
|
||||
|
||||
# 6. To compile the pipeline._generate function, we must pass all parameters
|
||||
# to the function and tell JAX which are static arguments, that is, arguments that
|
||||
# are known at compile time and won't change. In our case, it is num_inference_steps,
|
||||
# height, width and return_latents.
|
||||
# Once the function is compiled, these parameters are ommited from future calls and
|
||||
# cannot be changed without modifying the code and recompiling.
|
||||
def aot_compile(
|
||||
prompt=default_prompt,
|
||||
negative_prompt=default_neg_prompt,
|
||||
seed=default_seed,
|
||||
guidance_scale=default_guidance_scale,
|
||||
num_inference_steps=default_num_steps,
|
||||
):
|
||||
prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
|
||||
prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
|
||||
g = jnp.array([guidance_scale] * prompt_ids.shape[0], dtype=jnp.float32)
|
||||
g = g[:, None]
|
||||
|
||||
return (
|
||||
pmap(pipeline._generate, static_broadcasted_argnums=[3, 4, 5, 9])
|
||||
.lower(
|
||||
prompt_ids,
|
||||
p_params,
|
||||
rng,
|
||||
num_inference_steps, # num_inference_steps
|
||||
height, # height
|
||||
width, # width
|
||||
g,
|
||||
None,
|
||||
neg_prompt_ids,
|
||||
False, # return_latents
|
||||
)
|
||||
.compile()
|
||||
)
|
||||
|
||||
|
||||
start = time.time()
|
||||
print("Compiling ...")
|
||||
p_generate = aot_compile()
|
||||
print(f"Compiled in {time.time() - start}")
|
||||
|
||||
|
||||
# 7. Let's now put it all together in a generate function.
|
||||
def generate(prompt, negative_prompt, seed=default_seed, guidance_scale=default_guidance_scale):
|
||||
prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
|
||||
prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
|
||||
g = jnp.array([guidance_scale] * prompt_ids.shape[0], dtype=jnp.float32)
|
||||
g = g[:, None]
|
||||
images = p_generate(prompt_ids, p_params, rng, g, None, neg_prompt_ids)
|
||||
|
||||
# convert the images to PIL
|
||||
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
|
||||
return pipeline.numpy_to_pil(np.array(images))
|
||||
|
||||
|
||||
# 8. The first forward pass after AOT compilation still takes a while longer than
|
||||
# subsequent passes, this is because on the first pass, JAX uses Python dispatch, which
|
||||
# Fills the C++ dispatch cache.
|
||||
# When using jit, this extra step is done automatically, but when using AOT compilation,
|
||||
# it doesn't happen until the function call is made.
|
||||
start = time.time()
|
||||
prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang"
|
||||
neg_prompt = "cartoon, illustration, animation. face. male, female"
|
||||
images = generate(prompt, neg_prompt)
|
||||
print(f"First inference in {time.time() - start}")
|
||||
|
||||
# 9. From this point forward, any calls to generate should result in a faster inference
|
||||
# time and it won't change.
|
||||
start = time.time()
|
||||
prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang"
|
||||
neg_prompt = "cartoon, illustration, animation. face. male, female"
|
||||
images = generate(prompt, neg_prompt)
|
||||
print(f"Inference in {time.time() - start}")
|
||||
|
||||
for i, image in enumerate(images):
|
||||
image.save(f"castle_{i}.png")
|
||||
@@ -20,7 +20,7 @@ pip install -e .
|
||||
|
||||
Then cd in the `examples/t2i_adapter` folder and run
|
||||
```bash
|
||||
pip install -r requirements_sdxl.txt
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
@@ -44,7 +44,7 @@ from accelerate.utils import write_basic_config
|
||||
write_basic_config()
|
||||
```
|
||||
|
||||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
|
||||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
|
||||
|
||||
### Training
|
||||
|
||||
@@ -73,10 +73,10 @@ accelerate launch train_text_to_image_sdxl.py \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
**Notes**:
|
||||
**Notes**:
|
||||
|
||||
* The `train_text_to_image_sdxl.py` script pre-computes text embeddings and the VAE encodings and keeps them in memory. While for smaller datasets like [`lambdalabs/pokemon-blip-captions`](https://hf.co/datasets/lambdalabs/pokemon-blip-captions), it might not be a problem, it can definitely lead to memory problems when the script is used on a larger dataset. For those purposes, you would want to serialize these pre-computed representations to disk separately and load them during the fine-tuning process. Refer to [this PR](https://github.com/huggingface/diffusers/pull/4505) for a more in-depth discussion.
|
||||
* The training script is compute-intensive and may not run on a consumer GPU like Tesla T4.
|
||||
* The `train_text_to_image_sdxl.py` script pre-computes text embeddings and the VAE encodings and keeps them in memory. While for smaller datasets like [`lambdalabs/pokemon-blip-captions`](https://hf.co/datasets/lambdalabs/pokemon-blip-captions), it might not be a problem, it can definitely lead to memory problems when the script is used on a larger dataset. For those purposes, you would want to serialize these pre-computed representations to disk separately and load them during the fine-tuning process. Refer to [this PR](https://github.com/huggingface/diffusers/pull/4505) for a more in-depth discussion.
|
||||
* The training script is compute-intensive and may not run on a consumer GPU like Tesla T4.
|
||||
* The training command shown above performs intermediate quality validation in between the training epochs and logs the results to Weights and Biases. `--report_to`, `--validation_prompt`, and `--validation_epochs` are the relevant CLI arguments here.
|
||||
* SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
|
||||
|
||||
@@ -95,6 +95,35 @@ image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
|
||||
image.save("pokemon.png")
|
||||
```
|
||||
|
||||
### Inference in Pytorch XLA
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
pipe = DiffusionPipeline.from_pretrained(model_id)
|
||||
|
||||
device = xm.xla_device()
|
||||
pipe.to(device)
|
||||
|
||||
prompt = "A pokemon with green eyes and red legs."
|
||||
start = time()
|
||||
image = pipe(prompt, num_inference_steps=inference_steps).images[0]
|
||||
print(f'Compilation time is {time()-start} sec')
|
||||
image.save("pokemon.png")
|
||||
|
||||
start = time()
|
||||
image = pipe(prompt, num_inference_steps=inference_steps).images[0]
|
||||
print(f'Inference time is {time()-start} sec after compilation')
|
||||
```
|
||||
|
||||
Note: There is a warmup step in PyTorch XLA. This takes longer because of
|
||||
compilation and optimization. To see the real benefits of Pytorch XLA and
|
||||
speedup, we need to call the pipe again on the input with the same length
|
||||
as the original prompt to reuse the optimized graph and get the performance
|
||||
boost.
|
||||
|
||||
## LoRA training example for Stable Diffusion XL (SDXL)
|
||||
|
||||
Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
|
||||
@@ -112,7 +141,7 @@ on consumer GPUs like Tesla T4, Tesla V100.
|
||||
|
||||
### Training
|
||||
|
||||
First, you need to set up your development environment as is explained in the [installation section](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables and, optionally, the `VAE_NAME` variable. Here, we will use [Stable Diffusion XL 1.0-base](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and the [Pokemons dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions).
|
||||
First, you need to set up your development environment as is explained in the [installation section](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables and, optionally, the `VAE_NAME` variable. Here, we will use [Stable Diffusion XL 1.0-base](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and the [Pokemons dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions).
|
||||
|
||||
**___Note: It is quite useful to monitor the training progress by regularly generating sample images during training. [Weights and Biases](https://docs.wandb.ai/quickstart) is a nice solution to easily see generating images during training. All you need to do is to run `pip install wandb` before training to automatically log images.___**
|
||||
|
||||
@@ -122,7 +151,7 @@ export VAE_NAME="madebyollin/sdxl-vae-fp16-fix"
|
||||
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
|
||||
```
|
||||
|
||||
For this example we want to directly store the trained LoRA embeddings on the Hub, so
|
||||
For this example we want to directly store the trained LoRA embeddings on the Hub, so
|
||||
we need to be logged in and add the `--push_to_hub` flag.
|
||||
|
||||
```bash
|
||||
@@ -149,7 +178,7 @@ accelerate launch train_text_to_image_lora_sdxl.py \
|
||||
|
||||
The above command will also run inference as fine-tuning progresses and log the results to Weights and Biases.
|
||||
|
||||
**Notes**:
|
||||
**Notes**:
|
||||
|
||||
* SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
|
||||
|
||||
@@ -178,7 +207,7 @@ accelerate launch train_text_to_image_lora_sdxl.py \
|
||||
|
||||
### Inference
|
||||
|
||||
Once you have trained a model using above command, the inference can be done simply using the `DiffusionPipeline` after loading the trained LoRA weights. You
|
||||
Once you have trained a model using above command, the inference can be done simply using the `DiffusionPipeline` after loading the trained LoRA weights. You
|
||||
need to pass the `output_dir` for loading the LoRA weights which, in this case, is `sd-pokemon-model-lora-sdxl`.
|
||||
|
||||
```python
|
||||
|
||||
@@ -4,3 +4,4 @@ transformers>=4.25.1
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
datasets
|
||||
|
||||
@@ -43,7 +43,7 @@ from transformers.utils import ContextManagers
|
||||
import diffusers
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import EMAModel
|
||||
from diffusers.training_utils import EMAModel, compute_snr
|
||||
from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
@@ -577,9 +577,10 @@ def main():
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
|
||||
)
|
||||
|
||||
# Freeze vae and text_encoder
|
||||
# Freeze vae and text_encoder and set unet to trainable
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
unet.train()
|
||||
|
||||
# Create EMA for the unet.
|
||||
if args.use_ema:
|
||||
@@ -601,30 +602,6 @@ def main():
|
||||
else:
|
||||
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
||||
|
||||
def compute_snr(timesteps):
|
||||
"""
|
||||
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
||||
"""
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod
|
||||
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
||||
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
||||
|
||||
# Expand the tensors.
|
||||
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
|
||||
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
||||
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
||||
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
||||
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
||||
|
||||
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
||||
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
||||
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
||||
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
||||
|
||||
# Compute SNR.
|
||||
snr = (alpha / sigma) ** 2
|
||||
return snr
|
||||
|
||||
# `accelerate` 0.16.0 will have better support for customized saving
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
@@ -878,29 +855,29 @@ def main():
|
||||
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||
)
|
||||
args.resume_from_checkpoint = None
|
||||
initial_global_step = 0
|
||||
else:
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(args.output_dir, path))
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
||||
initial_global_step = global_step
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
else:
|
||||
initial_global_step = 0
|
||||
|
||||
progress_bar = tqdm(
|
||||
range(0, args.max_train_steps),
|
||||
initial=initial_global_step,
|
||||
desc="Steps",
|
||||
# Only show the progress bar once on each machine.
|
||||
disable=not accelerator.is_local_main_process,
|
||||
)
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
unet.train()
|
||||
train_loss = 0.0
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(unet):
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()
|
||||
@@ -951,26 +928,14 @@ def main():
|
||||
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(timesteps)
|
||||
base_weight = (
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective needs to be floored to an SNR weight of one.
|
||||
mse_loss_weights = base_weight + 1
|
||||
else:
|
||||
# Epsilon and sample both use the same loss weights.
|
||||
mse_loss_weights = base_weight
|
||||
|
||||
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
|
||||
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
|
||||
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
|
||||
mse_loss_weights[snr == 0] = 1.0
|
||||
|
||||
# We first calculate the original loss. Then we mean over the non-batch dimensions and
|
||||
# rebalance the sample-wise losses with their respective loss weights.
|
||||
# Finally, we take the mean of the rebalanced loss.
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
loss = loss.mean()
|
||||
|
||||
@@ -43,6 +43,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DCon
|
||||
from diffusers.loaders import AttnProcsLayers
|
||||
from diffusers.models.attention_processor import LoRAAttnProcessor
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import compute_snr
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
@@ -428,7 +429,6 @@ def main():
|
||||
# freeze parameters of models to save more memory
|
||||
unet.requires_grad_(False)
|
||||
vae.requires_grad_(False)
|
||||
|
||||
text_encoder.requires_grad_(False)
|
||||
|
||||
# For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
|
||||
@@ -491,30 +491,6 @@ def main():
|
||||
else:
|
||||
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
||||
|
||||
def compute_snr(timesteps):
|
||||
"""
|
||||
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
||||
"""
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod
|
||||
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
||||
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
||||
|
||||
# Expand the tensors.
|
||||
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
|
||||
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
||||
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
||||
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
||||
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
||||
|
||||
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
||||
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
||||
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
||||
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
||||
|
||||
# Compute SNR.
|
||||
snr = (alpha / sigma) ** 2
|
||||
return snr
|
||||
|
||||
lora_layers = AttnProcsLayers(unet.attn_processors)
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
@@ -713,29 +689,29 @@ def main():
|
||||
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||
)
|
||||
args.resume_from_checkpoint = None
|
||||
initial_global_step = 0
|
||||
else:
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(args.output_dir, path))
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
||||
initial_global_step = global_step
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
||||
else:
|
||||
initial_global_step = 0
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
progress_bar = tqdm(
|
||||
range(0, args.max_train_steps),
|
||||
initial=initial_global_step,
|
||||
desc="Steps",
|
||||
# Only show the progress bar once on each machine.
|
||||
disable=not accelerator.is_local_main_process,
|
||||
)
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
unet.train()
|
||||
train_loss = 0.0
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(unet):
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||
@@ -782,26 +758,14 @@ def main():
|
||||
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(timesteps)
|
||||
base_weight = (
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective needs to be floored to an SNR weight of one.
|
||||
mse_loss_weights = base_weight + 1
|
||||
else:
|
||||
# Epsilon and sample both use the same loss weights.
|
||||
mse_loss_weights = base_weight
|
||||
|
||||
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
|
||||
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
|
||||
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
|
||||
mse_loss_weights[snr == 0] = 1.0
|
||||
|
||||
# We first calculate the original loss. Then we mean over the non-batch dimensions and
|
||||
# rebalance the sample-wise losses with their respective loss weights.
|
||||
# Finally, we take the mean of the rebalanced loss.
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
loss = loss.mean()
|
||||
|
||||
@@ -50,8 +50,9 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.loaders import LoraLoaderMixin, text_encoder_lora_state_dict
|
||||
from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0
|
||||
from diffusers.models.lora import LoRALinearLayer
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import compute_snr
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
@@ -608,53 +609,42 @@ def main(args):
|
||||
|
||||
# now we will add new LoRA weights to the attention layers
|
||||
# Set correct lora layers
|
||||
unet_lora_attn_procs = {}
|
||||
unet_lora_parameters = []
|
||||
for name, attn_processor in unet.attn_processors.items():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = unet.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = unet.config.block_out_channels[block_id]
|
||||
for attn_processor_name, attn_processor in unet.attn_processors.items():
|
||||
# Parse the attention module.
|
||||
attn_module = unet
|
||||
for n in attn_processor_name.split(".")[:-1]:
|
||||
attn_module = getattr(attn_module, n)
|
||||
|
||||
lora_attn_processor_class = (
|
||||
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
|
||||
# Set the `lora_layer` attribute of the attention-related matrices.
|
||||
attn_module.to_q.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
|
||||
)
|
||||
)
|
||||
module = lora_attn_processor_class(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=args.rank
|
||||
attn_module.to_k.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
|
||||
)
|
||||
)
|
||||
attn_module.to_v.set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
|
||||
)
|
||||
)
|
||||
attn_module.to_out[0].set_lora_layer(
|
||||
LoRALinearLayer(
|
||||
in_features=attn_module.to_out[0].in_features,
|
||||
out_features=attn_module.to_out[0].out_features,
|
||||
rank=args.rank,
|
||||
)
|
||||
)
|
||||
unet_lora_attn_procs[name] = module
|
||||
unet_lora_parameters.extend(module.parameters())
|
||||
|
||||
unet.set_attn_processor(unet_lora_attn_procs)
|
||||
|
||||
def compute_snr(timesteps):
|
||||
"""
|
||||
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
||||
"""
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod
|
||||
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
||||
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
||||
|
||||
# Expand the tensors.
|
||||
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
|
||||
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
||||
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
||||
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
||||
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
||||
|
||||
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
||||
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
||||
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
||||
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
||||
|
||||
# Compute SNR.
|
||||
snr = (alpha / sigma) ** 2
|
||||
return snr
|
||||
# Accumulate the LoRA params to optimize.
|
||||
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
|
||||
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
|
||||
|
||||
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
|
||||
# So, instead, we monkey-patch the forward calls of its attention-blocks.
|
||||
@@ -970,18 +960,25 @@ def main(args):
|
||||
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||
)
|
||||
args.resume_from_checkpoint = None
|
||||
initial_global_step = 0
|
||||
else:
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(args.output_dir, path))
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
||||
initial_global_step = global_step
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
else:
|
||||
initial_global_step = 0
|
||||
|
||||
progress_bar = tqdm(
|
||||
range(0, args.max_train_steps),
|
||||
initial=initial_global_step,
|
||||
desc="Steps",
|
||||
# Only show the progress bar once on each machine.
|
||||
disable=not accelerator.is_local_main_process,
|
||||
)
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
unet.train()
|
||||
@@ -990,12 +987,6 @@ def main(args):
|
||||
text_encoder_two.train()
|
||||
train_loss = 0.0
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(unet):
|
||||
# Convert images to latent space
|
||||
if args.pretrained_vae_model_name_or_path is not None:
|
||||
@@ -1071,26 +1062,14 @@ def main(args):
|
||||
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(timesteps)
|
||||
base_weight = (
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective needs to be floored to an SNR weight of one.
|
||||
mse_loss_weights = base_weight + 1
|
||||
else:
|
||||
# Epsilon and sample both use the same loss weights.
|
||||
mse_loss_weights = base_weight
|
||||
|
||||
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
|
||||
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
|
||||
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
|
||||
mse_loss_weights[snr == 0] = 1.0
|
||||
|
||||
# We first calculate the original loss. Then we mean over the non-batch dimensions and
|
||||
# rebalance the sample-wise losses with their respective loss weights.
|
||||
# Finally, we take the mean of the rebalanced loss.
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
loss = loss.mean()
|
||||
|
||||
@@ -51,7 +51,7 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import EMAModel
|
||||
from diffusers.training_utils import EMAModel, compute_snr
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
@@ -657,6 +657,8 @@ def main(args):
|
||||
vae.requires_grad_(False)
|
||||
text_encoder_one.requires_grad_(False)
|
||||
text_encoder_two.requires_grad_(False)
|
||||
# Set unet as trainable.
|
||||
unet.train()
|
||||
|
||||
# For mixed precision training we cast all non-trainable weigths to half-precision
|
||||
# as these weights are only used for inference, keeping weights in full precision is not required.
|
||||
@@ -692,30 +694,6 @@ def main(args):
|
||||
else:
|
||||
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
||||
|
||||
def compute_snr(timesteps):
|
||||
"""
|
||||
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
||||
"""
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod
|
||||
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
||||
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
||||
|
||||
# Expand the tensors.
|
||||
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
|
||||
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
||||
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
||||
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
||||
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
||||
|
||||
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
||||
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
||||
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
||||
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
||||
|
||||
# Compute SNR.
|
||||
snr = (alpha / sigma) ** 2
|
||||
return snr
|
||||
|
||||
# `accelerate` 0.16.0 will have better support for customized saving
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
@@ -991,29 +969,29 @@ def main(args):
|
||||
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||
)
|
||||
args.resume_from_checkpoint = None
|
||||
initial_global_step = 0
|
||||
else:
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(args.output_dir, path))
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
||||
initial_global_step = global_step
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
else:
|
||||
initial_global_step = 0
|
||||
|
||||
progress_bar = tqdm(
|
||||
range(0, args.max_train_steps),
|
||||
initial=initial_global_step,
|
||||
desc="Steps",
|
||||
# Only show the progress bar once on each machine.
|
||||
disable=not accelerator.is_local_main_process,
|
||||
)
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
unet.train()
|
||||
train_loss = 0.0
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(unet):
|
||||
# Sample noise that we'll add to the latents
|
||||
model_input = batch["model_input"].to(accelerator.device)
|
||||
@@ -1088,26 +1066,14 @@ def main(args):
|
||||
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(timesteps)
|
||||
base_weight = (
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective needs to be floored to an SNR weight of one.
|
||||
mse_loss_weights = base_weight + 1
|
||||
else:
|
||||
# Epsilon and sample both use the same loss weights.
|
||||
mse_loss_weights = base_weight
|
||||
|
||||
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
|
||||
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
|
||||
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
|
||||
mse_loss_weights[snr == 0] = 1.0
|
||||
|
||||
# We first calculate the original loss. Then we mean over the non-batch dimensions and
|
||||
# rebalance the sample-wise losses with their respective loss weights.
|
||||
# Finally, we take the mean of the rebalanced loss.
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
loss = loss.mean()
|
||||
|
||||
@@ -809,18 +809,25 @@ def main():
|
||||
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||
)
|
||||
args.resume_from_checkpoint = None
|
||||
initial_global_step = 0
|
||||
else:
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(args.output_dir, path))
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
||||
initial_global_step = global_step
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
else:
|
||||
initial_global_step = 0
|
||||
|
||||
progress_bar = tqdm(
|
||||
range(0, args.max_train_steps),
|
||||
initial=initial_global_step,
|
||||
desc="Steps",
|
||||
# Only show the progress bar once on each machine.
|
||||
disable=not accelerator.is_local_main_process,
|
||||
)
|
||||
|
||||
# keep original embeddings as reference
|
||||
orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()
|
||||
@@ -828,12 +835,6 @@ def main():
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
text_encoder.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(text_encoder):
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
|
||||
|
||||
@@ -607,28 +607,28 @@ def main(args):
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
||||
if args.checkpoints_total_limit is not None:
|
||||
checkpoints = os.listdir(args.output_dir)
|
||||
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
||||
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
||||
if accelerator.is_main_process:
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
||||
if args.checkpoints_total_limit is not None:
|
||||
checkpoints = os.listdir(args.output_dir)
|
||||
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
||||
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
||||
|
||||
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
|
||||
if len(checkpoints) >= args.checkpoints_total_limit:
|
||||
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
||||
removing_checkpoints = checkpoints[0:num_to_remove]
|
||||
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
|
||||
if len(checkpoints) >= args.checkpoints_total_limit:
|
||||
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
||||
removing_checkpoints = checkpoints[0:num_to_remove]
|
||||
|
||||
logger.info(
|
||||
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
||||
)
|
||||
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
||||
logger.info(
|
||||
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
||||
)
|
||||
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
||||
|
||||
for removing_checkpoint in removing_checkpoints:
|
||||
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
||||
shutil.rmtree(removing_checkpoint)
|
||||
for removing_checkpoint in removing_checkpoints:
|
||||
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
||||
shutil.rmtree(removing_checkpoint)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
@@ -73,17 +73,17 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
||||
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
||||
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
||||
|
||||
new_item = new_item.replace("q.weight", "query.weight")
|
||||
new_item = new_item.replace("q.bias", "query.bias")
|
||||
new_item = new_item.replace("q.weight", "to_q.weight")
|
||||
new_item = new_item.replace("q.bias", "to_q.bias")
|
||||
|
||||
new_item = new_item.replace("k.weight", "key.weight")
|
||||
new_item = new_item.replace("k.bias", "key.bias")
|
||||
new_item = new_item.replace("k.weight", "to_k.weight")
|
||||
new_item = new_item.replace("k.bias", "to_k.bias")
|
||||
|
||||
new_item = new_item.replace("v.weight", "value.weight")
|
||||
new_item = new_item.replace("v.bias", "value.bias")
|
||||
new_item = new_item.replace("v.weight", "to_v.weight")
|
||||
new_item = new_item.replace("v.bias", "to_v.bias")
|
||||
|
||||
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
||||
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
||||
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
|
||||
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
|
||||
|
||||
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
||||
|
||||
@@ -92,6 +92,19 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
||||
return mapping
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
|
||||
def conv_attn_to_linear(checkpoint):
|
||||
keys = list(checkpoint.keys())
|
||||
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
||||
for key in keys:
|
||||
if ".".join(key.split(".")[-2:]) in attn_keys:
|
||||
if checkpoint[key].ndim > 2:
|
||||
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
||||
elif "proj_attn.weight" in key:
|
||||
if checkpoint[key].ndim > 2:
|
||||
checkpoint[key] = checkpoint[key][:, :, 0]
|
||||
|
||||
|
||||
# Modified from diffusers.pipelines.stable_diffusion.convert_from_ckpt.assign_to_checkpoint
|
||||
# config.num_head_channels => num_head_channels
|
||||
def assign_to_checkpoint(
|
||||
@@ -104,8 +117,9 @@ def assign_to_checkpoint(
|
||||
):
|
||||
"""
|
||||
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
|
||||
attention layers, and takes into account additional replacements that may arise. Assigns the weights to the new
|
||||
checkpoint.
|
||||
attention layers, and takes into account additional replacements that may arise.
|
||||
|
||||
Assigns the weights to the new checkpoint.
|
||||
"""
|
||||
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
||||
|
||||
@@ -143,25 +157,16 @@ def assign_to_checkpoint(
|
||||
new_path = new_path.replace(replacement["old"], replacement["new"])
|
||||
|
||||
# proj_attn.weight has to be converted from conv 1D to linear
|
||||
if "proj_attn.weight" in new_path:
|
||||
is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path)
|
||||
shape = old_checkpoint[path["old"]].shape
|
||||
if is_attn_weight and len(shape) == 3:
|
||||
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
||||
elif is_attn_weight and len(shape) == 4:
|
||||
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
|
||||
else:
|
||||
checkpoint[new_path] = old_checkpoint[path["old"]]
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear
|
||||
def conv_attn_to_linear(checkpoint):
|
||||
keys = list(checkpoint.keys())
|
||||
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
||||
for key in keys:
|
||||
if ".".join(key.split(".")[-2:]) in attn_keys:
|
||||
if checkpoint[key].ndim > 2:
|
||||
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
||||
elif "proj_attn.weight" in key:
|
||||
if checkpoint[key].ndim > 2:
|
||||
checkpoint[key] = checkpoint[key][:, :, 0]
|
||||
|
||||
|
||||
def create_vae_diffusers_config(config_type):
|
||||
# Hardcoded for now
|
||||
if args.config_type == "test":
|
||||
@@ -339,7 +344,7 @@ def create_text_decoder_config_big():
|
||||
return text_decoder_config
|
||||
|
||||
|
||||
# Based on diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments.convert_ldm_vae_checkpoint
|
||||
# Based on diffusers.pipelines.stable_diffusion.convert_from_ckpt.convert_ldm_vae_checkpoint
|
||||
def convert_vae_to_diffusers(ckpt, diffusers_model, num_head_channels=1):
|
||||
"""
|
||||
Converts a UniDiffuser autoencoder_kl.pth checkpoint to a diffusers AutoencoderKL.
|
||||
@@ -674,6 +679,11 @@ if __name__ == "__main__":
|
||||
type=int,
|
||||
help="The UniDiffuser model type to convert to. Should be 0 for UniDiffuser-v0 and 1 for UniDiffuser-v1.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--safe_serialization",
|
||||
action="store_true",
|
||||
help="Whether to use safetensors/safe seialization when saving the pipeline.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -766,11 +776,11 @@ if __name__ == "__main__":
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
image_encoder=image_encoder,
|
||||
image_processor=image_processor,
|
||||
clip_image_processor=image_processor,
|
||||
clip_tokenizer=clip_tokenizer,
|
||||
text_decoder=text_decoder,
|
||||
text_tokenizer=text_tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
pipeline.save_pretrained(args.pipeline_output_path)
|
||||
pipeline.save_pretrained(args.pipeline_output_path, safe_serialization=args.safe_serialization)
|
||||
|
||||
@@ -102,8 +102,8 @@ _deps = [
|
||||
"importlib_metadata",
|
||||
"invisible-watermark>=0.2.0",
|
||||
"isort>=5.5.4",
|
||||
"jax>=0.2.8,!=0.3.2",
|
||||
"jaxlib>=0.1.65",
|
||||
"jax>=0.4.1",
|
||||
"jaxlib>=0.4.1",
|
||||
"Jinja2",
|
||||
"k-diffusion>=0.0.12",
|
||||
"torchsde",
|
||||
@@ -255,6 +255,7 @@ setup(
|
||||
url="https://github.com/huggingface/diffusers",
|
||||
package_dir={"": "src"},
|
||||
packages=find_packages("src"),
|
||||
package_data={"diffusers": ["py.typed"]},
|
||||
include_package_data=True,
|
||||
python_requires=">=3.8.0",
|
||||
install_requires=list(install_requires),
|
||||
|
||||
@@ -15,8 +15,8 @@ deps = {
|
||||
"importlib_metadata": "importlib_metadata",
|
||||
"invisible-watermark": "invisible-watermark>=0.2.0",
|
||||
"isort": "isort>=5.5.4",
|
||||
"jax": "jax>=0.2.8,!=0.3.2",
|
||||
"jaxlib": "jaxlib>=0.1.65",
|
||||
"jax": "jax>=0.4.1",
|
||||
"jaxlib": "jaxlib>=0.4.1",
|
||||
"Jinja2": "Jinja2",
|
||||
"k-diffusion": "k-diffusion>=0.0.12",
|
||||
"torchsde": "torchsde",
|
||||
|
||||
+524
-93
@@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
@@ -27,26 +26,33 @@ from huggingface_hub import hf_hub_download, model_info
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
|
||||
from . import __version__
|
||||
from .models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||
from .utils import (
|
||||
DIFFUSERS_CACHE,
|
||||
HF_HUB_OFFLINE,
|
||||
USE_PEFT_BACKEND,
|
||||
_get_model_file,
|
||||
convert_state_dict_to_diffusers,
|
||||
convert_state_dict_to_peft,
|
||||
convert_unet_state_dict_to_peft,
|
||||
deprecate,
|
||||
get_adapter_name,
|
||||
get_peft_kwargs,
|
||||
is_accelerate_available,
|
||||
is_omegaconf_available,
|
||||
is_peft_available,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
recurse_remove_peft_layers,
|
||||
scale_lora_layers,
|
||||
set_adapter_layers,
|
||||
set_weights_and_activate_adapters,
|
||||
)
|
||||
from .utils.import_utils import BACKENDS_MAPPING
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, PreTrainedModel
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
@@ -66,19 +72,6 @@ TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
|
||||
CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
|
||||
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
|
||||
|
||||
|
||||
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
|
||||
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
|
||||
# available.
|
||||
# For PEFT it is has to be greater than 0.6.0 and for transformers it has to be greater than 4.33.1.
|
||||
_required_peft_version = is_peft_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("peft")).base_version
|
||||
) > version.parse("0.5")
|
||||
_required_transformers_version = version.parse(
|
||||
version.parse(importlib.metadata.version("transformers")).base_version
|
||||
) > version.parse("4.33")
|
||||
|
||||
USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version
|
||||
LORA_DEPRECATION_MESSAGE = "You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT make sure to install the latest PEFT and transformers packages in the future."
|
||||
|
||||
|
||||
@@ -115,7 +108,7 @@ class PatchedLoraProjection(nn.Module):
|
||||
|
||||
return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)
|
||||
|
||||
def _fuse_lora(self, lora_scale=1.0):
|
||||
def _fuse_lora(self, lora_scale=1.0, safe_fusing=False):
|
||||
if self.lora_linear_layer is None:
|
||||
return
|
||||
|
||||
@@ -129,6 +122,14 @@ class PatchedLoraProjection(nn.Module):
|
||||
w_up = w_up * self.lora_linear_layer.network_alpha / self.lora_linear_layer.rank
|
||||
|
||||
fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
|
||||
|
||||
if safe_fusing and torch.isnan(fused_weight).any().item():
|
||||
raise ValueError(
|
||||
"This LoRA weight seems to be broken. "
|
||||
f"Encountered NaN values when trying to fuse LoRA weights for {self}."
|
||||
"LoRA weights will not be fused."
|
||||
)
|
||||
|
||||
self.regular_linear_layer.weight.data = fused_weight.to(device=device, dtype=dtype)
|
||||
|
||||
# we can drop the lora layer now
|
||||
@@ -399,7 +400,7 @@ class UNet2DConditionLoadersMixin:
|
||||
# fill attn processors
|
||||
lora_layers_list = []
|
||||
|
||||
is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys())
|
||||
is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys()) and not USE_PEFT_BACKEND
|
||||
is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
|
||||
|
||||
if is_lora:
|
||||
@@ -513,6 +514,10 @@ class UNet2DConditionLoadersMixin:
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
attn_processors[key].load_state_dict(value_dict)
|
||||
elif USE_PEFT_BACKEND:
|
||||
# In that case we have nothing to do as loading the adapter weights is already handled above by `set_peft_model_state_dict`
|
||||
# on the Unet
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
|
||||
@@ -523,33 +528,36 @@ class UNet2DConditionLoadersMixin:
|
||||
# Now we remove any existing hooks to
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
if _pipeline is not None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, nn.Module):
|
||||
if hasattr(component, "_hf_hook"):
|
||||
|
||||
# For PEFT backend the Unet is already offloaded at this stage as it is handled inside `lora_lora_weights_into_unet`
|
||||
if not USE_PEFT_BACKEND:
|
||||
if _pipeline is not None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
||||
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
|
||||
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
|
||||
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
# only custom diffusion needs to set attn processors
|
||||
if is_custom_diffusion:
|
||||
self.set_attn_processor(attn_processors)
|
||||
# only custom diffusion needs to set attn processors
|
||||
if is_custom_diffusion:
|
||||
self.set_attn_processor(attn_processors)
|
||||
|
||||
# set lora layers
|
||||
for target_module, lora_layer in lora_layers_list:
|
||||
target_module.set_lora_layer(lora_layer)
|
||||
# set lora layers
|
||||
for target_module, lora_layer in lora_layers_list:
|
||||
target_module.set_lora_layer(lora_layer)
|
||||
|
||||
self.to(dtype=self.dtype, device=self.device)
|
||||
self.to(dtype=self.dtype, device=self.device)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas):
|
||||
is_new_lora_format = all(
|
||||
@@ -666,20 +674,83 @@ class UNet2DConditionLoadersMixin:
|
||||
save_function(state_dict, os.path.join(save_directory, weight_name))
|
||||
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
|
||||
|
||||
def fuse_lora(self, lora_scale=1.0):
|
||||
def fuse_lora(self, lora_scale=1.0, safe_fusing=False):
|
||||
self.lora_scale = lora_scale
|
||||
self._safe_fusing = safe_fusing
|
||||
self.apply(self._fuse_lora_apply)
|
||||
|
||||
def _fuse_lora_apply(self, module):
|
||||
if hasattr(module, "_fuse_lora"):
|
||||
module._fuse_lora(self.lora_scale)
|
||||
if not USE_PEFT_BACKEND:
|
||||
if hasattr(module, "_fuse_lora"):
|
||||
module._fuse_lora(self.lora_scale, self._safe_fusing)
|
||||
else:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
if self.lora_scale != 1.0:
|
||||
module.scale_layer(self.lora_scale)
|
||||
module.merge(safe_merge=self._safe_fusing)
|
||||
|
||||
def unfuse_lora(self):
|
||||
self.apply(self._unfuse_lora_apply)
|
||||
|
||||
def _unfuse_lora_apply(self, module):
|
||||
if hasattr(module, "_unfuse_lora"):
|
||||
module._unfuse_lora()
|
||||
if not USE_PEFT_BACKEND:
|
||||
if hasattr(module, "_unfuse_lora"):
|
||||
module._unfuse_lora()
|
||||
else:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
|
||||
def set_adapters(
|
||||
self,
|
||||
adapter_names: Union[List[str], str],
|
||||
weights: Optional[Union[List[float], float]] = None,
|
||||
):
|
||||
"""
|
||||
Sets the adapter layers for the unet.
|
||||
|
||||
Args:
|
||||
adapter_names (`List[str]` or `str`):
|
||||
The names of the adapters to use.
|
||||
weights (`Union[List[float], float]`, *optional*):
|
||||
The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
|
||||
adapters.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for `set_adapters()`.")
|
||||
|
||||
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
|
||||
|
||||
if weights is None:
|
||||
weights = [1.0] * len(adapter_names)
|
||||
elif isinstance(weights, float):
|
||||
weights = [weights] * len(adapter_names)
|
||||
|
||||
if len(adapter_names) != len(weights):
|
||||
raise ValueError(
|
||||
f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}."
|
||||
)
|
||||
|
||||
set_weights_and_activate_adapters(self, adapter_names, weights)
|
||||
|
||||
def disable_lora(self):
|
||||
"""
|
||||
Disables the active LoRA layers for the unet.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
set_adapter_layers(self, enabled=False)
|
||||
|
||||
def enable_lora(self):
|
||||
"""
|
||||
Enables the active LoRA layers for the unet.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
set_adapter_layers(self, enabled=True)
|
||||
|
||||
|
||||
def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs):
|
||||
@@ -1098,9 +1169,10 @@ class LoraLoaderMixin:
|
||||
text_encoder_name = TEXT_ENCODER_NAME
|
||||
unet_name = UNET_NAME
|
||||
num_fused_loras = 0
|
||||
use_peft_backend = USE_PEFT_BACKEND
|
||||
|
||||
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
||||
def load_lora_weights(
|
||||
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
||||
):
|
||||
"""
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
|
||||
`self.text_encoder`.
|
||||
@@ -1120,6 +1192,9 @@ class LoraLoaderMixin:
|
||||
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
"""
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
@@ -1135,6 +1210,7 @@ class LoraLoaderMixin:
|
||||
network_alphas=network_alphas,
|
||||
unet=self.unet,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
)
|
||||
self.load_lora_into_text_encoder(
|
||||
@@ -1143,6 +1219,7 @@ class LoraLoaderMixin:
|
||||
text_encoder=self.text_encoder,
|
||||
lora_scale=self.lora_scale,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
)
|
||||
|
||||
@@ -1443,7 +1520,40 @@ class LoraLoaderMixin:
|
||||
return new_state_dict
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, _pipeline=None):
|
||||
def _optionally_disable_offloading(cls, _pipeline):
|
||||
"""
|
||||
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
|
||||
|
||||
Args:
|
||||
_pipeline (`DiffusionPipeline`):
|
||||
The pipeline to disable offloading for.
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
||||
"""
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
|
||||
if _pipeline is not None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
||||
if not is_model_cpu_offload:
|
||||
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
|
||||
if not is_sequential_cpu_offload:
|
||||
is_sequential_cpu_offload = isinstance(component._hf_hook, AlignDevicesHook)
|
||||
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_unet(
|
||||
cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `unet`.
|
||||
|
||||
@@ -1461,6 +1571,9 @@ class LoraLoaderMixin:
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
"""
|
||||
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
@@ -1487,6 +1600,56 @@ class LoraLoaderMixin:
|
||||
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
|
||||
logger.warn(warn_message)
|
||||
|
||||
if USE_PEFT_BACKEND and len(state_dict.keys()) > 0:
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
if adapter_name in getattr(unet, "peft_config", {}):
|
||||
raise ValueError(
|
||||
f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name."
|
||||
)
|
||||
|
||||
state_dict = convert_unet_state_dict_to_peft(state_dict)
|
||||
|
||||
if network_alphas is not None:
|
||||
# The alphas state dict have the same structure as Unet, thus we convert it to peft format using
|
||||
# `convert_unet_state_dict_to_peft` method.
|
||||
network_alphas = convert_unet_state_dict_to_peft(network_alphas)
|
||||
|
||||
rank = {}
|
||||
for key, val in state_dict.items():
|
||||
if "lora_B" in key:
|
||||
rank[key] = val.shape[1]
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(unet)
|
||||
|
||||
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
||||
# otherwise loading LoRA weights will lead to an error
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
||||
|
||||
inject_adapter_in_model(lora_config, unet, adapter_name=adapter_name)
|
||||
incompatible_keys = set_peft_model_state_dict(unet, state_dict, adapter_name)
|
||||
|
||||
if incompatible_keys is not None:
|
||||
# check only for unexpected keys
|
||||
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
||||
if unexpected_keys:
|
||||
logger.warning(
|
||||
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
||||
f" {unexpected_keys}. "
|
||||
)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
unet.load_attn_procs(
|
||||
state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage, _pipeline=_pipeline
|
||||
)
|
||||
@@ -1500,6 +1663,7 @@ class LoraLoaderMixin:
|
||||
prefix=None,
|
||||
lora_scale=1.0,
|
||||
low_cpu_mem_usage=None,
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
):
|
||||
"""
|
||||
@@ -1523,6 +1687,9 @@ class LoraLoaderMixin:
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
"""
|
||||
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
|
||||
@@ -1545,7 +1712,7 @@ class LoraLoaderMixin:
|
||||
rank = {}
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
|
||||
|
||||
if cls.use_peft_backend:
|
||||
if USE_PEFT_BACKEND:
|
||||
# convert state dict
|
||||
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
|
||||
|
||||
@@ -1558,6 +1725,7 @@ class LoraLoaderMixin:
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
rank_key_fc1 = f"{name}.fc1.lora_B.weight"
|
||||
rank_key_fc2 = f"{name}.fc2.lora_B.weight"
|
||||
|
||||
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
|
||||
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
|
||||
else:
|
||||
@@ -1581,25 +1749,31 @@ class LoraLoaderMixin:
|
||||
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
||||
}
|
||||
|
||||
if cls.use_peft_backend:
|
||||
if USE_PEFT_BACKEND:
|
||||
from peft import LoraConfig
|
||||
|
||||
lora_rank = list(rank.values())[0]
|
||||
# By definition, the scale should be alpha divided by rank.
|
||||
# https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/tuners/lora/layer.py#L71
|
||||
alpha = lora_scale * lora_rank
|
||||
lora_config_kwargs = get_peft_kwargs(
|
||||
rank, network_alphas, text_encoder_lora_state_dict, is_unet=False
|
||||
)
|
||||
|
||||
target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
|
||||
if patch_mlp:
|
||||
target_modules += ["fc1", "fc2"]
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# TODO: support multi alpha / rank: https://github.com/huggingface/peft/pull/873
|
||||
lora_config = LoraConfig(r=lora_rank, target_modules=target_modules, lora_alpha=alpha)
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(text_encoder)
|
||||
|
||||
text_encoder.load_adapter(adapter_state_dict=text_encoder_lora_state_dict, peft_config=lora_config)
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
||||
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
# inject LoRA layers and load the state dict
|
||||
# in transformers we automatically check whether the adapter name is already in use or not
|
||||
text_encoder.load_adapter(
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=text_encoder_lora_state_dict,
|
||||
peft_config=lora_config,
|
||||
)
|
||||
|
||||
# scale LoRA layers with `lora_scale`
|
||||
scale_lora_layers(text_encoder, weight=lora_scale)
|
||||
else:
|
||||
cls._modify_text_encoder(
|
||||
text_encoder,
|
||||
@@ -1671,7 +1845,7 @@ class LoraLoaderMixin:
|
||||
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
|
||||
|
||||
def _remove_text_encoder_monkey_patch(self):
|
||||
if self.use_peft_backend:
|
||||
if USE_PEFT_BACKEND:
|
||||
remove_method = recurse_remove_peft_layers
|
||||
else:
|
||||
remove_method = self._remove_text_encoder_monkey_patch_classmethod
|
||||
@@ -1679,18 +1853,20 @@ class LoraLoaderMixin:
|
||||
if hasattr(self, "text_encoder"):
|
||||
remove_method(self.text_encoder)
|
||||
|
||||
if self.use_peft_backend:
|
||||
# In case text encoder have no Lora attached
|
||||
if USE_PEFT_BACKEND and getattr(self.text_encoder, "peft_config", None) is not None:
|
||||
del self.text_encoder.peft_config
|
||||
self.text_encoder._hf_peft_config_loaded = None
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
remove_method(self.text_encoder_2)
|
||||
if self.use_peft_backend:
|
||||
if USE_PEFT_BACKEND:
|
||||
del self.text_encoder_2.peft_config
|
||||
self.text_encoder_2._hf_peft_config_loaded = None
|
||||
|
||||
@classmethod
|
||||
def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
|
||||
deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.23", LORA_DEPRECATION_MESSAGE)
|
||||
if version.parse(__version__) > version.parse("0.23"):
|
||||
deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.25", LORA_DEPRECATION_MESSAGE)
|
||||
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
@@ -1718,7 +1894,8 @@ class LoraLoaderMixin:
|
||||
r"""
|
||||
Monkey-patches the forward passes of attention modules of the text encoder.
|
||||
"""
|
||||
deprecate("_modify_text_encoder", "0.23", LORA_DEPRECATION_MESSAGE)
|
||||
if version.parse(__version__) > version.parse("0.23"):
|
||||
deprecate("_modify_text_encoder", "0.25", LORA_DEPRECATION_MESSAGE)
|
||||
|
||||
def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters):
|
||||
linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model
|
||||
@@ -2058,14 +2235,31 @@ class LoraLoaderMixin:
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
for _, module in self.unet.named_modules():
|
||||
if hasattr(module, "set_lora_layer"):
|
||||
module.set_lora_layer(None)
|
||||
if not USE_PEFT_BACKEND:
|
||||
if version.parse(__version__) > version.parse("0.23"):
|
||||
logger.warn(
|
||||
"You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights,"
|
||||
"you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT."
|
||||
)
|
||||
|
||||
for _, module in self.unet.named_modules():
|
||||
if hasattr(module, "set_lora_layer"):
|
||||
module.set_lora_layer(None)
|
||||
else:
|
||||
recurse_remove_peft_layers(self.unet)
|
||||
if hasattr(self.unet, "peft_config"):
|
||||
del self.unet.peft_config
|
||||
|
||||
# Safe to call the following regardless of LoRA.
|
||||
self._remove_text_encoder_monkey_patch()
|
||||
|
||||
def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora_scale: float = 1.0):
|
||||
def fuse_lora(
|
||||
self,
|
||||
fuse_unet: bool = True,
|
||||
fuse_text_encoder: bool = True,
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
):
|
||||
r"""
|
||||
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
||||
|
||||
@@ -2082,6 +2276,8 @@ class LoraLoaderMixin:
|
||||
LoRA parameters then it won't have any effect.
|
||||
lora_scale (`float`, defaults to 1.0):
|
||||
Controls how much to influence the outputs with the LoRA parameters.
|
||||
safe_fusing (`bool`, defaults to `False`):
|
||||
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
||||
"""
|
||||
if fuse_unet or fuse_text_encoder:
|
||||
self.num_fused_loras += 1
|
||||
@@ -2091,12 +2287,13 @@ class LoraLoaderMixin:
|
||||
)
|
||||
|
||||
if fuse_unet:
|
||||
self.unet.fuse_lora(lora_scale)
|
||||
self.unet.fuse_lora(lora_scale, safe_fusing=safe_fusing)
|
||||
|
||||
if self.use_peft_backend:
|
||||
if USE_PEFT_BACKEND:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0):
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False):
|
||||
# TODO(Patrick, Younes): enable "safe" fusing
|
||||
for module in text_encoder.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
if lora_scale != 1.0:
|
||||
@@ -2105,26 +2302,27 @@ class LoraLoaderMixin:
|
||||
module.merge()
|
||||
|
||||
else:
|
||||
deprecate("fuse_text_encoder_lora", "0.23", LORA_DEPRECATION_MESSAGE)
|
||||
if version.parse(__version__) > version.parse("0.23"):
|
||||
deprecate("fuse_text_encoder_lora", "0.25", LORA_DEPRECATION_MESSAGE)
|
||||
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0):
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False):
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj._fuse_lora(lora_scale)
|
||||
attn_module.k_proj._fuse_lora(lora_scale)
|
||||
attn_module.v_proj._fuse_lora(lora_scale)
|
||||
attn_module.out_proj._fuse_lora(lora_scale)
|
||||
attn_module.q_proj._fuse_lora(lora_scale, safe_fusing)
|
||||
attn_module.k_proj._fuse_lora(lora_scale, safe_fusing)
|
||||
attn_module.v_proj._fuse_lora(lora_scale, safe_fusing)
|
||||
attn_module.out_proj._fuse_lora(lora_scale, safe_fusing)
|
||||
|
||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
||||
mlp_module.fc1._fuse_lora(lora_scale)
|
||||
mlp_module.fc2._fuse_lora(lora_scale)
|
||||
mlp_module.fc1._fuse_lora(lora_scale, safe_fusing)
|
||||
mlp_module.fc2._fuse_lora(lora_scale, safe_fusing)
|
||||
|
||||
if fuse_text_encoder:
|
||||
if hasattr(self, "text_encoder"):
|
||||
fuse_text_encoder_lora(self.text_encoder, lora_scale)
|
||||
fuse_text_encoder_lora(self.text_encoder, lora_scale, safe_fusing)
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
fuse_text_encoder_lora(self.text_encoder_2, lora_scale)
|
||||
fuse_text_encoder_lora(self.text_encoder_2, lora_scale, safe_fusing)
|
||||
|
||||
def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True):
|
||||
r"""
|
||||
@@ -2144,10 +2342,17 @@ class LoraLoaderMixin:
|
||||
LoRA parameters then it won't have any effect.
|
||||
"""
|
||||
if unfuse_unet:
|
||||
self.unet.unfuse_lora()
|
||||
if not USE_PEFT_BACKEND:
|
||||
self.unet.unfuse_lora()
|
||||
else:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
if self.use_peft_backend:
|
||||
from peft.tuners.tuner_utils import BaseTunerLayer
|
||||
for module in self.unet.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
def unfuse_text_encoder_lora(text_encoder):
|
||||
for module in text_encoder.modules():
|
||||
@@ -2155,7 +2360,8 @@ class LoraLoaderMixin:
|
||||
module.unmerge()
|
||||
|
||||
else:
|
||||
deprecate("unfuse_text_encoder_lora", "0.23", LORA_DEPRECATION_MESSAGE)
|
||||
if version.parse(__version__) > version.parse("0.23"):
|
||||
deprecate("unfuse_text_encoder_lora", "0.25", LORA_DEPRECATION_MESSAGE)
|
||||
|
||||
def unfuse_text_encoder_lora(text_encoder):
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
@@ -2178,6 +2384,214 @@ class LoraLoaderMixin:
|
||||
|
||||
self.num_fused_loras -= 1
|
||||
|
||||
def set_adapters_for_text_encoder(
|
||||
self,
|
||||
adapter_names: Union[List[str], str],
|
||||
text_encoder: Optional[PreTrainedModel] = None,
|
||||
text_encoder_weights: List[float] = None,
|
||||
):
|
||||
"""
|
||||
Sets the adapter layers for the text encoder.
|
||||
|
||||
Args:
|
||||
adapter_names (`List[str]` or `str`):
|
||||
The names of the adapters to use.
|
||||
text_encoder (`torch.nn.Module`, *optional*):
|
||||
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
|
||||
attribute.
|
||||
text_encoder_weights (`List[float]`, *optional*):
|
||||
The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
def process_weights(adapter_names, weights):
|
||||
if weights is None:
|
||||
weights = [1.0] * len(adapter_names)
|
||||
elif isinstance(weights, float):
|
||||
weights = [weights]
|
||||
|
||||
if len(adapter_names) != len(weights):
|
||||
raise ValueError(
|
||||
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
|
||||
)
|
||||
return weights
|
||||
|
||||
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
|
||||
text_encoder_weights = process_weights(adapter_names, text_encoder_weights)
|
||||
text_encoder = text_encoder or getattr(self, "text_encoder", None)
|
||||
if text_encoder is None:
|
||||
raise ValueError(
|
||||
"The pipeline does not have a default `pipe.text_encoder` class. Please make sure to pass a `text_encoder` instead."
|
||||
)
|
||||
set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights)
|
||||
|
||||
def disable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel] = None):
|
||||
"""
|
||||
Disables the LoRA layers for the text encoder.
|
||||
|
||||
Args:
|
||||
text_encoder (`torch.nn.Module`, *optional*):
|
||||
The text encoder module to disable the LoRA layers for. If `None`, it will try to get the
|
||||
`text_encoder` attribute.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
text_encoder = text_encoder or getattr(self, "text_encoder", None)
|
||||
if text_encoder is None:
|
||||
raise ValueError("Text Encoder not found.")
|
||||
set_adapter_layers(text_encoder, enabled=False)
|
||||
|
||||
def enable_lora_for_text_encoder(self, text_encoder: Optional[PreTrainedModel] = None):
|
||||
"""
|
||||
Enables the LoRA layers for the text encoder.
|
||||
|
||||
Args:
|
||||
text_encoder (`torch.nn.Module`, *optional*):
|
||||
The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder`
|
||||
attribute.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
text_encoder = text_encoder or getattr(self, "text_encoder", None)
|
||||
if text_encoder is None:
|
||||
raise ValueError("Text Encoder not found.")
|
||||
set_adapter_layers(self.text_encoder, enabled=True)
|
||||
|
||||
def set_adapters(
|
||||
self,
|
||||
adapter_names: Union[List[str], str],
|
||||
adapter_weights: Optional[List[float]] = None,
|
||||
):
|
||||
# Handle the UNET
|
||||
self.unet.set_adapters(adapter_names, adapter_weights)
|
||||
|
||||
# Handle the Text Encoder
|
||||
if hasattr(self, "text_encoder"):
|
||||
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder, adapter_weights)
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder_2, adapter_weights)
|
||||
|
||||
def disable_lora(self):
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
# Disable unet adapters
|
||||
self.unet.disable_lora()
|
||||
|
||||
# Disable text encoder adapters
|
||||
if hasattr(self, "text_encoder"):
|
||||
self.disable_lora_for_text_encoder(self.text_encoder)
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
self.disable_lora_for_text_encoder(self.text_encoder_2)
|
||||
|
||||
def enable_lora(self):
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
# Enable unet adapters
|
||||
self.unet.enable_lora()
|
||||
|
||||
# Enable text encoder adapters
|
||||
if hasattr(self, "text_encoder"):
|
||||
self.enable_lora_for_text_encoder(self.text_encoder)
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
self.enable_lora_for_text_encoder(self.text_encoder_2)
|
||||
|
||||
def get_active_adapters(self) -> List[str]:
|
||||
"""
|
||||
Gets the list of the current active adapters.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
|
||||
pipeline.get_active_adapters()
|
||||
```
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError(
|
||||
"PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
|
||||
)
|
||||
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
active_adapters = []
|
||||
|
||||
for module in self.unet.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
active_adapters = module.active_adapters
|
||||
break
|
||||
|
||||
return active_adapters
|
||||
|
||||
def get_list_adapters(self) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Gets the current list of all available adapters in the pipeline.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError(
|
||||
"PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
|
||||
)
|
||||
|
||||
set_adapters = {}
|
||||
|
||||
if hasattr(self, "text_encoder") and hasattr(self.text_encoder, "peft_config"):
|
||||
set_adapters["text_encoder"] = list(self.text_encoder.peft_config.keys())
|
||||
|
||||
if hasattr(self, "text_encoder_2") and hasattr(self.text_encoder_2, "peft_config"):
|
||||
set_adapters["text_encoder_2"] = list(self.text_encoder_2.peft_config.keys())
|
||||
|
||||
if hasattr(self, "unet") and hasattr(self.unet, "peft_config"):
|
||||
set_adapters["unet"] = list(self.unet.peft_config.keys())
|
||||
|
||||
return set_adapters
|
||||
|
||||
def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None:
|
||||
"""
|
||||
Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
|
||||
you want to load multiple adapters and free some GPU memory.
|
||||
|
||||
Args:
|
||||
adapter_names (`List[str]`):
|
||||
List of adapters to send device to.
|
||||
device (`Union[torch.device, str, int]`):
|
||||
Device to send the adapters to. Can be either a torch device, a str or an integer.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
# Handle the UNET
|
||||
for unet_module in self.unet.modules():
|
||||
if isinstance(unet_module, BaseTunerLayer):
|
||||
for adapter_name in adapter_names:
|
||||
unet_module.lora_A[adapter_name].to(device)
|
||||
unet_module.lora_B[adapter_name].to(device)
|
||||
|
||||
# Handle the text encoder
|
||||
modules_to_process = []
|
||||
if hasattr(self, "text_encoder"):
|
||||
modules_to_process.append(self.text_encoder)
|
||||
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
modules_to_process.append(self.text_encoder_2)
|
||||
|
||||
for text_encoder in modules_to_process:
|
||||
# loop over submodules
|
||||
for text_encoder_module in text_encoder.modules():
|
||||
if isinstance(text_encoder_module, BaseTunerLayer):
|
||||
for adapter_name in adapter_names:
|
||||
text_encoder_module.lora_A[adapter_name].to(device)
|
||||
text_encoder_module.lora_B[adapter_name].to(device)
|
||||
|
||||
|
||||
class FromSingleFileMixin:
|
||||
"""
|
||||
@@ -2335,8 +2749,12 @@ class FromSingleFileMixin:
|
||||
from .models.controlnet import ControlNetModel
|
||||
from .pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
||||
|
||||
# Model type will be inferred from the checkpoint.
|
||||
if not isinstance(controlnet, (ControlNetModel, MultiControlNetModel)):
|
||||
# list/tuple or a single instance of ControlNetModel or MultiControlNetModel
|
||||
if not (
|
||||
isinstance(controlnet, (ControlNetModel, MultiControlNetModel))
|
||||
or isinstance(controlnet, (list, tuple))
|
||||
and isinstance(controlnet[0], ControlNetModel)
|
||||
):
|
||||
raise ValueError("ControlNet needs to be passed if loading from ControlNet pipeline.")
|
||||
elif "StableDiffusion" in pipeline_name:
|
||||
# Model type will be inferred from the checkpoint.
|
||||
@@ -2758,7 +3176,12 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
"""This class overrides `LoraLoaderMixin` with LoRA loading/saving code that's specific to SDXL"""
|
||||
|
||||
# Overrride to properly handle the loading and unloading of the additional text encoder.
|
||||
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
adapter_name: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
|
||||
`self.text_encoder`.
|
||||
@@ -2776,6 +3199,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
@@ -2793,7 +3219,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet, _pipeline=self)
|
||||
self.load_lora_into_unet(
|
||||
state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, _pipeline=self
|
||||
)
|
||||
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
||||
if len(text_encoder_state_dict) > 0:
|
||||
self.load_lora_into_text_encoder(
|
||||
@@ -2802,6 +3230,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
text_encoder=self.text_encoder,
|
||||
prefix="text_encoder",
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
)
|
||||
|
||||
@@ -2813,6 +3242,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
text_encoder=self.text_encoder_2,
|
||||
prefix="text_encoder_2",
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
)
|
||||
|
||||
@@ -2879,16 +3309,17 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
)
|
||||
|
||||
def _remove_text_encoder_monkey_patch(self):
|
||||
if self.use_peft_backend:
|
||||
if USE_PEFT_BACKEND:
|
||||
recurse_remove_peft_layers(self.text_encoder)
|
||||
# TODO: @younesbelkada handle this in transformers side
|
||||
del self.text_encoder.peft_config
|
||||
self.text_encoder._hf_peft_config_loaded = None
|
||||
if getattr(self.text_encoder, "peft_config", None) is not None:
|
||||
del self.text_encoder.peft_config
|
||||
self.text_encoder._hf_peft_config_loaded = None
|
||||
|
||||
recurse_remove_peft_layers(self.text_encoder_2)
|
||||
|
||||
del self.text_encoder_2.peft_config
|
||||
self.text_encoder_2._hf_peft_config_loaded = None
|
||||
if getattr(self.text_encoder_2, "peft_config", None) is not None:
|
||||
del self.text_encoder_2.peft_config
|
||||
self.text_encoder_2._hf_peft_config_loaded = None
|
||||
else:
|
||||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
|
||||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
|
||||
|
||||
@@ -1,7 +1,15 @@
|
||||
from torch import nn
|
||||
|
||||
|
||||
def get_activation(act_fn):
|
||||
def get_activation(act_fn: str) -> nn.Module:
|
||||
"""Helper function to get activation function from string.
|
||||
|
||||
Args:
|
||||
act_fn (str): Name of activation function.
|
||||
|
||||
Returns:
|
||||
nn.Module: Activation function.
|
||||
"""
|
||||
if act_fn in ["swish", "silu"]:
|
||||
return nn.SiLU()
|
||||
elif act_fn == "mish":
|
||||
|
||||
+110
-10
@@ -231,7 +231,11 @@ class T2IAdapter(ModelMixin, ConfigMixin):
|
||||
The number of channel of each downsample block's output hidden state. The `len(block_out_channels)` will
|
||||
also determine the number of downsample blocks in the Adapter.
|
||||
num_res_blocks (`int`, *optional*, defaults to 2):
|
||||
Number of ResNet blocks in each downsample block
|
||||
Number of ResNet blocks in each downsample block.
|
||||
downscale_factor (`int`, *optional*, defaults to 8):
|
||||
A factor that determines the total downscale factor of the Adapter.
|
||||
adapter_type (`str`, *optional*, defaults to `full_adapter`):
|
||||
The type of Adapter to use. Choose either `full_adapter` or `full_adapter_xl` or `light_adapter`.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
@@ -258,6 +262,12 @@ class T2IAdapter(ModelMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
r"""
|
||||
This function processes the input tensor `x` through the adapter model and returns a list of feature tensors,
|
||||
each representing information extracted at a different scale from the input. The length of the list is
|
||||
determined by the number of downsample blocks in the Adapter, as specified by the `channels` and
|
||||
`num_res_blocks` parameters during initialization.
|
||||
"""
|
||||
return self.adapter(x)
|
||||
|
||||
@property
|
||||
@@ -269,6 +279,10 @@ class T2IAdapter(ModelMixin, ConfigMixin):
|
||||
|
||||
|
||||
class FullAdapter(nn.Module):
|
||||
r"""
|
||||
See [`T2IAdapter`] for more information.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
@@ -296,6 +310,12 @@ class FullAdapter(nn.Module):
|
||||
self.total_downscale_factor = downscale_factor * 2 ** (len(channels) - 1)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
r"""
|
||||
This method processes the input tensor `x` through the FullAdapter model and performs operations including
|
||||
pixel unshuffling, convolution, and a stack of AdapterBlocks. It returns a list of feature tensors, each
|
||||
capturing information at a different stage of processing within the FullAdapter model. The number of feature
|
||||
tensors in the list is determined by the number of downsample blocks specified during initialization.
|
||||
"""
|
||||
x = self.unshuffle(x)
|
||||
x = self.conv_in(x)
|
||||
|
||||
@@ -309,6 +329,10 @@ class FullAdapter(nn.Module):
|
||||
|
||||
|
||||
class FullAdapterXL(nn.Module):
|
||||
r"""
|
||||
See [`T2IAdapter`] for more information.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
@@ -338,6 +362,10 @@ class FullAdapterXL(nn.Module):
|
||||
self.total_downscale_factor = downscale_factor * 2
|
||||
|
||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
r"""
|
||||
This method takes the tensor x as input and processes it through FullAdapterXL model. It consists of operations
|
||||
including unshuffling pixels, applying convolution layer and appending each block into list of feature tensors.
|
||||
"""
|
||||
x = self.unshuffle(x)
|
||||
x = self.conv_in(x)
|
||||
|
||||
@@ -351,7 +379,22 @@ class FullAdapterXL(nn.Module):
|
||||
|
||||
|
||||
class AdapterBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, num_res_blocks, down=False):
|
||||
r"""
|
||||
An AdapterBlock is a helper model that contains multiple ResNet-like blocks. It is used in the `FullAdapter` and
|
||||
`FullAdapterXL` models.
|
||||
|
||||
Parameters:
|
||||
in_channels (`int`):
|
||||
Number of channels of AdapterBlock's input.
|
||||
out_channels (`int`):
|
||||
Number of channels of AdapterBlock's output.
|
||||
num_res_blocks (`int`):
|
||||
Number of ResNet blocks in the AdapterBlock.
|
||||
down (`bool`, *optional*, defaults to `False`):
|
||||
Whether to perform downsampling on AdapterBlock's input.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False):
|
||||
super().__init__()
|
||||
|
||||
self.downsample = None
|
||||
@@ -366,7 +409,12 @@ class AdapterBlock(nn.Module):
|
||||
*[AdapterResnetBlock(out_channels) for _ in range(num_res_blocks)],
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
This method takes tensor x as input and performs operations downsampling and convolutional layers if the
|
||||
self.downsample and self.in_conv properties of AdapterBlock model are specified. Then it applies a series of
|
||||
residual blocks to the input tensor.
|
||||
"""
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
@@ -379,13 +427,25 @@ class AdapterBlock(nn.Module):
|
||||
|
||||
|
||||
class AdapterResnetBlock(nn.Module):
|
||||
def __init__(self, channels):
|
||||
r"""
|
||||
An `AdapterResnetBlock` is a helper model that implements a ResNet-like block.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
Number of channels of AdapterResnetBlock's input and output.
|
||||
"""
|
||||
|
||||
def __init__(self, channels: int):
|
||||
super().__init__()
|
||||
self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
||||
self.act = nn.ReLU()
|
||||
self.block2 = nn.Conv2d(channels, channels, kernel_size=1)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
This method takes input tensor x and applies a convolutional layer, ReLU activation, and another convolutional
|
||||
layer on the input tensor. It returns addition with the input tensor.
|
||||
"""
|
||||
h = x
|
||||
h = self.block1(h)
|
||||
h = self.act(h)
|
||||
@@ -398,6 +458,10 @@ class AdapterResnetBlock(nn.Module):
|
||||
|
||||
|
||||
class LightAdapter(nn.Module):
|
||||
r"""
|
||||
See [`T2IAdapter`] for more information.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
@@ -424,7 +488,11 @@ class LightAdapter(nn.Module):
|
||||
|
||||
self.total_downscale_factor = downscale_factor * (2 ** len(channels))
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
r"""
|
||||
This method takes the input tensor x and performs downscaling and appends it in list of feature tensors. Each
|
||||
feature tensor corresponds to a different level of processing within the LightAdapter.
|
||||
"""
|
||||
x = self.unshuffle(x)
|
||||
|
||||
features = []
|
||||
@@ -437,7 +505,22 @@ class LightAdapter(nn.Module):
|
||||
|
||||
|
||||
class LightAdapterBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, num_res_blocks, down=False):
|
||||
r"""
|
||||
A `LightAdapterBlock` is a helper model that contains multiple `LightAdapterResnetBlocks`. It is used in the
|
||||
`LightAdapter` model.
|
||||
|
||||
Parameters:
|
||||
in_channels (`int`):
|
||||
Number of channels of LightAdapterBlock's input.
|
||||
out_channels (`int`):
|
||||
Number of channels of LightAdapterBlock's output.
|
||||
num_res_blocks (`int`):
|
||||
Number of LightAdapterResnetBlocks in the LightAdapterBlock.
|
||||
down (`bool`, *optional*, defaults to `False`):
|
||||
Whether to perform downsampling on LightAdapterBlock's input.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int, num_res_blocks: int, down: bool = False):
|
||||
super().__init__()
|
||||
mid_channels = out_channels // 4
|
||||
|
||||
@@ -449,7 +532,11 @@ class LightAdapterBlock(nn.Module):
|
||||
self.resnets = nn.Sequential(*[LightAdapterResnetBlock(mid_channels) for _ in range(num_res_blocks)])
|
||||
self.out_conv = nn.Conv2d(mid_channels, out_channels, kernel_size=1)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
This method takes tensor x as input and performs downsampling if required. Then it applies in convolution
|
||||
layer, a sequence of residual blocks, and out convolutional layer.
|
||||
"""
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
@@ -461,13 +548,26 @@ class LightAdapterBlock(nn.Module):
|
||||
|
||||
|
||||
class LightAdapterResnetBlock(nn.Module):
|
||||
def __init__(self, channels):
|
||||
"""
|
||||
A `LightAdapterResnetBlock` is a helper model that implements a ResNet-like block with a slightly different
|
||||
architecture than `AdapterResnetBlock`.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
Number of channels of LightAdapterResnetBlock's input and output.
|
||||
"""
|
||||
|
||||
def __init__(self, channels: int):
|
||||
super().__init__()
|
||||
self.block1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
||||
self.act = nn.ReLU()
|
||||
self.block2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
This function takes input tensor x and processes it through one convolutional layer, ReLU activation, and
|
||||
another convolutional layer and adds it to input tensor.
|
||||
"""
|
||||
h = x
|
||||
h = self.block1(h)
|
||||
h = self.act(h)
|
||||
|
||||
@@ -11,12 +11,13 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..utils import USE_PEFT_BACKEND
|
||||
from ..utils.torch_utils import maybe_allow_in_graph
|
||||
from .activations import get_activation
|
||||
from .attention_processor import Attention
|
||||
@@ -26,7 +27,17 @@ from .lora import LoRACompatibleLinear
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class GatedSelfAttentionDense(nn.Module):
|
||||
def __init__(self, query_dim, context_dim, n_heads, d_head):
|
||||
r"""
|
||||
A gated self-attention dense layer that combines visual features and object features.
|
||||
|
||||
Parameters:
|
||||
query_dim (`int`): The number of channels in the query.
|
||||
context_dim (`int`): The number of channels in the context.
|
||||
n_heads (`int`): The number of heads to use for attention.
|
||||
d_head (`int`): The number of channels in each head.
|
||||
"""
|
||||
|
||||
def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
|
||||
super().__init__()
|
||||
|
||||
# we need a linear projection since we need cat visual feature and obj feature
|
||||
@@ -43,7 +54,7 @@ class GatedSelfAttentionDense(nn.Module):
|
||||
|
||||
self.enabled = True
|
||||
|
||||
def forward(self, x, objs):
|
||||
def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
|
||||
if not self.enabled:
|
||||
return x
|
||||
|
||||
@@ -67,15 +78,25 @@ class BasicTransformerBlock(nn.Module):
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
||||
only_cross_attention (`bool`, *optional*):
|
||||
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
||||
double_self_attention (`bool`, *optional*):
|
||||
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm (:
|
||||
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
||||
attention_bias (:
|
||||
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
||||
only_cross_attention (`bool`, *optional*):
|
||||
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
||||
double_self_attention (`bool`, *optional*):
|
||||
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
||||
upcast_attention (`bool`, *optional*):
|
||||
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
||||
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use learnable elementwise affine parameters for normalization.
|
||||
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
||||
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
||||
final_dropout (`bool` *optional*, defaults to False):
|
||||
Whether to apply a final dropout after the last feed-forward layer.
|
||||
attention_type (`str`, *optional*, defaults to `"default"`):
|
||||
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -175,7 +196,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
) -> torch.FloatTensor:
|
||||
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||
# 0. Self-Attention
|
||||
if self.use_ada_layer_norm:
|
||||
@@ -186,6 +207,10 @@ class BasicTransformerBlock(nn.Module):
|
||||
)
|
||||
else:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
# print("After first norm")
|
||||
# print(f"hidden_states: {hidden_states.dtype}")
|
||||
# print(f"norm_hidden_states: {norm_hidden_states.dtype}")
|
||||
# print(f"encoder_hidden_states: {norm_hidden_states.dtype}")
|
||||
|
||||
# 1. Retrieve lora scale.
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
@@ -202,7 +227,9 @@ class BasicTransformerBlock(nn.Module):
|
||||
)
|
||||
if self.use_ada_layer_norm_zero:
|
||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
# print(f"attn_output: {attn_output.dtype}")
|
||||
hidden_states = attn_output + hidden_states
|
||||
# print(f"attn_output: {attn_output.dtype}")
|
||||
|
||||
# 2.5 GLIGEN Control
|
||||
if gligen_kwargs is not None:
|
||||
@@ -280,6 +307,7 @@ class FeedForward(nn.Module):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = dim_out if dim_out is not None else dim
|
||||
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
|
||||
|
||||
if activation_fn == "gelu":
|
||||
act_fn = GELU(dim, inner_dim)
|
||||
@@ -296,14 +324,15 @@ class FeedForward(nn.Module):
|
||||
# project dropout
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
# project out
|
||||
self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
|
||||
self.net.append(linear_cls(inner_dim, dim_out))
|
||||
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
||||
if final_dropout:
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
|
||||
def forward(self, hidden_states, scale: float = 1.0):
|
||||
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
|
||||
compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
|
||||
for module in self.net:
|
||||
if isinstance(module, (LoRACompatibleLinear, GEGLU)):
|
||||
if isinstance(module, compatible_cls):
|
||||
hidden_states = module(hidden_states, scale)
|
||||
else:
|
||||
hidden_states = module(hidden_states)
|
||||
@@ -313,6 +342,11 @@ class FeedForward(nn.Module):
|
||||
class GELU(nn.Module):
|
||||
r"""
|
||||
GELU activation function with tanh approximation support with `approximate="tanh"`.
|
||||
|
||||
Parameters:
|
||||
dim_in (`int`): The number of channels in the input.
|
||||
dim_out (`int`): The number of channels in the output.
|
||||
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
|
||||
"""
|
||||
|
||||
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
|
||||
@@ -320,7 +354,7 @@ class GELU(nn.Module):
|
||||
self.proj = nn.Linear(dim_in, dim_out)
|
||||
self.approximate = approximate
|
||||
|
||||
def gelu(self, gate):
|
||||
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
||||
if gate.device.type != "mps":
|
||||
return F.gelu(gate, approximate=self.approximate)
|
||||
# mps: gelu is not implemented for float16
|
||||
@@ -343,48 +377,58 @@ class GEGLU(nn.Module):
|
||||
|
||||
def __init__(self, dim_in: int, dim_out: int):
|
||||
super().__init__()
|
||||
self.proj = LoRACompatibleLinear(dim_in, dim_out * 2)
|
||||
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
|
||||
|
||||
def gelu(self, gate):
|
||||
self.proj = linear_cls(dim_in, dim_out * 2)
|
||||
|
||||
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
||||
if gate.device.type != "mps":
|
||||
return F.gelu(gate)
|
||||
# mps: gelu is not implemented for float16
|
||||
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
||||
|
||||
def forward(self, hidden_states, scale: float = 1.0):
|
||||
hidden_states, gate = self.proj(hidden_states, scale).chunk(2, dim=-1)
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
|
||||
return hidden_states * self.gelu(gate)
|
||||
|
||||
|
||||
class ApproximateGELU(nn.Module):
|
||||
"""
|
||||
The approximate form of Gaussian Error Linear Unit (GELU)
|
||||
r"""
|
||||
The approximate form of Gaussian Error Linear Unit (GELU). For more details, see section 2:
|
||||
https://arxiv.org/abs/1606.08415.
|
||||
|
||||
For more details, see section 2: https://arxiv.org/abs/1606.08415
|
||||
Parameters:
|
||||
dim_in (`int`): The number of channels in the input.
|
||||
dim_out (`int`): The number of channels in the output.
|
||||
"""
|
||||
|
||||
def __init__(self, dim_in: int, dim_out: int):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.proj(x)
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
"""
|
||||
r"""
|
||||
Norm layer modified to incorporate timestep embeddings.
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
num_embeddings (`int`): The size of the dictionary of embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim, num_embeddings):
|
||||
def __init__(self, embedding_dim: int, num_embeddings: int):
|
||||
super().__init__()
|
||||
self.emb = nn.Embedding(num_embeddings, embedding_dim)
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
|
||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
|
||||
|
||||
def forward(self, x, timestep):
|
||||
def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
|
||||
emb = self.linear(self.silu(self.emb(timestep)))
|
||||
scale, shift = torch.chunk(emb, 2)
|
||||
x = self.norm(x) * (1 + scale) + shift
|
||||
@@ -392,11 +436,15 @@ class AdaLayerNorm(nn.Module):
|
||||
|
||||
|
||||
class AdaLayerNormZero(nn.Module):
|
||||
"""
|
||||
r"""
|
||||
Norm layer adaptive layer norm zero (adaLN-Zero).
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
num_embeddings (`int`): The size of the dictionary of embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim, num_embeddings):
|
||||
def __init__(self, embedding_dim: int, num_embeddings: int):
|
||||
super().__init__()
|
||||
|
||||
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
|
||||
@@ -405,7 +453,13 @@ class AdaLayerNormZero(nn.Module):
|
||||
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
def forward(self, x, timestep, class_labels, hidden_dtype=None):
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
class_labels: torch.LongTensor,
|
||||
hidden_dtype: Optional[torch.dtype] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
|
||||
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||
@@ -413,8 +467,15 @@ class AdaLayerNormZero(nn.Module):
|
||||
|
||||
|
||||
class AdaGroupNorm(nn.Module):
|
||||
"""
|
||||
r"""
|
||||
GroupNorm layer modified to incorporate timestep embeddings.
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
num_embeddings (`int`): The size of the dictionary of embeddings.
|
||||
num_groups (`int`): The number of groups to separate the channels into.
|
||||
act_fn (`str`, *optional*, defaults to `None`): The activation function to use.
|
||||
eps (`float`, *optional*, defaults to `1e-5`): The epsilon value to use for numerical stability.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -431,7 +492,7 @@ class AdaGroupNorm(nn.Module):
|
||||
|
||||
self.linear = nn.Linear(embedding_dim, out_dim * 2)
|
||||
|
||||
def forward(self, x, emb):
|
||||
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
|
||||
if self.act:
|
||||
emb = self.act(emb)
|
||||
emb = self.linear(emb)
|
||||
|
||||
@@ -131,6 +131,9 @@ class FlaxAttention(nn.Module):
|
||||
Dropout rate
|
||||
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
|
||||
enable memory efficient attention https://arxiv.org/abs/2112.05682
|
||||
split_head_dim (`bool`, *optional*, defaults to `False`):
|
||||
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
|
||||
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
|
||||
@@ -140,6 +143,7 @@ class FlaxAttention(nn.Module):
|
||||
dim_head: int = 64
|
||||
dropout: float = 0.0
|
||||
use_memory_efficient_attention: bool = False
|
||||
split_head_dim: bool = False
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
|
||||
def setup(self):
|
||||
@@ -177,9 +181,15 @@ class FlaxAttention(nn.Module):
|
||||
key_proj = self.key(context)
|
||||
value_proj = self.value(context)
|
||||
|
||||
query_states = self.reshape_heads_to_batch_dim(query_proj)
|
||||
key_states = self.reshape_heads_to_batch_dim(key_proj)
|
||||
value_states = self.reshape_heads_to_batch_dim(value_proj)
|
||||
if self.split_head_dim:
|
||||
b = hidden_states.shape[0]
|
||||
query_states = jnp.reshape(query_proj, (b, -1, self.heads, self.dim_head))
|
||||
key_states = jnp.reshape(key_proj, (b, -1, self.heads, self.dim_head))
|
||||
value_states = jnp.reshape(value_proj, (b, -1, self.heads, self.dim_head))
|
||||
else:
|
||||
query_states = self.reshape_heads_to_batch_dim(query_proj)
|
||||
key_states = self.reshape_heads_to_batch_dim(key_proj)
|
||||
value_states = self.reshape_heads_to_batch_dim(value_proj)
|
||||
|
||||
if self.use_memory_efficient_attention:
|
||||
query_states = query_states.transpose(1, 0, 2)
|
||||
@@ -206,14 +216,23 @@ class FlaxAttention(nn.Module):
|
||||
hidden_states = hidden_states.transpose(1, 0, 2)
|
||||
else:
|
||||
# compute attentions
|
||||
attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
|
||||
if self.split_head_dim:
|
||||
attention_scores = jnp.einsum("b t n h, b f n h -> b n f t", key_states, query_states)
|
||||
else:
|
||||
attention_scores = jnp.einsum("b i d, b j d->b i j", query_states, key_states)
|
||||
|
||||
attention_scores = attention_scores * self.scale
|
||||
attention_probs = nn.softmax(attention_scores, axis=2)
|
||||
attention_probs = nn.softmax(attention_scores, axis=-1 if self.split_head_dim else 2)
|
||||
|
||||
# attend to values
|
||||
hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
|
||||
if self.split_head_dim:
|
||||
hidden_states = jnp.einsum("b n f t, b t n h -> b f n h", attention_probs, value_states)
|
||||
b = hidden_states.shape[0]
|
||||
hidden_states = jnp.reshape(hidden_states, (b, -1, self.heads * self.dim_head))
|
||||
else:
|
||||
hidden_states = jnp.einsum("b i j, b j d -> b i d", attention_probs, value_states)
|
||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||
|
||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||
hidden_states = self.proj_attn(hidden_states)
|
||||
return self.dropout_layer(hidden_states, deterministic=deterministic)
|
||||
|
||||
@@ -239,6 +258,9 @@ class FlaxBasicTransformerBlock(nn.Module):
|
||||
Parameters `dtype`
|
||||
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
|
||||
enable memory efficient attention https://arxiv.org/abs/2112.05682
|
||||
split_head_dim (`bool`, *optional*, defaults to `False`):
|
||||
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
|
||||
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
|
||||
"""
|
||||
dim: int
|
||||
n_heads: int
|
||||
@@ -247,15 +269,28 @@ class FlaxBasicTransformerBlock(nn.Module):
|
||||
only_cross_attention: bool = False
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
use_memory_efficient_attention: bool = False
|
||||
split_head_dim: bool = False
|
||||
|
||||
def setup(self):
|
||||
# self attention (or cross_attention if only_cross_attention is True)
|
||||
self.attn1 = FlaxAttention(
|
||||
self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype
|
||||
self.dim,
|
||||
self.n_heads,
|
||||
self.d_head,
|
||||
self.dropout,
|
||||
self.use_memory_efficient_attention,
|
||||
self.split_head_dim,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
# cross attention
|
||||
self.attn2 = FlaxAttention(
|
||||
self.dim, self.n_heads, self.d_head, self.dropout, self.use_memory_efficient_attention, dtype=self.dtype
|
||||
self.dim,
|
||||
self.n_heads,
|
||||
self.d_head,
|
||||
self.dropout,
|
||||
self.use_memory_efficient_attention,
|
||||
self.split_head_dim,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.ff = FlaxFeedForward(dim=self.dim, dropout=self.dropout, dtype=self.dtype)
|
||||
self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)
|
||||
@@ -308,6 +343,9 @@ class FlaxTransformer2DModel(nn.Module):
|
||||
Parameters `dtype`
|
||||
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
|
||||
enable memory efficient attention https://arxiv.org/abs/2112.05682
|
||||
split_head_dim (`bool`, *optional*, defaults to `False`):
|
||||
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
|
||||
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
|
||||
"""
|
||||
in_channels: int
|
||||
n_heads: int
|
||||
@@ -318,6 +356,7 @@ class FlaxTransformer2DModel(nn.Module):
|
||||
only_cross_attention: bool = False
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
use_memory_efficient_attention: bool = False
|
||||
split_head_dim: bool = False
|
||||
|
||||
def setup(self):
|
||||
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
|
||||
@@ -343,6 +382,7 @@ class FlaxTransformer2DModel(nn.Module):
|
||||
only_cross_attention=self.only_cross_attention,
|
||||
dtype=self.dtype,
|
||||
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
||||
split_head_dim=self.split_head_dim,
|
||||
)
|
||||
for _ in range(self.depth)
|
||||
]
|
||||
|
||||
@@ -18,7 +18,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..utils import deprecate, logging
|
||||
from ..utils import USE_PEFT_BACKEND, deprecate, logging
|
||||
from ..utils.import_utils import is_xformers_available
|
||||
from ..utils.torch_utils import maybe_allow_in_graph
|
||||
from .lora import LoRACompatibleLinear, LoRALinearLayer
|
||||
@@ -137,22 +137,27 @@ class Attention(nn.Module):
|
||||
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
|
||||
)
|
||||
|
||||
self.to_q = LoRACompatibleLinear(query_dim, self.inner_dim, bias=bias)
|
||||
if USE_PEFT_BACKEND:
|
||||
linear_cls = nn.Linear
|
||||
else:
|
||||
linear_cls = LoRACompatibleLinear
|
||||
|
||||
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
|
||||
|
||||
if not self.only_cross_attention:
|
||||
# only relevant for the `AddedKVProcessor` classes
|
||||
self.to_k = LoRACompatibleLinear(self.cross_attention_dim, self.inner_dim, bias=bias)
|
||||
self.to_v = LoRACompatibleLinear(self.cross_attention_dim, self.inner_dim, bias=bias)
|
||||
self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
|
||||
self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
|
||||
else:
|
||||
self.to_k = None
|
||||
self.to_v = None
|
||||
|
||||
if self.added_kv_proj_dim is not None:
|
||||
self.add_k_proj = LoRACompatibleLinear(added_kv_proj_dim, self.inner_dim)
|
||||
self.add_v_proj = LoRACompatibleLinear(added_kv_proj_dim, self.inner_dim)
|
||||
self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
|
||||
self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
|
||||
|
||||
self.to_out = nn.ModuleList([])
|
||||
self.to_out.append(LoRACompatibleLinear(self.inner_dim, query_dim, bias=out_bias))
|
||||
self.to_out.append(linear_cls(self.inner_dim, query_dim, bias=out_bias))
|
||||
self.to_out.append(nn.Dropout(dropout))
|
||||
|
||||
# set attention processor
|
||||
@@ -310,19 +315,16 @@ class Attention(nn.Module):
|
||||
|
||||
self.set_processor(processor)
|
||||
|
||||
def set_processor(self, processor: "AttnProcessor"):
|
||||
if (
|
||||
hasattr(self, "processor")
|
||||
and not isinstance(processor, LORA_ATTENTION_PROCESSORS)
|
||||
and self.to_q.lora_layer is not None
|
||||
):
|
||||
def set_processor(self, processor: "AttnProcessor", _remove_lora=False):
|
||||
if hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None:
|
||||
deprecate(
|
||||
"set_processor to offload LoRA",
|
||||
"0.26.0",
|
||||
"In detail, removing LoRA layers via calling `set_processor` or `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.",
|
||||
"In detail, removing LoRA layers via calling `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.",
|
||||
)
|
||||
# TODO(Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete
|
||||
# We need to remove all LoRA layers
|
||||
# Don't forget to remove ALL `_remove_lora` from the codebase
|
||||
for module in self.modules():
|
||||
if hasattr(module, "set_lora_layer"):
|
||||
module.set_lora_layer(None)
|
||||
@@ -548,6 +550,8 @@ class AttnProcessor:
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
@@ -565,15 +569,15 @@ class AttnProcessor:
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states, scale=scale)
|
||||
query = attn.to_q(hidden_states, *args)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states, scale=scale)
|
||||
value = attn.to_v(encoder_hidden_states, scale=scale)
|
||||
key = attn.to_k(encoder_hidden_states, *args)
|
||||
value = attn.to_v(encoder_hidden_states, *args)
|
||||
|
||||
query = attn.head_to_batch_dim(query)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
@@ -584,7 +588,7 @@ class AttnProcessor:
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states, scale=scale)
|
||||
hidden_states = attn.to_out[0](hidden_states, *args)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
@@ -1010,15 +1014,20 @@ class AttnProcessor2_0:
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states, scale=scale)
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
query = attn.to_q(hidden_states, *args)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states, scale=scale)
|
||||
value = attn.to_v(encoder_hidden_states, scale=scale)
|
||||
key = (
|
||||
attn.to_k(encoder_hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_k(encoder_hidden_states)
|
||||
)
|
||||
value = (
|
||||
attn.to_v(encoder_hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_v(encoder_hidden_states)
|
||||
)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
@@ -1038,7 +1047,9 @@ class AttnProcessor2_0:
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states, scale=scale)
|
||||
hidden_states = (
|
||||
attn.to_out[0](hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_out[0](hidden_states)
|
||||
)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
@@ -1650,7 +1661,7 @@ class LoRAXFormersAttnProcessor(nn.Module):
|
||||
"0.26.0",
|
||||
(
|
||||
f"Make sure use {self_cls_name[4:]} instead by setting"
|
||||
"LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
|
||||
"LoRA layers to `self.{to_q,to_k,to_v,add_k_proj,add_v_proj,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
|
||||
" `LoraLoaderMixin.load_lora_weights`"
|
||||
),
|
||||
)
|
||||
@@ -1700,7 +1711,7 @@ class LoRAAttnAddedKVProcessor(nn.Module):
|
||||
"0.26.0",
|
||||
(
|
||||
f"Make sure use {self_cls_name[4:]} instead by setting"
|
||||
"LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
|
||||
"LoRA layers to `self.{to_q,to_k,to_v,add_k_proj,add_v_proj,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
|
||||
" `LoraLoaderMixin.load_lora_weights`"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -196,7 +196,9 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(
|
||||
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
||||
):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -220,9 +222,9 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
module.set_processor(processor, _remove_lora=_remove_lora)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
@@ -244,10 +246,24 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
self.set_attn_processor(processor, _remove_lora=True)
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
||||
def encode(
|
||||
self, x: torch.FloatTensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
Args:
|
||||
x (`torch.FloatTensor`): Input batch of images.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
The latent representations of the encoded images. If `return_dict` is True, a
|
||||
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
||||
"""
|
||||
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
||||
return self.tiled_encode(x, return_dict=return_dict)
|
||||
|
||||
@@ -279,6 +295,20 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
"""
|
||||
Decode a batch of images.
|
||||
|
||||
Args:
|
||||
z (`torch.FloatTensor`): Input batch of latent vectors.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
|
||||
"""
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
|
||||
@@ -517,7 +517,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(
|
||||
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
||||
):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -541,9 +543,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
module.set_processor(processor, _remove_lora=_remove_lora)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
@@ -565,7 +567,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
self.set_attn_processor(processor, _remove_lora=True)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
def set_attention_slice(self, slice_size):
|
||||
@@ -669,7 +671,13 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
||||
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
|
||||
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
|
||||
embeddings.
|
||||
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||
negative values to the attention scores corresponding to "discard" tokens.
|
||||
added_cond_kwargs (`dict`):
|
||||
Additional conditions for the Stable Diffusion XL UNet.
|
||||
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
|
||||
|
||||
@@ -168,7 +168,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
controlnet_conditioning_channel_order: str = "rgb"
|
||||
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256)
|
||||
|
||||
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.Array) -> FrozenDict:
|
||||
# init input tensors
|
||||
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
|
||||
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
|
||||
|
||||
@@ -107,14 +107,18 @@ class DualTransformer2DModel(nn.Module):
|
||||
Args:
|
||||
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
||||
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
||||
hidden_states
|
||||
hidden_states.
|
||||
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
timestep ( `torch.long`, *optional*):
|
||||
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
|
||||
attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Optional attention mask to be applied in Attention
|
||||
Optional attention mask to be applied in Attention.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ..utils import USE_PEFT_BACKEND
|
||||
from .activations import get_activation
|
||||
from .lora import LoRACompatibleLinear
|
||||
|
||||
@@ -166,8 +167,9 @@ class TimestepEmbedding(nn.Module):
|
||||
cond_proj_dim=None,
|
||||
):
|
||||
super().__init__()
|
||||
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
||||
|
||||
self.linear_1 = LoRACompatibleLinear(in_channels, time_embed_dim)
|
||||
self.linear_1 = linear_cls(in_channels, time_embed_dim)
|
||||
|
||||
if cond_proj_dim is not None:
|
||||
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
||||
@@ -180,7 +182,7 @@ class TimestepEmbedding(nn.Module):
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
self.linear_2 = LoRACompatibleLinear(time_embed_dim, time_embed_dim_out)
|
||||
self.linear_2 = linear_cls(time_embed_dim, time_embed_dim_out)
|
||||
|
||||
if post_act_fn is None:
|
||||
self.post_act = None
|
||||
|
||||
+101
-29
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -25,29 +25,50 @@ from ..utils import logging
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0, use_peft_backend: bool = False):
|
||||
if use_peft_backend:
|
||||
from peft.tuners.lora import LoraLayer
|
||||
def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj.lora_scale = lora_scale
|
||||
attn_module.k_proj.lora_scale = lora_scale
|
||||
attn_module.v_proj.lora_scale = lora_scale
|
||||
attn_module.out_proj.lora_scale = lora_scale
|
||||
|
||||
for module in text_encoder.modules():
|
||||
if isinstance(module, LoraLayer):
|
||||
module.scaling[module.active_adapter] = lora_scale
|
||||
else:
|
||||
for _, attn_module in text_encoder_attn_modules(text_encoder):
|
||||
if isinstance(attn_module.q_proj, PatchedLoraProjection):
|
||||
attn_module.q_proj.lora_scale = lora_scale
|
||||
attn_module.k_proj.lora_scale = lora_scale
|
||||
attn_module.v_proj.lora_scale = lora_scale
|
||||
attn_module.out_proj.lora_scale = lora_scale
|
||||
|
||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
||||
mlp_module.fc1.lora_scale = lora_scale
|
||||
mlp_module.fc2.lora_scale = lora_scale
|
||||
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
|
||||
if isinstance(mlp_module.fc1, PatchedLoraProjection):
|
||||
mlp_module.fc1.lora_scale = lora_scale
|
||||
mlp_module.fc2.lora_scale = lora_scale
|
||||
|
||||
|
||||
class LoRALinearLayer(nn.Module):
|
||||
def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
|
||||
r"""
|
||||
A linear layer that is used with LoRA.
|
||||
|
||||
Parameters:
|
||||
in_features (`int`):
|
||||
Number of input features.
|
||||
out_features (`int`):
|
||||
Number of output features.
|
||||
rank (`int`, `optional`, defaults to 4):
|
||||
The rank of the LoRA layer.
|
||||
network_alpha (`float`, `optional`, defaults to `None`):
|
||||
The value of the network alpha used for stable learning and preventing underflow. This value has the same
|
||||
meaning as the `--network_alpha` option in the kohya-ss trainer script. See
|
||||
https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
||||
device (`torch.device`, `optional`, defaults to `None`):
|
||||
The device to use for the layer's weights.
|
||||
dtype (`torch.dtype`, `optional`, defaults to `None`):
|
||||
The dtype to use for the layer's weights.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
rank: int = 4,
|
||||
network_alpha: Optional[float] = None,
|
||||
device: Optional[Union[torch.device, str]] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
|
||||
@@ -62,7 +83,8 @@ class LoRALinearLayer(nn.Module):
|
||||
nn.init.normal_(self.down.weight, std=1 / rank)
|
||||
nn.init.zeros_(self.up.weight)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# print(f"From {self.__class__.__name__}: hidden_states: {hidden_states.dtype}")
|
||||
orig_dtype = hidden_states.dtype
|
||||
dtype = self.down.weight.dtype
|
||||
|
||||
@@ -72,12 +94,43 @@ class LoRALinearLayer(nn.Module):
|
||||
if self.network_alpha is not None:
|
||||
up_hidden_states *= self.network_alpha / self.rank
|
||||
|
||||
return up_hidden_states.to(orig_dtype)
|
||||
out = up_hidden_states.to(orig_dtype)
|
||||
# print(f"From {self.__class__.__name__}: out: {out.dtype}")
|
||||
return out
|
||||
|
||||
|
||||
class LoRAConv2dLayer(nn.Module):
|
||||
r"""
|
||||
A convolutional layer that is used with LoRA.
|
||||
|
||||
Parameters:
|
||||
in_features (`int`):
|
||||
Number of input features.
|
||||
out_features (`int`):
|
||||
Number of output features.
|
||||
rank (`int`, `optional`, defaults to 4):
|
||||
The rank of the LoRA layer.
|
||||
kernel_size (`int` or `tuple` of two `int`, `optional`, defaults to 1):
|
||||
The kernel size of the convolution.
|
||||
stride (`int` or `tuple` of two `int`, `optional`, defaults to 1):
|
||||
The stride of the convolution.
|
||||
padding (`int` or `tuple` of two `int` or `str`, `optional`, defaults to 0):
|
||||
The padding of the convolution.
|
||||
network_alpha (`float`, `optional`, defaults to `None`):
|
||||
The value of the network alpha used for stable learning and preventing underflow. This value has the same
|
||||
meaning as the `--network_alpha` option in the kohya-ss trainer script. See
|
||||
https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_features, out_features, rank=4, kernel_size=(1, 1), stride=(1, 1), padding=0, network_alpha=None
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
rank: int = 4,
|
||||
kernel_size: Union[int, Tuple[int, int]] = (1, 1),
|
||||
stride: Union[int, Tuple[int, int]] = (1, 1),
|
||||
padding: Union[int, Tuple[int, int], str] = 0,
|
||||
network_alpha: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -94,7 +147,7 @@ class LoRAConv2dLayer(nn.Module):
|
||||
nn.init.normal_(self.down.weight, std=1 / rank)
|
||||
nn.init.zeros_(self.up.weight)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
orig_dtype = hidden_states.dtype
|
||||
dtype = self.down.weight.dtype
|
||||
|
||||
@@ -119,7 +172,7 @@ class LoRACompatibleConv(nn.Conv2d):
|
||||
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
|
||||
self.lora_layer = lora_layer
|
||||
|
||||
def _fuse_lora(self, lora_scale=1.0):
|
||||
def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):
|
||||
if self.lora_layer is None:
|
||||
return
|
||||
|
||||
@@ -135,6 +188,14 @@ class LoRACompatibleConv(nn.Conv2d):
|
||||
fusion = torch.mm(w_up.flatten(start_dim=1), w_down.flatten(start_dim=1))
|
||||
fusion = fusion.reshape((w_orig.shape))
|
||||
fused_weight = w_orig + (lora_scale * fusion)
|
||||
|
||||
if safe_fusing and torch.isnan(fused_weight).any().item():
|
||||
raise ValueError(
|
||||
"This LoRA weight seems to be broken. "
|
||||
f"Encountered NaN values when trying to fuse LoRA weights for {self}."
|
||||
"LoRA weights will not be fused."
|
||||
)
|
||||
|
||||
self.weight.data = fused_weight.to(device=device, dtype=dtype)
|
||||
|
||||
# we can drop the lora layer now
|
||||
@@ -163,7 +224,7 @@ class LoRACompatibleConv(nn.Conv2d):
|
||||
self.w_up = None
|
||||
self.w_down = None
|
||||
|
||||
def forward(self, hidden_states, scale: float = 1.0):
|
||||
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
|
||||
if self.lora_layer is None:
|
||||
# make sure to the functional Conv2D function as otherwise torch.compile's graph will break
|
||||
# see: https://github.com/huggingface/diffusers/pull/4315
|
||||
@@ -171,7 +232,10 @@ class LoRACompatibleConv(nn.Conv2d):
|
||||
hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
|
||||
)
|
||||
else:
|
||||
return super().forward(hidden_states) + (scale * self.lora_layer(hidden_states))
|
||||
original_outputs = F.conv2d(
|
||||
hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
|
||||
)
|
||||
return original_outputs + (scale * self.lora_layer(hidden_states))
|
||||
|
||||
|
||||
class LoRACompatibleLinear(nn.Linear):
|
||||
@@ -186,7 +250,7 @@ class LoRACompatibleLinear(nn.Linear):
|
||||
def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
|
||||
self.lora_layer = lora_layer
|
||||
|
||||
def _fuse_lora(self, lora_scale=1.0):
|
||||
def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):
|
||||
if self.lora_layer is None:
|
||||
return
|
||||
|
||||
@@ -200,6 +264,14 @@ class LoRACompatibleLinear(nn.Linear):
|
||||
w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank
|
||||
|
||||
fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
|
||||
|
||||
if safe_fusing and torch.isnan(fused_weight).any().item():
|
||||
raise ValueError(
|
||||
"This LoRA weight seems to be broken. "
|
||||
f"Encountered NaN values when trying to fuse LoRA weights for {self}."
|
||||
"LoRA weights will not be fused."
|
||||
)
|
||||
|
||||
self.weight.data = fused_weight.to(device=device, dtype=dtype)
|
||||
|
||||
# we can drop the lora layer now
|
||||
@@ -226,7 +298,7 @@ class LoRACompatibleLinear(nn.Linear):
|
||||
self.w_up = None
|
||||
self.w_down = None
|
||||
|
||||
def forward(self, hidden_states, scale: float = 1.0):
|
||||
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
|
||||
if self.lora_layer is None:
|
||||
out = super().forward(hidden_states)
|
||||
return out
|
||||
|
||||
@@ -192,7 +192,7 @@ class FlaxModelMixin(PushToHubMixin):
|
||||
```"""
|
||||
return self._cast_floating_to(params, jnp.float16, mask)
|
||||
|
||||
def init_weights(self, rng: jax.random.KeyArray) -> Dict:
|
||||
def init_weights(self, rng: jax.Array) -> Dict:
|
||||
raise NotImplementedError(f"init_weights method has to be implemented for {self}")
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -32,10 +32,12 @@ from ..utils import (
|
||||
DIFFUSERS_CACHE,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
HF_HUB_OFFLINE,
|
||||
MIN_PEFT_VERSION,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
_add_variant,
|
||||
_get_model_file,
|
||||
check_peft_version,
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_torch_version,
|
||||
@@ -187,6 +189,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
|
||||
_supports_gradient_checkpointing = False
|
||||
_keys_to_ignore_on_load_unexpected = None
|
||||
_hf_peft_config_loaded = False
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -292,6 +295,153 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
"""
|
||||
self.set_use_memory_efficient_attention_xformers(False)
|
||||
|
||||
def add_adapter(self, adapter_config, adapter_name: str = "default") -> None:
|
||||
r"""
|
||||
Adds a new adapter to the current model for training. If no adapter name is passed, a default name is assigned
|
||||
to the adapter to follow the convention of the PEFT library.
|
||||
|
||||
If you are not familiar with adapters and PEFT methods, we invite you to read more about them in the PEFT
|
||||
[documentation](https://huggingface.co/docs/peft).
|
||||
|
||||
Args:
|
||||
adapter_config (`[~peft.PeftConfig]`):
|
||||
The configuration of the adapter to add; supported adapters are non-prefix tuning and adaption prompt
|
||||
methods.
|
||||
adapter_name (`str`, *optional*, defaults to `"default"`):
|
||||
The name of the adapter to add. If no name is passed, a default name is assigned to the adapter.
|
||||
"""
|
||||
check_peft_version(min_version=MIN_PEFT_VERSION)
|
||||
|
||||
from peft import PeftConfig, inject_adapter_in_model
|
||||
|
||||
if not self._hf_peft_config_loaded:
|
||||
self._hf_peft_config_loaded = True
|
||||
elif adapter_name in self.peft_config:
|
||||
raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")
|
||||
|
||||
if not isinstance(adapter_config, PeftConfig):
|
||||
raise ValueError(
|
||||
f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead."
|
||||
)
|
||||
|
||||
# Unlike transformers, here we don't need to retrieve the name_or_path of the unet as the loading logic is
|
||||
# handled by the `load_lora_layers` or `LoraLoaderMixin`. Therefore we set it to `None` here.
|
||||
adapter_config.base_model_name_or_path = None
|
||||
inject_adapter_in_model(adapter_config, self, adapter_name)
|
||||
self.set_adapter(adapter_name)
|
||||
|
||||
def set_adapter(self, adapter_name: Union[str, List[str]]) -> None:
|
||||
"""
|
||||
Sets a specific adapter by forcing the model to only use that adapter and disables the other adapters.
|
||||
|
||||
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
|
||||
official documentation: https://huggingface.co/docs/peft
|
||||
|
||||
Args:
|
||||
adapter_name (Union[str, List[str]])):
|
||||
The list of adapters to set or the adapter name in case of single adapter.
|
||||
"""
|
||||
check_peft_version(min_version=MIN_PEFT_VERSION)
|
||||
|
||||
if not self._hf_peft_config_loaded:
|
||||
raise ValueError("No adapter loaded. Please load an adapter first.")
|
||||
|
||||
if isinstance(adapter_name, str):
|
||||
adapter_name = [adapter_name]
|
||||
|
||||
missing = set(adapter_name) - set(self.peft_config)
|
||||
if len(missing) > 0:
|
||||
raise ValueError(
|
||||
f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)."
|
||||
f" current loaded adapters are: {list(self.peft_config.keys())}"
|
||||
)
|
||||
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
_adapters_has_been_set = False
|
||||
|
||||
for _, module in self.named_modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
if hasattr(module, "set_adapter"):
|
||||
module.set_adapter(adapter_name)
|
||||
# Previous versions of PEFT does not support multi-adapter inference
|
||||
elif not hasattr(module, "set_adapter") and len(adapter_name) != 1:
|
||||
raise ValueError(
|
||||
"You are trying to set multiple adapters and you have a PEFT version that does not support multi-adapter inference. Please upgrade to the latest version of PEFT."
|
||||
" `pip install -U peft` or `pip install -U git+https://github.com/huggingface/peft.git`"
|
||||
)
|
||||
else:
|
||||
module.active_adapter = adapter_name
|
||||
_adapters_has_been_set = True
|
||||
|
||||
if not _adapters_has_been_set:
|
||||
raise ValueError(
|
||||
"Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters."
|
||||
)
|
||||
|
||||
def disable_adapters(self) -> None:
|
||||
r"""
|
||||
Disable all adapters attached to the model and fallback to inference with the base model only.
|
||||
|
||||
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
|
||||
official documentation: https://huggingface.co/docs/peft
|
||||
"""
|
||||
check_peft_version(min_version=MIN_PEFT_VERSION)
|
||||
|
||||
if not self._hf_peft_config_loaded:
|
||||
raise ValueError("No adapter loaded. Please load an adapter first.")
|
||||
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
for _, module in self.named_modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
if hasattr(module, "enable_adapters"):
|
||||
module.enable_adapters(enabled=False)
|
||||
else:
|
||||
# support for older PEFT versions
|
||||
module.disable_adapters = True
|
||||
|
||||
def enable_adapters(self) -> None:
|
||||
"""
|
||||
Enable adapters that are attached to the model. The model will use `self.active_adapters()` to retrieve the
|
||||
list of adapters to enable.
|
||||
|
||||
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
|
||||
official documentation: https://huggingface.co/docs/peft
|
||||
"""
|
||||
check_peft_version(min_version=MIN_PEFT_VERSION)
|
||||
|
||||
if not self._hf_peft_config_loaded:
|
||||
raise ValueError("No adapter loaded. Please load an adapter first.")
|
||||
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
for _, module in self.named_modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
if hasattr(module, "enable_adapters"):
|
||||
module.enable_adapters(enabled=True)
|
||||
else:
|
||||
# support for older PEFT versions
|
||||
module.disable_adapters = False
|
||||
|
||||
def active_adapters(self) -> List[str]:
|
||||
"""
|
||||
Gets the current list of active adapters of the model.
|
||||
|
||||
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
|
||||
official documentation: https://huggingface.co/docs/peft
|
||||
"""
|
||||
check_peft_version(min_version=MIN_PEFT_VERSION)
|
||||
|
||||
if not self._hf_peft_config_loaded:
|
||||
raise ValueError("No adapter loaded. Please load an adapter first.")
|
||||
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
for _, module in self.named_modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
return module.active_adapter
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
|
||||
@@ -192,7 +192,9 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(
|
||||
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
||||
):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -216,9 +218,9 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
module.set_processor(processor, _remove_lora=_remove_lora)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
@@ -240,7 +242,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
self.set_attn_processor(processor, _remove_lora=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
+182
-60
@@ -14,12 +14,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..utils import USE_PEFT_BACKEND
|
||||
from .activations import get_activation
|
||||
from .attention import AdaGroupNorm
|
||||
from .attention_processor import SpatialNorm
|
||||
@@ -38,9 +39,18 @@ class Upsample1D(nn.Module):
|
||||
option to use a convolution transpose.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
name (`str`, default `conv`):
|
||||
name of the upsampling 1D layer.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
use_conv: bool = False,
|
||||
use_conv_transpose: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
name: str = "conv",
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
@@ -54,7 +64,7 @@ class Upsample1D(nn.Module):
|
||||
elif use_conv:
|
||||
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
def forward(self, inputs):
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
assert inputs.shape[1] == self.channels
|
||||
if self.use_conv_transpose:
|
||||
return self.conv(inputs)
|
||||
@@ -79,9 +89,18 @@ class Downsample1D(nn.Module):
|
||||
number of output channels. Defaults to `channels`.
|
||||
padding (`int`, default `1`):
|
||||
padding for the convolution.
|
||||
name (`str`, default `conv`):
|
||||
name of the downsampling 1D layer.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
use_conv: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
padding: int = 1,
|
||||
name: str = "conv",
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
@@ -96,7 +115,7 @@ class Downsample1D(nn.Module):
|
||||
assert self.channels == self.out_channels
|
||||
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
|
||||
|
||||
def forward(self, inputs):
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
assert inputs.shape[1] == self.channels
|
||||
return self.conv(inputs)
|
||||
|
||||
@@ -113,21 +132,31 @@ class Upsample2D(nn.Module):
|
||||
option to use a convolution transpose.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
name (`str`, default `conv`):
|
||||
name of the upsampling 2D layer.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
use_conv: bool = False,
|
||||
use_conv_transpose: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
name: str = "conv",
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_conv_transpose = use_conv_transpose
|
||||
self.name = name
|
||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||
|
||||
conv = None
|
||||
if use_conv_transpose:
|
||||
conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
|
||||
elif use_conv:
|
||||
conv = LoRACompatibleConv(self.channels, self.out_channels, 3, padding=1)
|
||||
conv = conv_cls(self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
if name == "conv":
|
||||
@@ -135,7 +164,7 @@ class Upsample2D(nn.Module):
|
||||
else:
|
||||
self.Conv2d_0 = conv
|
||||
|
||||
def forward(self, hidden_states, output_size=None, scale: float = 1.0):
|
||||
def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None, scale: float = 1.0):
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
if self.use_conv_transpose:
|
||||
@@ -166,12 +195,12 @@ class Upsample2D(nn.Module):
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
if self.use_conv:
|
||||
if self.name == "conv":
|
||||
if isinstance(self.conv, LoRACompatibleConv):
|
||||
if isinstance(self.conv, LoRACompatibleConv) and not USE_PEFT_BACKEND:
|
||||
hidden_states = self.conv(hidden_states, scale)
|
||||
else:
|
||||
hidden_states = self.conv(hidden_states)
|
||||
else:
|
||||
if isinstance(self.Conv2d_0, LoRACompatibleConv):
|
||||
if isinstance(self.Conv2d_0, LoRACompatibleConv) and not USE_PEFT_BACKEND:
|
||||
hidden_states = self.Conv2d_0(hidden_states, scale)
|
||||
else:
|
||||
hidden_states = self.Conv2d_0(hidden_states)
|
||||
@@ -191,9 +220,18 @@ class Downsample2D(nn.Module):
|
||||
number of output channels. Defaults to `channels`.
|
||||
padding (`int`, default `1`):
|
||||
padding for the convolution.
|
||||
name (`str`, default `conv`):
|
||||
name of the downsampling 2D layer.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
use_conv: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
padding: int = 1,
|
||||
name: str = "conv",
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
@@ -201,9 +239,10 @@ class Downsample2D(nn.Module):
|
||||
self.padding = padding
|
||||
stride = 2
|
||||
self.name = name
|
||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||
|
||||
if use_conv:
|
||||
conv = LoRACompatibleConv(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
||||
conv = conv_cls(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
|
||||
@@ -219,13 +258,18 @@ class Downsample2D(nn.Module):
|
||||
|
||||
def forward(self, hidden_states, scale: float = 1.0):
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
if self.use_conv and self.padding == 0:
|
||||
pad = (0, 1, 0, 1)
|
||||
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
|
||||
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
if isinstance(self.conv, LoRACompatibleConv):
|
||||
hidden_states = self.conv(hidden_states, scale)
|
||||
|
||||
if not USE_PEFT_BACKEND:
|
||||
if isinstance(self.conv, LoRACompatibleConv):
|
||||
hidden_states = self.conv(hidden_states, scale)
|
||||
else:
|
||||
hidden_states = self.conv(hidden_states)
|
||||
else:
|
||||
hidden_states = self.conv(hidden_states)
|
||||
|
||||
@@ -246,7 +290,13 @@ class FirUpsample2D(nn.Module):
|
||||
kernel for the FIR filter.
|
||||
"""
|
||||
|
||||
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int = None,
|
||||
out_channels: Optional[int] = None,
|
||||
use_conv: bool = False,
|
||||
fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
|
||||
):
|
||||
super().__init__()
|
||||
out_channels = out_channels if out_channels else channels
|
||||
if use_conv:
|
||||
@@ -255,7 +305,14 @@ class FirUpsample2D(nn.Module):
|
||||
self.fir_kernel = fir_kernel
|
||||
self.out_channels = out_channels
|
||||
|
||||
def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
|
||||
def _upsample_2d(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
weight: Optional[torch.Tensor] = None,
|
||||
kernel: Optional[torch.FloatTensor] = None,
|
||||
factor: int = 2,
|
||||
gain: float = 1,
|
||||
) -> torch.Tensor:
|
||||
"""Fused `upsample_2d()` followed by `Conv2d()`.
|
||||
|
||||
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
||||
@@ -335,7 +392,7 @@ class FirUpsample2D(nn.Module):
|
||||
|
||||
return output
|
||||
|
||||
def forward(self, hidden_states):
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if self.use_conv:
|
||||
height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
|
||||
height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
||||
@@ -359,7 +416,13 @@ class FirDownsample2D(nn.Module):
|
||||
kernel for the FIR filter.
|
||||
"""
|
||||
|
||||
def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int = None,
|
||||
out_channels: Optional[int] = None,
|
||||
use_conv: bool = False,
|
||||
fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
|
||||
):
|
||||
super().__init__()
|
||||
out_channels = out_channels if out_channels else channels
|
||||
if use_conv:
|
||||
@@ -368,7 +431,14 @@ class FirDownsample2D(nn.Module):
|
||||
self.use_conv = use_conv
|
||||
self.out_channels = out_channels
|
||||
|
||||
def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
|
||||
def _downsample_2d(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
weight: Optional[torch.Tensor] = None,
|
||||
kernel: Optional[torch.FloatTensor] = None,
|
||||
factor: int = 2,
|
||||
gain: float = 1,
|
||||
) -> torch.Tensor:
|
||||
"""Fused `Conv2d()` followed by `downsample_2d()`.
|
||||
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
||||
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
|
||||
@@ -422,7 +492,7 @@ class FirDownsample2D(nn.Module):
|
||||
|
||||
return output
|
||||
|
||||
def forward(self, hidden_states):
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if self.use_conv:
|
||||
downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
|
||||
hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
||||
@@ -434,14 +504,20 @@ class FirDownsample2D(nn.Module):
|
||||
|
||||
# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
|
||||
class KDownsample2D(nn.Module):
|
||||
def __init__(self, pad_mode="reflect"):
|
||||
r"""A 2D K-downsampling layer.
|
||||
|
||||
Parameters:
|
||||
pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
|
||||
"""
|
||||
|
||||
def __init__(self, pad_mode: str = "reflect"):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
|
||||
self.pad = kernel_1d.shape[1] // 2 - 1
|
||||
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
|
||||
|
||||
def forward(self, inputs):
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
|
||||
weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
|
||||
indices = torch.arange(inputs.shape[1], device=inputs.device)
|
||||
@@ -451,14 +527,20 @@ class KDownsample2D(nn.Module):
|
||||
|
||||
|
||||
class KUpsample2D(nn.Module):
|
||||
def __init__(self, pad_mode="reflect"):
|
||||
r"""A 2D K-upsampling layer.
|
||||
|
||||
Parameters:
|
||||
pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
|
||||
"""
|
||||
|
||||
def __init__(self, pad_mode: str = "reflect"):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
|
||||
self.pad = kernel_1d.shape[1] // 2 - 1
|
||||
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
|
||||
|
||||
def forward(self, inputs):
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
|
||||
weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
|
||||
indices = torch.arange(inputs.shape[1], device=inputs.device)
|
||||
@@ -501,23 +583,23 @@ class ResnetBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
conv_shortcut=False,
|
||||
dropout=0.0,
|
||||
temb_channels=512,
|
||||
groups=32,
|
||||
groups_out=None,
|
||||
pre_norm=True,
|
||||
eps=1e-6,
|
||||
non_linearity="swish",
|
||||
skip_time_act=False,
|
||||
time_embedding_norm="default", # default, scale_shift, ada_group, spatial
|
||||
kernel=None,
|
||||
output_scale_factor=1.0,
|
||||
use_in_shortcut=None,
|
||||
up=False,
|
||||
down=False,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
conv_shortcut: bool = False,
|
||||
dropout: float = 0.0,
|
||||
temb_channels: int = 512,
|
||||
groups: int = 32,
|
||||
groups_out: Optional[int] = None,
|
||||
pre_norm: bool = True,
|
||||
eps: float = 1e-6,
|
||||
non_linearity: str = "swish",
|
||||
skip_time_act: bool = False,
|
||||
time_embedding_norm: str = "default", # default, scale_shift, ada_group, spatial
|
||||
kernel: Optional[torch.FloatTensor] = None,
|
||||
output_scale_factor: float = 1.0,
|
||||
use_in_shortcut: Optional[bool] = None,
|
||||
up: bool = False,
|
||||
down: bool = False,
|
||||
conv_shortcut_bias: bool = True,
|
||||
conv_2d_out_channels: Optional[int] = None,
|
||||
):
|
||||
@@ -534,6 +616,9 @@ class ResnetBlock2D(nn.Module):
|
||||
self.time_embedding_norm = time_embedding_norm
|
||||
self.skip_time_act = skip_time_act
|
||||
|
||||
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||
|
||||
if groups_out is None:
|
||||
groups_out = groups
|
||||
|
||||
@@ -544,13 +629,13 @@ class ResnetBlock2D(nn.Module):
|
||||
else:
|
||||
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
|
||||
self.conv1 = LoRACompatibleConv(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
if temb_channels is not None:
|
||||
if self.time_embedding_norm == "default":
|
||||
self.time_emb_proj = LoRACompatibleLinear(temb_channels, out_channels)
|
||||
self.time_emb_proj = linear_cls(temb_channels, out_channels)
|
||||
elif self.time_embedding_norm == "scale_shift":
|
||||
self.time_emb_proj = LoRACompatibleLinear(temb_channels, 2 * out_channels)
|
||||
self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)
|
||||
elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
|
||||
self.time_emb_proj = None
|
||||
else:
|
||||
@@ -567,7 +652,7 @@ class ResnetBlock2D(nn.Module):
|
||||
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
conv_2d_out_channels = conv_2d_out_channels or out_channels
|
||||
self.conv2 = LoRACompatibleConv(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.nonlinearity = get_activation(non_linearity)
|
||||
|
||||
@@ -593,7 +678,7 @@ class ResnetBlock2D(nn.Module):
|
||||
|
||||
self.conv_shortcut = None
|
||||
if self.use_in_shortcut:
|
||||
self.conv_shortcut = LoRACompatibleConv(
|
||||
self.conv_shortcut = conv_cls(
|
||||
in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
|
||||
)
|
||||
|
||||
@@ -634,12 +719,16 @@ class ResnetBlock2D(nn.Module):
|
||||
else self.downsample(hidden_states)
|
||||
)
|
||||
|
||||
hidden_states = self.conv1(hidden_states, scale)
|
||||
hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states)
|
||||
|
||||
if self.time_emb_proj is not None:
|
||||
if not self.skip_time_act:
|
||||
temb = self.nonlinearity(temb)
|
||||
temb = self.time_emb_proj(temb, scale)[:, :, None, None]
|
||||
temb = (
|
||||
self.time_emb_proj(temb, scale)[:, :, None, None]
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.time_emb_proj(temb)[:, :, None, None]
|
||||
)
|
||||
|
||||
if temb is not None and self.time_embedding_norm == "default":
|
||||
hidden_states = hidden_states + temb
|
||||
@@ -656,10 +745,12 @@ class ResnetBlock2D(nn.Module):
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states, scale)
|
||||
hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states)
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
input_tensor = self.conv_shortcut(input_tensor, scale)
|
||||
input_tensor = (
|
||||
self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor)
|
||||
)
|
||||
|
||||
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
||||
|
||||
@@ -667,7 +758,7 @@ class ResnetBlock2D(nn.Module):
|
||||
|
||||
|
||||
# unet_rl.py
|
||||
def rearrange_dims(tensor):
|
||||
def rearrange_dims(tensor: torch.Tensor) -> torch.Tensor:
|
||||
if len(tensor.shape) == 2:
|
||||
return tensor[:, :, None]
|
||||
if len(tensor.shape) == 3:
|
||||
@@ -681,16 +772,24 @@ def rearrange_dims(tensor):
|
||||
class Conv1dBlock(nn.Module):
|
||||
"""
|
||||
Conv1d --> GroupNorm --> Mish
|
||||
|
||||
Parameters:
|
||||
inp_channels (`int`): Number of input channels.
|
||||
out_channels (`int`): Number of output channels.
|
||||
kernel_size (`int` or `tuple`): Size of the convolving kernel.
|
||||
n_groups (`int`, default `8`): Number of groups to separate the channels into.
|
||||
"""
|
||||
|
||||
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
||||
def __init__(
|
||||
self, inp_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], n_groups: int = 8
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.group_norm = nn.GroupNorm(n_groups, out_channels)
|
||||
self.mish = nn.Mish()
|
||||
|
||||
def forward(self, inputs):
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
intermediate_repr = self.conv1d(inputs)
|
||||
intermediate_repr = rearrange_dims(intermediate_repr)
|
||||
intermediate_repr = self.group_norm(intermediate_repr)
|
||||
@@ -701,7 +800,19 @@ class Conv1dBlock(nn.Module):
|
||||
|
||||
# unet_rl.py
|
||||
class ResidualTemporalBlock1D(nn.Module):
|
||||
def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
|
||||
"""
|
||||
Residual 1D block with temporal convolutions.
|
||||
|
||||
Parameters:
|
||||
inp_channels (`int`): Number of input channels.
|
||||
out_channels (`int`): Number of output channels.
|
||||
embed_dim (`int`): Embedding dimension.
|
||||
kernel_size (`int` or `tuple`): Size of the convolving kernel.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, inp_channels: int, out_channels: int, embed_dim: int, kernel_size: Union[int, Tuple[int, int]] = 5
|
||||
):
|
||||
super().__init__()
|
||||
self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
|
||||
self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
|
||||
@@ -713,7 +824,7 @@ class ResidualTemporalBlock1D(nn.Module):
|
||||
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, inputs, t):
|
||||
def forward(self, inputs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
inputs : [ batch_size x inp_channels x horizon ]
|
||||
@@ -729,7 +840,9 @@ class ResidualTemporalBlock1D(nn.Module):
|
||||
return out + self.residual_conv(inputs)
|
||||
|
||||
|
||||
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
|
||||
def upsample_2d(
|
||||
hidden_states: torch.Tensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
|
||||
) -> torch.Tensor:
|
||||
r"""Upsample2D a batch of 2D images with the given filter.
|
||||
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
|
||||
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
|
||||
@@ -766,7 +879,9 @@ def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
|
||||
return output
|
||||
|
||||
|
||||
def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
|
||||
def downsample_2d(
|
||||
hidden_states: torch.Tensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
|
||||
) -> torch.Tensor:
|
||||
r"""Downsample2D a batch of 2D images with the given filter.
|
||||
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
|
||||
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
|
||||
@@ -801,7 +916,9 @@ def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
|
||||
return output
|
||||
|
||||
|
||||
def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
|
||||
def upfirdn2d_native(
|
||||
tensor: torch.Tensor, kernel: torch.Tensor, up: int = 1, down: int = 1, pad: Tuple[int, int] = (0, 0)
|
||||
) -> torch.Tensor:
|
||||
up_x = up_y = up
|
||||
down_x = down_y = down
|
||||
pad_x0 = pad_y0 = pad[0]
|
||||
@@ -849,9 +966,14 @@ class TemporalConvLayer(nn.Module):
|
||||
"""
|
||||
Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
|
||||
https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016
|
||||
|
||||
Parameters:
|
||||
in_dim (`int`): Number of input channels.
|
||||
out_dim (`int`): Number of output channels.
|
||||
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
|
||||
"""
|
||||
|
||||
def __init__(self, in_dim, out_dim=None, dropout=0.0):
|
||||
def __init__(self, in_dim: int, out_dim: Optional[int] = None, dropout: float = 0.0):
|
||||
super().__init__()
|
||||
out_dim = out_dim or in_dim
|
||||
self.in_dim = in_dim
|
||||
@@ -884,7 +1006,7 @@ class TemporalConvLayer(nn.Module):
|
||||
nn.init.zeros_(self.conv4[-1].weight)
|
||||
nn.init.zeros_(self.conv4[-1].bias)
|
||||
|
||||
def forward(self, hidden_states, num_frames=1):
|
||||
def forward(self, hidden_states: torch.Tensor, num_frames: int = 1) -> torch.Tensor:
|
||||
hidden_states = (
|
||||
hidden_states[None, :].reshape((-1, num_frames) + hidden_states.shape[1:]).permute(0, 2, 1, 3, 4)
|
||||
)
|
||||
|
||||
@@ -20,7 +20,7 @@ from torch import nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..models.embeddings import ImagePositionalEmbeddings
|
||||
from ..utils import BaseOutput, deprecate
|
||||
from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate
|
||||
from .attention import BasicTransformerBlock
|
||||
from .embeddings import PatchEmbed
|
||||
from .lora import LoRACompatibleConv, LoRACompatibleLinear
|
||||
@@ -100,6 +100,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
||||
|
||||
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
||||
# Define whether input is continuous or discrete depending on configuration
|
||||
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
|
||||
@@ -139,9 +142,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
if use_linear_projection:
|
||||
self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
|
||||
self.proj_in = linear_cls(in_channels, inner_dim)
|
||||
else:
|
||||
self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||
elif self.is_input_vectorized:
|
||||
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
|
||||
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
|
||||
@@ -197,9 +200,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
if self.is_input_continuous:
|
||||
# TODO: should use out_channels for continuous projections
|
||||
if use_linear_projection:
|
||||
self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
|
||||
self.proj_out = linear_cls(inner_dim, in_channels)
|
||||
else:
|
||||
self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
elif self.is_input_vectorized:
|
||||
self.norm_out = nn.LayerNorm(inner_dim)
|
||||
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
|
||||
@@ -235,6 +238,14 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
||||
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
||||
`AdaLayerZeroNorm`.
|
||||
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
attention_mask ( `torch.Tensor`, *optional*):
|
||||
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||
negative values to the attention scores corresponding to "discard" tokens.
|
||||
encoder_attention_mask ( `torch.Tensor`, *optional*):
|
||||
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
||||
|
||||
@@ -284,13 +295,21 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = self.proj_in(hidden_states, scale=lora_scale)
|
||||
hidden_states = (
|
||||
self.proj_in(hidden_states, scale=lora_scale)
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.proj_in(hidden_states)
|
||||
)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
else:
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
hidden_states = self.proj_in(hidden_states, scale=lora_scale)
|
||||
hidden_states = (
|
||||
self.proj_in(hidden_states, scale=lora_scale)
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.proj_in(hidden_states)
|
||||
)
|
||||
|
||||
elif self.is_input_vectorized:
|
||||
hidden_states = self.latent_image_embedding(hidden_states)
|
||||
@@ -326,9 +345,17 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
if self.is_input_continuous:
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
hidden_states = self.proj_out(hidden_states, scale=lora_scale)
|
||||
hidden_states = (
|
||||
self.proj_out(hidden_states, scale=lora_scale)
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.proj_out(hidden_states)
|
||||
)
|
||||
else:
|
||||
hidden_states = self.proj_out(hidden_states, scale=lora_scale)
|
||||
hidden_states = (
|
||||
self.proj_out(hidden_states, scale=lora_scale)
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.proj_out(hidden_states)
|
||||
)
|
||||
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
output = hidden_states + residual
|
||||
|
||||
@@ -128,6 +128,12 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
|
||||
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
||||
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
||||
`AdaLayerZeroNorm`.
|
||||
num_frames (`int`, *optional*, defaults to 1):
|
||||
The number of frames to be processed per batch. This is used to reshape the hidden states.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
@@ -19,6 +19,7 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..utils import is_torch_version, logging
|
||||
from ..utils.torch_utils import apply_freeu
|
||||
from .activations import get_activation
|
||||
from .attention import AdaGroupNorm
|
||||
from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
|
||||
@@ -249,6 +250,7 @@ def get_up_block(
|
||||
add_upsample,
|
||||
resnet_eps,
|
||||
resnet_act_fn,
|
||||
resolution_idx=None,
|
||||
transformer_layers_per_block=1,
|
||||
num_attention_heads=None,
|
||||
resnet_groups=None,
|
||||
@@ -281,6 +283,7 @@ def get_up_block(
|
||||
out_channels=out_channels,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=temb_channels,
|
||||
resolution_idx=resolution_idx,
|
||||
dropout=dropout,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
@@ -295,6 +298,7 @@ def get_up_block(
|
||||
out_channels=out_channels,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=temb_channels,
|
||||
resolution_idx=resolution_idx,
|
||||
dropout=dropout,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
@@ -314,6 +318,7 @@ def get_up_block(
|
||||
out_channels=out_channels,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=temb_channels,
|
||||
resolution_idx=resolution_idx,
|
||||
dropout=dropout,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
@@ -337,6 +342,7 @@ def get_up_block(
|
||||
out_channels=out_channels,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=temb_channels,
|
||||
resolution_idx=resolution_idx,
|
||||
dropout=dropout,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
@@ -362,6 +368,7 @@ def get_up_block(
|
||||
out_channels=out_channels,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=temb_channels,
|
||||
resolution_idx=resolution_idx,
|
||||
dropout=dropout,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
@@ -377,6 +384,7 @@ def get_up_block(
|
||||
out_channels=out_channels,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=temb_channels,
|
||||
resolution_idx=resolution_idx,
|
||||
dropout=dropout,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
@@ -390,6 +398,7 @@ def get_up_block(
|
||||
out_channels=out_channels,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=temb_channels,
|
||||
resolution_idx=resolution_idx,
|
||||
dropout=dropout,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
@@ -402,6 +411,7 @@ def get_up_block(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
resolution_idx=resolution_idx,
|
||||
dropout=dropout,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
@@ -415,6 +425,7 @@ def get_up_block(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
resolution_idx=resolution_idx,
|
||||
dropout=dropout,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
@@ -430,6 +441,7 @@ def get_up_block(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
resolution_idx=resolution_idx,
|
||||
dropout=dropout,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
@@ -441,6 +453,7 @@ def get_up_block(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
resolution_idx=resolution_idx,
|
||||
dropout=dropout,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
@@ -453,6 +466,21 @@ def get_up_block(
|
||||
|
||||
|
||||
class AutoencoderTinyBlock(nn.Module):
|
||||
"""
|
||||
Tiny Autoencoder block used in [`AutoencoderTiny`]. It is a mini residual module consisting of plain conv + ReLU
|
||||
blocks.
|
||||
|
||||
Args:
|
||||
in_channels (`int`): The number of input channels.
|
||||
out_channels (`int`): The number of output channels.
|
||||
act_fn (`str`):
|
||||
` The activation function to use. Supported values are `"swish"`, `"mish"`, `"gelu"`, and `"relu"`.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to
|
||||
`out_channels`.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, out_channels: int, act_fn: str):
|
||||
super().__init__()
|
||||
act_fn = get_activation(act_fn)
|
||||
@@ -1993,6 +2021,7 @@ class AttnUpBlock2D(nn.Module):
|
||||
prev_output_channel: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
resolution_idx: int = None,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
@@ -2075,6 +2104,8 @@ class AttnUpBlock2D(nn.Module):
|
||||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
self.resolution_idx = resolution_idx
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
# pop res hidden states
|
||||
@@ -2103,6 +2134,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
out_channels: int,
|
||||
prev_output_channel: int,
|
||||
temb_channels: int,
|
||||
resolution_idx: int = None,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
transformer_layers_per_block: int = 1,
|
||||
@@ -2181,6 +2213,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.resolution_idx = resolution_idx
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -2194,11 +2227,30 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
is_freeu_enabled = (
|
||||
getattr(self, "s1", None)
|
||||
and getattr(self, "s2", None)
|
||||
and getattr(self, "b1", None)
|
||||
and getattr(self, "b2", None)
|
||||
)
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
|
||||
# FreeU: Only operate on the first two stages
|
||||
if is_freeu_enabled:
|
||||
hidden_states, res_hidden_states = apply_freeu(
|
||||
self.resolution_idx,
|
||||
hidden_states,
|
||||
res_hidden_states,
|
||||
s1=self.s1,
|
||||
s2=self.s2,
|
||||
b1=self.b1,
|
||||
b2=self.b2,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
@@ -2252,6 +2304,7 @@ class UpBlock2D(nn.Module):
|
||||
prev_output_channel: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
resolution_idx: int = None,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
@@ -2292,12 +2345,33 @@ class UpBlock2D(nn.Module):
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.resolution_idx = resolution_idx
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
|
||||
is_freeu_enabled = (
|
||||
getattr(self, "s1", None)
|
||||
and getattr(self, "s2", None)
|
||||
and getattr(self, "b1", None)
|
||||
and getattr(self, "b2", None)
|
||||
)
|
||||
|
||||
for resnet in self.resnets:
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
|
||||
# FreeU: Only operate on the first two stages
|
||||
if is_freeu_enabled:
|
||||
hidden_states, res_hidden_states = apply_freeu(
|
||||
self.resolution_idx,
|
||||
hidden_states,
|
||||
res_hidden_states,
|
||||
s1=self.s1,
|
||||
s2=self.s2,
|
||||
b1=self.b1,
|
||||
b2=self.b2,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
@@ -2331,6 +2405,7 @@ class UpDecoderBlock2D(nn.Module):
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
resolution_idx: int = None,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
@@ -2370,6 +2445,8 @@ class UpDecoderBlock2D(nn.Module):
|
||||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
self.resolution_idx = resolution_idx
|
||||
|
||||
def forward(self, hidden_states, temb=None, scale: float = 1.0):
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states, temb=temb, scale=scale)
|
||||
@@ -2386,6 +2463,7 @@ class AttnUpDecoderBlock2D(nn.Module):
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
resolution_idx: int = None,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
@@ -2449,6 +2527,8 @@ class AttnUpDecoderBlock2D(nn.Module):
|
||||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
self.resolution_idx = resolution_idx
|
||||
|
||||
def forward(self, hidden_states, temb=None, scale: float = 1.0):
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
hidden_states = resnet(hidden_states, temb=temb, scale=scale)
|
||||
@@ -2469,6 +2549,7 @@ class AttnSkipUpBlock2D(nn.Module):
|
||||
prev_output_channel: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
resolution_idx: int = None,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
@@ -2553,6 +2634,8 @@ class AttnSkipUpBlock2D(nn.Module):
|
||||
self.skip_norm = None
|
||||
self.act = None
|
||||
|
||||
self.resolution_idx = resolution_idx
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None, scale: float = 1.0):
|
||||
for resnet in self.resnets:
|
||||
# pop res hidden states
|
||||
@@ -2589,6 +2672,7 @@ class SkipUpBlock2D(nn.Module):
|
||||
prev_output_channel: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
resolution_idx: int = None,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
@@ -2651,6 +2735,8 @@ class SkipUpBlock2D(nn.Module):
|
||||
self.skip_norm = None
|
||||
self.act = None
|
||||
|
||||
self.resolution_idx = resolution_idx
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None, scale: float = 1.0):
|
||||
for resnet in self.resnets:
|
||||
# pop res hidden states
|
||||
@@ -2684,6 +2770,7 @@ class ResnetUpsampleBlock2D(nn.Module):
|
||||
prev_output_channel: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
resolution_idx: int = None,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
@@ -2743,6 +2830,7 @@ class ResnetUpsampleBlock2D(nn.Module):
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.resolution_idx = resolution_idx
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
|
||||
for resnet in self.resnets:
|
||||
@@ -2784,6 +2872,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
||||
out_channels: int,
|
||||
prev_output_channel: int,
|
||||
temb_channels: int,
|
||||
resolution_idx: int = None,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
@@ -2873,6 +2962,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.resolution_idx = resolution_idx
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -2947,6 +3037,7 @@ class KUpBlock2D(nn.Module):
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
resolution_idx: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 5,
|
||||
resnet_eps: float = 1e-5,
|
||||
@@ -2988,6 +3079,7 @@ class KUpBlock2D(nn.Module):
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.resolution_idx = resolution_idx
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[-1]
|
||||
@@ -3027,6 +3119,7 @@ class KCrossAttnUpBlock2D(nn.Module):
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
resolution_idx: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 4,
|
||||
resnet_eps: float = 1e-5,
|
||||
@@ -3104,6 +3197,7 @@ class KCrossAttnUpBlock2D(nn.Module):
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.resolution_idx = resolution_idx
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -39,6 +39,9 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
|
||||
Whether to add downsampling layer before each final output
|
||||
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
|
||||
enable memory efficient attention https://arxiv.org/abs/2112.05682
|
||||
split_head_dim (`bool`, *optional*, defaults to `False`):
|
||||
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
|
||||
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
"""
|
||||
@@ -51,6 +54,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
|
||||
use_linear_projection: bool = False
|
||||
only_cross_attention: bool = False
|
||||
use_memory_efficient_attention: bool = False
|
||||
split_head_dim: bool = False
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
transformer_layers_per_block: int = 1
|
||||
|
||||
@@ -77,6 +81,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
|
||||
use_linear_projection=self.use_linear_projection,
|
||||
only_cross_attention=self.only_cross_attention,
|
||||
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
||||
split_head_dim=self.split_head_dim,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
attentions.append(attn_block)
|
||||
@@ -179,6 +184,9 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
|
||||
Whether to add upsampling layer before each final output
|
||||
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
|
||||
enable memory efficient attention https://arxiv.org/abs/2112.05682
|
||||
split_head_dim (`bool`, *optional*, defaults to `False`):
|
||||
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
|
||||
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
"""
|
||||
@@ -192,6 +200,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
|
||||
use_linear_projection: bool = False
|
||||
only_cross_attention: bool = False
|
||||
use_memory_efficient_attention: bool = False
|
||||
split_head_dim: bool = False
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
transformer_layers_per_block: int = 1
|
||||
|
||||
@@ -219,6 +228,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
|
||||
use_linear_projection=self.use_linear_projection,
|
||||
only_cross_attention=self.only_cross_attention,
|
||||
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
||||
split_head_dim=self.split_head_dim,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
attentions.append(attn_block)
|
||||
@@ -323,6 +333,9 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
|
||||
Number of attention heads of each spatial transformer block
|
||||
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
|
||||
enable memory efficient attention https://arxiv.org/abs/2112.05682
|
||||
split_head_dim (`bool`, *optional*, defaults to `False`):
|
||||
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
|
||||
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
|
||||
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
||||
Parameters `dtype`
|
||||
"""
|
||||
@@ -332,6 +345,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
|
||||
num_attention_heads: int = 1
|
||||
use_linear_projection: bool = False
|
||||
use_memory_efficient_attention: bool = False
|
||||
split_head_dim: bool = False
|
||||
dtype: jnp.dtype = jnp.float32
|
||||
transformer_layers_per_block: int = 1
|
||||
|
||||
@@ -356,6 +370,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
|
||||
depth=self.transformer_layers_per_block,
|
||||
use_linear_projection=self.use_linear_projection,
|
||||
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
||||
split_head_dim=self.split_head_dim,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
attentions.append(attn_block)
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch.utils.checkpoint
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import UNet2DConditionLoadersMixin
|
||||
from ..utils import BaseOutput, logging
|
||||
from ..utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
|
||||
from .activations import get_activation
|
||||
from .attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
@@ -542,6 +542,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resolution_idx=i,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=reversed_cross_attention_dim[i],
|
||||
num_attention_heads=reversed_num_attention_heads[i],
|
||||
@@ -613,7 +614,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
|
||||
return processors
|
||||
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(
|
||||
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
||||
):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -637,9 +640,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
module.set_processor(processor, _remove_lora=_remove_lora)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
@@ -660,7 +663,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
self.set_attn_processor(processor, _remove_lora=True)
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
r"""
|
||||
@@ -731,6 +734,38 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def enable_freeu(self, s1, s2, b1, b2):
|
||||
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
||||
|
||||
The suffixes after the scaling factors represent the stage blocks where they are being applied.
|
||||
|
||||
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
|
||||
are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
||||
|
||||
Args:
|
||||
s1 (`float`):
|
||||
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate the "oversmoothing effect" in the enhanced denoising process.
|
||||
s2 (`float`):
|
||||
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate the "oversmoothing effect" in the enhanced denoising process.
|
||||
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
||||
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
||||
"""
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
setattr(upsample_block, "s1", s1)
|
||||
setattr(upsample_block, "s2", s2)
|
||||
setattr(upsample_block, "b1", b1)
|
||||
setattr(upsample_block, "b2", b2)
|
||||
|
||||
def disable_freeu(self):
|
||||
"""Disables the FreeU mechanism."""
|
||||
freeu_keys = {"s1", "s2", "b1", "b2"}
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
for k in freeu_keys:
|
||||
if hasattr(upsample_block, k) or getattr(upsample_block, k) is not None:
|
||||
setattr(upsample_block, k, None)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
@@ -755,6 +790,26 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
||||
encoder_hidden_states (`torch.FloatTensor`):
|
||||
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
||||
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
||||
timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
|
||||
through the `self.time_embedding` layer to obtain the timestep embeddings.
|
||||
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||
negative values to the attention scores corresponding to "discard" tokens.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
added_cond_kwargs: (`dict`, *optional*):
|
||||
A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
|
||||
are passed along to the UNet blocks.
|
||||
down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
|
||||
A tuple of tensors that if specified are added to the residuals of down unet blocks.
|
||||
mid_block_additional_residual: (`torch.Tensor`, *optional*):
|
||||
A tensor that if specified is added to the residual of the middle unet block.
|
||||
encoder_attention_mask (`torch.Tensor`):
|
||||
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
|
||||
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
||||
@@ -940,6 +995,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
|
||||
# 3. down
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
|
||||
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
|
||||
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
|
||||
@@ -1039,6 +1097,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self)
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
|
||||
@@ -92,6 +92,9 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
||||
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
|
||||
Enable memory efficient attention as described [here](https://arxiv.org/abs/2112.05682).
|
||||
split_head_dim (`bool`, *optional*, defaults to `False`):
|
||||
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
|
||||
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
|
||||
"""
|
||||
|
||||
sample_size: int = 32
|
||||
@@ -116,13 +119,14 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
flip_sin_to_cos: bool = True
|
||||
freq_shift: int = 0
|
||||
use_memory_efficient_attention: bool = False
|
||||
split_head_dim: bool = False
|
||||
transformer_layers_per_block: Union[int, Tuple[int]] = 1
|
||||
addition_embed_type: Optional[str] = None
|
||||
addition_time_embed_dim: Optional[int] = None
|
||||
addition_embed_type_num_heads: int = 64
|
||||
projection_class_embeddings_input_dim: Optional[int] = None
|
||||
|
||||
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.Array) -> FrozenDict:
|
||||
# init input tensors
|
||||
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
|
||||
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
|
||||
@@ -134,8 +138,18 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
|
||||
added_cond_kwargs = None
|
||||
if self.addition_embed_type == "text_time":
|
||||
# TODO: how to get this from the config? It's no longer cross_attention_dim
|
||||
text_embeds_dim = 1280
|
||||
# we retrieve the expected `text_embeds_dim` by first checking if the architecture is a refiner
|
||||
# or non-refiner architecture and then by "reverse-computing" from `projection_class_embeddings_input_dim`
|
||||
is_refiner = (
|
||||
5 * self.config.addition_time_embed_dim + self.config.cross_attention_dim
|
||||
== self.config.projection_class_embeddings_input_dim
|
||||
)
|
||||
num_micro_conditions = 5 if is_refiner else 6
|
||||
|
||||
text_embeds_dim = self.config.projection_class_embeddings_input_dim - (
|
||||
num_micro_conditions * self.config.addition_time_embed_dim
|
||||
)
|
||||
|
||||
time_ids_channels = self.projection_class_embeddings_input_dim - text_embeds_dim
|
||||
time_ids_dims = time_ids_channels // self.addition_time_embed_dim
|
||||
added_cond_kwargs = {
|
||||
@@ -221,6 +235,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
use_linear_projection=self.use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
||||
split_head_dim=self.split_head_dim,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
else:
|
||||
@@ -244,6 +259,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
transformer_layers_per_block=transformer_layers_per_block[-1],
|
||||
use_linear_projection=self.use_linear_projection,
|
||||
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
||||
split_head_dim=self.split_head_dim,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
@@ -274,6 +290,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
use_linear_projection=self.use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
||||
split_head_dim=self.split_head_dim,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
else:
|
||||
@@ -317,6 +334,13 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
sample (`jnp.ndarray`): (batch, channel, height, width) noisy inputs tensor
|
||||
timestep (`jnp.ndarray` or `float` or `int`): timesteps
|
||||
encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states
|
||||
added_cond_kwargs: (`dict`, *optional*):
|
||||
A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
|
||||
are passed along to the UNet blocks.
|
||||
down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
|
||||
A tuple of tensors that if specified are added to the residuals of down unet blocks.
|
||||
mid_block_additional_residual: (`torch.Tensor`, *optional*):
|
||||
A tensor that if specified is added to the residual of the middle unet block.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
|
||||
plain tuple.
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ..utils.torch_utils import apply_freeu
|
||||
from .resnet import Downsample2D, ResnetBlock2D, TemporalConvLayer, Upsample2D
|
||||
from .transformer_2d import Transformer2DModel
|
||||
from .transformer_temporal import TransformerTemporalModel
|
||||
@@ -87,6 +88,7 @@ def get_up_block(
|
||||
resnet_eps,
|
||||
resnet_act_fn,
|
||||
num_attention_heads,
|
||||
resolution_idx=None,
|
||||
resnet_groups=None,
|
||||
cross_attention_dim=None,
|
||||
dual_cross_attention=False,
|
||||
@@ -107,6 +109,7 @@ def get_up_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
resolution_idx=resolution_idx,
|
||||
)
|
||||
elif up_block_type == "CrossAttnUpBlock3D":
|
||||
if cross_attention_dim is None:
|
||||
@@ -128,6 +131,7 @@ def get_up_block(
|
||||
only_cross_attention=only_cross_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
resolution_idx=resolution_idx,
|
||||
)
|
||||
raise ValueError(f"{up_block_type} does not exist.")
|
||||
|
||||
@@ -496,6 +500,7 @@ class CrossAttnUpBlock3D(nn.Module):
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
resolution_idx=None,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -565,6 +570,7 @@ class CrossAttnUpBlock3D(nn.Module):
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.resolution_idx = resolution_idx
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -577,6 +583,13 @@ class CrossAttnUpBlock3D(nn.Module):
|
||||
num_frames=1,
|
||||
cross_attention_kwargs=None,
|
||||
):
|
||||
is_freeu_enabled = (
|
||||
getattr(self, "s1", None)
|
||||
and getattr(self, "s2", None)
|
||||
and getattr(self, "b1", None)
|
||||
and getattr(self, "b2", None)
|
||||
)
|
||||
|
||||
# TODO(Patrick, William) - attention mask is not used
|
||||
for resnet, temp_conv, attn, temp_attn in zip(
|
||||
self.resnets, self.temp_convs, self.attentions, self.temp_attentions
|
||||
@@ -584,6 +597,19 @@ class CrossAttnUpBlock3D(nn.Module):
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
|
||||
# FreeU: Only operate on the first two stages
|
||||
if is_freeu_enabled:
|
||||
hidden_states, res_hidden_states = apply_freeu(
|
||||
self.resolution_idx,
|
||||
hidden_states,
|
||||
res_hidden_states,
|
||||
s1=self.s1,
|
||||
s2=self.s2,
|
||||
b1=self.b1,
|
||||
b2=self.b2,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
@@ -621,6 +647,7 @@ class UpBlock3D(nn.Module):
|
||||
resnet_pre_norm: bool = True,
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
resolution_idx=None,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
@@ -661,12 +688,32 @@ class UpBlock3D(nn.Module):
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.resolution_idx = resolution_idx
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, num_frames=1):
|
||||
is_freeu_enabled = (
|
||||
getattr(self, "s1", None)
|
||||
and getattr(self, "s2", None)
|
||||
and getattr(self, "b1", None)
|
||||
and getattr(self, "b2", None)
|
||||
)
|
||||
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
|
||||
# FreeU: Only operate on the first two stages
|
||||
if is_freeu_enabled:
|
||||
hidden_states, res_hidden_states = apply_freeu(
|
||||
self.resolution_idx,
|
||||
hidden_states,
|
||||
res_hidden_states,
|
||||
s1=self.s1,
|
||||
s2=self.s2,
|
||||
b1=self.b1,
|
||||
b2=self.b2,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
@@ -255,6 +255,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=reversed_num_attention_heads[i],
|
||||
dual_cross_attention=False,
|
||||
resolution_idx=i,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
@@ -366,7 +367,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(
|
||||
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
||||
):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -390,9 +393,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
module.set_processor(processor, _remove_lora=_remove_lora)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
@@ -454,12 +457,46 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
self.set_attn_processor(processor, _remove_lora=True)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.enable_freeu
|
||||
def enable_freeu(self, s1, s2, b1, b2):
|
||||
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
||||
|
||||
The suffixes after the scaling factors represent the stage blocks where they are being applied.
|
||||
|
||||
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
|
||||
are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
||||
|
||||
Args:
|
||||
s1 (`float`):
|
||||
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate the "oversmoothing effect" in the enhanced denoising process.
|
||||
s2 (`float`):
|
||||
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate the "oversmoothing effect" in the enhanced denoising process.
|
||||
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
||||
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
||||
"""
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
setattr(upsample_block, "s1", s1)
|
||||
setattr(upsample_block, "s2", s2)
|
||||
setattr(upsample_block, "b1", b1)
|
||||
setattr(upsample_block, "b2", b2)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.disable_freeu
|
||||
def disable_freeu(self):
|
||||
"""Disables the FreeU mechanism."""
|
||||
freeu_keys = {"s1", "s2", "b1", "b2"}
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
for k in freeu_keys:
|
||||
if hasattr(upsample_block, k) or getattr(upsample_block, k) is not None:
|
||||
setattr(upsample_block, k, None)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
@@ -482,6 +519,23 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
||||
encoder_hidden_states (`torch.FloatTensor`):
|
||||
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
||||
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
||||
timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
|
||||
through the `self.time_embedding` layer to obtain the timestep embeddings.
|
||||
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||
negative values to the attention scores corresponding to "discard" tokens.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
|
||||
A tuple of tensors that if specified are added to the residuals of down unet blocks.
|
||||
mid_block_additional_residual: (`torch.Tensor`, *optional*):
|
||||
A tensor that if specified is added to the residual of the middle unet block.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unet_3d_condition.UNet3DConditionOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
@@ -817,7 +817,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.Array) -> FrozenDict:
|
||||
# init input tensors
|
||||
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
|
||||
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
|
||||
|
||||
@@ -25,7 +25,14 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import deprecate, logging, replace_example_docstring
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
@@ -304,7 +311,10 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
@@ -429,6 +439,10 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
@@ -540,6 +554,32 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
|
||||
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
|
||||
|
||||
The suffixes after the scaling factors represent the stages where they are being applied.
|
||||
|
||||
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
|
||||
that are known to work well for different pipelines such as Alt Diffusion v1, v2, and Alt Diffusion XL.
|
||||
|
||||
Args:
|
||||
s1 (`float`):
|
||||
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate "oversmoothing effect" in the enhanced denoising process.
|
||||
s2 (`float`):
|
||||
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate "oversmoothing effect" in the enhanced denoising process.
|
||||
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
||||
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
||||
"""
|
||||
if not hasattr(self, "unet"):
|
||||
raise ValueError("The pipeline must have `unet` for using FreeU.")
|
||||
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
|
||||
|
||||
def disable_freeu(self):
|
||||
"""Disables the FreeU mechanism if enabled."""
|
||||
self.unet.disable_freeu()
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -635,6 +675,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
# to deal with lora scaling and other possible forward hooks
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
@@ -656,9 +697,8 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
@@ -667,7 +707,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
lora_scale=lora_scale,
|
||||
clip_skip=clip_skip,
|
||||
)
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
@@ -729,7 +769,8 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
|
||||
@@ -27,7 +27,15 @@ from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoa
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models.lora import adjust_lora_scale_text_encoder
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import PIL_INTERPOLATION, deprecate, logging, replace_example_docstring
|
||||
from ...utils import (
|
||||
PIL_INTERPOLATION,
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
@@ -302,7 +310,10 @@ class AltDiffusionImg2ImgPipeline(
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
# dynamically adjust the LoRA scale
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
|
||||
if not USE_PEFT_BACKEND:
|
||||
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
|
||||
else:
|
||||
scale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
@@ -427,6 +438,10 @@ class AltDiffusionImg2ImgPipeline(
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def run_safety_checker(self, image, device, dtype):
|
||||
@@ -580,6 +595,32 @@ class AltDiffusionImg2ImgPipeline(
|
||||
|
||||
return latents
|
||||
|
||||
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
|
||||
r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
|
||||
|
||||
The suffixes after the scaling factors represent the stages where they are being applied.
|
||||
|
||||
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
|
||||
that are known to work well for different pipelines such as Alt Diffusion v1, v2, and Alt Diffusion XL.
|
||||
|
||||
Args:
|
||||
s1 (`float`):
|
||||
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate "oversmoothing effect" in the enhanced denoising process.
|
||||
s2 (`float`):
|
||||
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
||||
mitigate "oversmoothing effect" in the enhanced denoising process.
|
||||
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
||||
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
||||
"""
|
||||
if not hasattr(self, "unet"):
|
||||
raise ValueError("The pipeline must have `unet` for using FreeU.")
|
||||
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
|
||||
|
||||
def disable_freeu(self):
|
||||
"""Disables the FreeU mechanism if enabled."""
|
||||
self.unet.disable_freeu()
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -751,7 +792,8 @@ class AltDiffusionImg2ImgPipeline(
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
|
||||
@@ -542,7 +542,8 @@ class AudioLDMPipeline(DiffusionPipeline):
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
# 8. Post-processing
|
||||
mel_spectrogram = self.decode_latents(latents)
|
||||
|
||||
@@ -538,7 +538,9 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(
|
||||
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
||||
):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -562,9 +564,9 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
module.set_processor(processor, _remove_lora=_remove_lora)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
@@ -586,7 +588,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
self.set_attn_processor(processor, _remove_lora=True)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
def set_attention_slice(self, slice_size):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user