Compare commits

...

44 Commits

Author SHA1 Message Date
sayakpaul 8f5007333b more compile compatibility. 2024-09-09 18:33:12 +05:30
sayakpaul a2c931aaeb merge main 2024-09-09 18:27:16 +05:30
YiYi Xu 8cdcdd9e32 add flux inpaint + img2img + controlnet to auto pipeline (#9367) 2024-09-06 07:14:48 -10:00
Dhruv Nair d269cc8a4e [CI] Quick fix for Cog Video Test (#9373)
update
2024-09-06 15:25:53 +05:30
Aryan 6dfa49963c [core] Freenoise memory improvements (#9262)
* update

* implement prompt interpolation

* make style

* resnet memory optimizations

* more memory optimizations; todo: refactor

* update

* update animatediff controlnet with latest changes

* refactor chunked inference changes

* remove print statements

* update

* chunk -> split

* remove changes from incorrect conflict resolution

* remove changes from incorrect conflict resolution

* add explanation of SplitInferenceModule

* update docs

* Revert "update docs"

This reverts commit c55a50a271.

* update docstring for freenoise split inference

* apply suggestions from review

* add tests

* apply suggestions from review
2024-09-06 12:51:20 +05:30
Haruya Ishikawa 5249a2666e fix one uncaught deprecation warning for accessing vae_latent_channels in VaeImagePreprocessor (#9372)
deprecation warning vae_latent_channels
2024-09-05 07:32:27 -10:00
Linoy Tsaban 55ac421f7b improve README for flux dreambooth lora (#9290)
* improve readme

* improve readme

* improve readme

* improve readme
2024-09-05 17:53:23 +05:30
Dhruv Nair 53051cf282 [CI] Update Single file Nightly Tests (#9357)
* update

* update
2024-09-05 14:33:44 +05:30
Tolga Cangöz 3000551729 Update UNet2DConditionModel's error messages (#9230)
* refactor
2024-09-04 10:49:56 -10:00
Vishnu V Jaddipal 249a9e48e8 Add Flux inpainting and Flux Img2Img (#9135)
---------

Co-authored-by: yiyixuxu <yixu310@gmail.com>
2024-09-04 10:31:43 -10:00
Fanli Lin 2ee3215949 [tests] make 2 tests device-agnostic (#9347)
* enabel on xpu

* fix style
2024-09-03 16:34:03 -10:00
Eduardo Escobar 8ecf499d8b Enable load_lora_weights for StableDiffusion3InpaintPipeline (#9330)
Enable load_lora_weights for StableDiffusion3InpaintPipeline

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-09-03 15:19:37 -10:00
YiYi Xu dcf320f293 small update on rotary embedding (#9354)
* update

* fix

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-09-03 07:18:33 -10:00
Sayak Paul 8ba90aa706 chore: add a cleaning utility to be useful during training. (#9240) 2024-09-03 15:00:17 +05:30
Aryan 9d49b45b19 [refactor] move positional embeddings to patch embed layer for CogVideoX (#9263)
* remove frame limit in cogvideox

* remove debug prints

* Update src/diffusers/models/transformers/cogvideox_transformer_3d.py

* revert pipeline; remove frame limitation

* revert transformer changes

* address review comments

* add error message

* apply suggestions from review
2024-09-03 14:45:12 +05:30
Dhruv Nair 81da2e1c95 [CI] Add option to dispatch Fast GPU tests on main (#9355)
update
2024-09-03 14:35:13 +05:30
Aryan 24053832b5 [tests] remove/speedup some low signal tests (#9285)
* remove 2 shapes from SDFunctionTesterMixin::test_vae_tiling

* combine freeu enable/disable test to reduce many inference runs

* remove low signal unet test for signature

* remove low signal embeddings test

* remove low signal progress bar test from PipelineTesterMixin

* combine ip-adapter single and multi tests to save many inferences

* fix broken tests

* Update tests/pipelines/test_pipelines_common.py

* Update tests/pipelines/test_pipelines_common.py

* add progress bar tests
2024-09-03 13:59:18 +05:30
Dhruv Nair f6f16a0c11 [CI] More Fast GPU Test Fixes (#9346)
* update

* update

* update

* update
2024-09-03 13:22:38 +05:30
Vishnu V Jaddipal 1c1ccaa03f Xlabs lora fix (#9348)
* Fix ```from_single_file``` for xl_inpaint

* Add basic flux inpaint pipeline

* style, quality, stray print

* Fix stray changes

* Add inpainting model support

* Change lora conversion for xlabs

* Fix stray changes

* Apply suggestions from code review

* style

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-09-03 10:43:43 +05:30
Dhruv Nair 007ad0e2aa [CI] More fixes for Fast GPU Tests on main (#9300)
update
2024-09-02 17:51:48 +05:30
Aryan 0e6a8403f6 [core] Support VideoToVideo with CogVideoX (#9333)
* add vid2vid pipeline for cogvideox

* make fix-copies

* update docs

* fake context parallel cache, vae encode tiling

* add test for cog vid2vid

* use video link from HF docs repo

* add copied from comments; correctly rename test class
2024-09-02 16:54:58 +05:30
Aryan af6c0fb766 [core] CogVideoX memory optimizations in VAE encode (#9340)
fake context parallel cache, vae encode tiling

(cherry picked from commit bf890bca0e)
2024-09-02 15:48:37 +05:30
YiYi Xu d8a16635f4 update runway repo for single_file (#9323)
update to a place holder
2024-08-30 08:51:21 -10:00
Aryan e417d02811 [docs] Add a note on torchao/quanto benchmarks for CogVideoX and memory-efficient inference (#9296)
* add a note on torchao/quanto benchmarks and memory-efficient inference

* apply suggestions from review

* update

* Update docs/source/en/api/pipelines/cogvideox.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/api/pipelines/cogvideox.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* add note on enable sequential cpu offload

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2024-08-30 13:53:25 +05:30
Dhruv Nair 1d4d71875b [CI] Update Hub Token on nightly tests (#9318)
update
2024-08-30 10:23:50 +05:30
YiYi Xu 61d96c3ae7 refactor rotary embedding 3: so it is not on cpu (#9307)
change get_1d_rotary to accept pos as torch tensors
2024-08-30 01:07:15 +05:30
YiYi Xu 4f495b06dc rotary embedding refactor 2: update comments, fix dtype for use_real=False (#9312)
fix notes and dtype
2024-08-28 23:31:47 -10:00
Anand Kumar 40c13fe5b4 [train_custom_diffusion.py] Fix the LR schedulers when num_train_epochs is passed in a distributed training env (#9308)
* Update train_custom_diffusion.py to fix the LR schedulers for `num_train_epochs`

* Fix saving text embeddings during safe serialization

* Fixed formatting
2024-08-29 14:23:36 +05:30
Sayak Paul 2a3fbc2cc2 [LoRA] support kohya and xlabs loras for flux. (#9295)
* support kohya lora in flux.

* format

* support xlabs

* diffusion_model prefix.

* Apply suggestions from code review

Co-authored-by: apolinário <joaopaulo.passos@gmail.com>

* empty commit.

Co-authored-by: Leommm-byte <leom20031@gmail.com>

---------

Co-authored-by: apolinário <joaopaulo.passos@gmail.com>
Co-authored-by: Leommm-byte <leom20031@gmail.com>
2024-08-29 07:41:46 +05:30
apolinário 089cf798eb Change default for guidance_scalein FLUX (#9305)
To match the original code, 7.0 is too high
2024-08-28 07:39:45 -10:00
Aryan cbc2ec8f44 AnimateDiff prompt travel (#9231)
* update

* implement prompt interpolation

* make style

* resnet memory optimizations

* more memory optimizations; todo: refactor

* update

* update animatediff controlnet with latest changes

* refactor chunked inference changes

* remove print statements

* undo memory optimization changes

* update docstrings

* fix tests

* fix pia tests

* apply suggestions from review

* add tests

* update comment
2024-08-28 14:48:12 +05:30
Frank (Haofan) Wang b5f591fea8 Update __init__.py (#9286) 2024-08-27 07:57:25 -10:00
Dhruv Nair 05b38c3c0d Fix Flux CLIP prompt embeds repeat for num_images_per_prompt > 1 (#9280)
update
2024-08-27 07:41:12 -10:00
Dhruv Nair 8f7fde5701 [CI] Update Release Tests (#9274)
* update

* update
2024-08-27 18:34:00 +05:30
Dhruv Nair a59672655b Fix Freenoise for AnimateDiff V3 checkpoint. (#9288)
update
2024-08-27 18:30:39 +05:30
Marçal Comajoan Cara 9aca79f2b8 Replace transformers.deepspeed with transformers.integrations.deepspeed (#9281)
to avoid "FutureWarning: transformers.deepspeed module is deprecated and will be removed in a future version. Please import deepspeed modules directly from transformers.integrations"

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-08-27 18:08:23 +05:30
sayakpaul 0e7204abcd rejig lora state dict to account for compiled modules. 2024-08-27 17:22:24 +05:30
Steven Liu bbcf2a8589 [docs] Add pipelines to table (#9282)
update pipelines
2024-08-27 12:15:30 +05:30
Álvaro Somoza 4cfb2164fb [IP Adapter] Fix cache_dir and local_files_only for image encoder (#9272)
initial fix
2024-08-26 09:03:08 -10:00
Linoy Tsaban c977966502 [Dreambooth flux] bug fix for dreambooth script (align with dreambooth lora) (#9257)
* fix shape

* fix prompt encoding

* style

* fix device

* add comment
2024-08-26 17:29:58 +05:30
YiYi Xu 1ca0a75567 refactor 3d rope for cogvideox (#9269)
* refactor 3d rope

* repeat -> expand
2024-08-25 11:57:12 -10:00
王奇勋 c1e6a32ae4 [Flux] Support Union ControlNet (#9175)
* refactor
---------

Co-authored-by: haofanwang <haofanwang.ai@gmail.com>
2024-08-25 00:24:21 -10:00
yangpei-comp 77b2162817 Bugfix in pipeline_kandinsky2_2_combined.py: Image type check mismatch (#9256)
Update pipeline_kandinsky2_2_combined.py

Bugfix on image type check mismatch
2024-08-23 08:38:47 -10:00
Dhruv Nair 4e66513a74 [CI] Run Fast + Fast GPU Tests on release branches. (#9255)
* update

* update
2024-08-23 19:34:37 +05:30
106 changed files with 6165 additions and 691 deletions
+5 -5
View File
@@ -79,7 +79,7 @@ jobs:
python utils/print_env.py
- name: Pipeline CUDA Test
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
@@ -139,7 +139,7 @@ jobs:
- name: Run nightly PyTorch CUDA tests for non-pipeline modules
if: ${{ matrix.module != 'examples'}}
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
@@ -152,7 +152,7 @@ jobs:
- name: Run nightly example tests with Torch
if: ${{ matrix.module == 'examples' }}
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
@@ -209,7 +209,7 @@ jobs:
- name: Run nightly Flax TPU tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
python -m pytest -n 0 \
-s -v -k "Flax" \
@@ -264,7 +264,7 @@ jobs:
- name: Run Nightly ONNXRuntime CUDA tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "Onnx" \
+2 -1
View File
@@ -1,6 +1,7 @@
name: Slow Tests on main
name: Fast GPU Tests on main
on:
workflow_dispatch:
push:
branches:
- main
+389
View File
@@ -0,0 +1,389 @@
# Duplicate workflow to push_tests.yml that is meant to run on release/patch branches as a final check
# Creating a duplicate workflow here is simpler than adding complex path/branch parsing logic to push_tests.yml
# Needs to be updated if push_tests.yml updated
name: (Release) Fast GPU Tests on main
on:
push:
branches:
- "v*.*.*-release"
- "v*.*.*-patch"
env:
DIFFUSERS_IS_CI: yes
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
PYTEST_TIMEOUT: 600
PIPELINE_USAGE_CUTOFF: 50000
jobs:
setup_torch_cuda_pipeline_matrix:
name: Setup Torch Pipelines CUDA Slow Tests Matrix
runs-on:
group: aws-general-8-plus
container:
image: diffusers/diffusers-pytorch-cpu
outputs:
pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
- name: Environment
run: |
python utils/print_env.py
- name: Fetch Pipeline Matrix
id: fetch_pipeline_matrix
run: |
matrix=$(python utils/fetch_torch_cuda_pipeline_test_matrix.py)
echo $matrix
echo "pipeline_test_matrix=$matrix" >> $GITHUB_OUTPUT
- name: Pipeline Tests Artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v2
with:
name: test-pipelines.json
path: reports
torch_pipelines_cuda_tests:
name: Torch Pipelines CUDA Tests
needs: setup_torch_cuda_pipeline_matrix
strategy:
fail-fast: false
max-parallel: 8
matrix:
module: ${{ fromJson(needs.setup_torch_cuda_pipeline_matrix.outputs.pipeline_test_matrix) }}
runs-on:
group: aws-g4dn-2xlarge
container:
image: diffusers/diffusers-pytorch-cuda
options: --shm-size "16gb" --ipc host --gpus 0
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: NVIDIA-SMI
run: |
nvidia-smi
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
- name: Environment
run: |
python utils/print_env.py
- name: Slow PyTorch CUDA checkpoint tests on Ubuntu
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_pipeline_${{ matrix.module }}_cuda \
tests/pipelines/${{ matrix.module }}
- name: Failure short reports
if: ${{ failure() }}
run: |
cat reports/tests_pipeline_${{ matrix.module }}_cuda_stats.txt
cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v2
with:
name: pipeline_${{ matrix.module }}_test_reports
path: reports
torch_cuda_tests:
name: Torch CUDA Tests
runs-on:
group: aws-g4dn-2xlarge
container:
image: diffusers/diffusers-pytorch-cuda
options: --shm-size "16gb" --ipc host --gpus 0
defaults:
run:
shell: bash
strategy:
fail-fast: false
max-parallel: 2
matrix:
module: [models, schedulers, lora, others, single_file]
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
python -m uv pip install peft@git+https://github.com/huggingface/peft.git
- name: Environment
run: |
python utils/print_env.py
- name: Run PyTorch CUDA tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
CUBLAS_WORKSPACE_CONFIG: :16:8
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "not Flax and not Onnx" \
--make-reports=tests_torch_cuda \
tests/${{ matrix.module }}
- name: Failure short reports
if: ${{ failure() }}
run: |
cat reports/tests_torch_cuda_stats.txt
cat reports/tests_torch_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v2
with:
name: torch_cuda_test_reports
path: reports
flax_tpu_tests:
name: Flax TPU Tests
runs-on: docker-tpu
container:
image: diffusers/diffusers-flax-tpu
options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ --privileged
defaults:
run:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
- name: Environment
run: |
python utils/print_env.py
- name: Run slow Flax TPU tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
python -m pytest -n 0 \
-s -v -k "Flax" \
--make-reports=tests_flax_tpu \
tests/
- name: Failure short reports
if: ${{ failure() }}
run: |
cat reports/tests_flax_tpu_stats.txt
cat reports/tests_flax_tpu_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v2
with:
name: flax_tpu_test_reports
path: reports
onnx_cuda_tests:
name: ONNX CUDA Tests
runs-on:
group: aws-g4dn-2xlarge
container:
image: diffusers/diffusers-onnxruntime-cuda
options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ --gpus 0
defaults:
run:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
- name: Environment
run: |
python utils/print_env.py
- name: Run slow ONNXRuntime CUDA tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "Onnx" \
--make-reports=tests_onnx_cuda \
tests/
- name: Failure short reports
if: ${{ failure() }}
run: |
cat reports/tests_onnx_cuda_stats.txt
cat reports/tests_onnx_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v2
with:
name: onnx_cuda_test_reports
path: reports
run_torch_compile_tests:
name: PyTorch Compile CUDA tests
runs-on:
group: aws-g4dn-2xlarge
container:
image: diffusers/diffusers-pytorch-compile-cuda
options: --gpus 0 --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: NVIDIA-SMI
run: |
nvidia-smi
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test,training]
- name: Environment
run: |
python utils/print_env.py
- name: Run example tests on GPU
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
RUN_COMPILE: yes
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_compile_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v2
with:
name: torch_compile_test_reports
path: reports
run_xformers_tests:
name: PyTorch xformers CUDA tests
runs-on:
group: aws-g4dn-2xlarge
container:
image: diffusers/diffusers-pytorch-xformers-cuda
options: --gpus 0 --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: NVIDIA-SMI
run: |
nvidia-smi
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test,training]
- name: Environment
run: |
python utils/print_env.py
- name: Run example tests on GPU
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "xformers" --make-reports=tests_torch_xformers_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_xformers_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v2
with:
name: torch_xformers_test_reports
path: reports
run_examples_tests:
name: Examples PyTorch CUDA tests on Ubuntu
runs-on:
group: aws-g4dn-2xlarge
container:
image: diffusers/diffusers-pytorch-cuda
options: --gpus 0 --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: NVIDIA-SMI
run: |
nvidia-smi
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test,training]
- name: Environment
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python utils/print_env.py
- name: Run example tests on GPU
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install timm
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v --make-reports=examples_torch_cuda examples/
- name: Failure short reports
if: ${{ failure() }}
run: |
cat reports/examples_torch_cuda_stats.txt
cat reports/examples_torch_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v2
with:
name: examples_test_reports
path: reports
+4
View File
@@ -226,6 +226,8 @@
- sections:
- local: api/models/controlnet
title: ControlNetModel
- local: api/models/controlnet_flux
title: FluxControlNetModel
- local: api/models/controlnet_hunyuandit
title: HunyuanDiT2DControlNetModel
- local: api/models/controlnet_sd3
@@ -320,6 +322,8 @@
title: Consistency Models
- local: api/pipelines/controlnet
title: ControlNet
- local: api/pipelines/controlnet_flux
title: ControlNet with Flux.1
- local: api/pipelines/controlnet_hunyuandit
title: ControlNet with Hunyuan-DiT
- local: api/pipelines/controlnet_sd3
@@ -0,0 +1,45 @@
<!--Copyright 2024 The HuggingFace Team and The InstantX Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# FluxControlNetModel
FluxControlNetModel is an implementation of ControlNet for Flux.1.
The ControlNet model was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, Maneesh Agrawala. It provides a greater degree of control over text-to-image generation by conditioning the model on additional inputs such as edge maps, depth maps, segmentation maps, and keypoints for pose detection.
The abstract from the paper is:
*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*
## Loading from the original format
By default the [`FluxControlNetModel`] should be loaded with [`~ModelMixin.from_pretrained`].
```py
from diffusers import FluxControlNetPipeline
from diffusers.models import FluxControlNetModel, FluxMultiControlNetModel
controlnet = FluxControlNetModel.from_pretrained("InstantX/FLUX.1-dev-Controlnet-Canny")
pipe = FluxControlNetPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", controlnet=controlnet)
controlnet = FluxControlNetModel.from_pretrained("InstantX/FLUX.1-dev-Controlnet-Canny")
controlnet = FluxMultiControlNetModel([controlnet])
pipe = FluxControlNetPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", controlnet=controlnet)
```
## FluxControlNetModel
[[autodoc]] FluxControlNetModel
## FluxControlNetOutput
[[autodoc]] models.controlnet_flux.FluxControlNetOutput
+18 -1
View File
@@ -77,16 +77,33 @@ CogVideoX-2b requires about 19 GB of GPU memory to decode 49 frames (6 seconds o
- `pipe.enable_model_cpu_offload()`:
- Without enabling cpu offloading, memory usage is `33 GB`
- With enabling cpu offloading, memory usage is `19 GB`
- `pipe.enable_sequential_cpu_offload()`:
- Similar to `enable_model_cpu_offload` but can significantly reduce memory usage at the cost of slow inference
- When enabled, memory usage is under `4 GB`
- `pipe.vae.enable_tiling()`:
- With enabling cpu offloading and tiling, memory usage is `11 GB`
- `pipe.vae.enable_slicing()`
### Quantized inference
[torchao](https://github.com/pytorch/ao) and [optimum-quanto](https://github.com/huggingface/optimum-quanto/) can be used to quantize the text encoder, transformer and VAE modules to lower the memory requirements. This makes it possible to run the model on a free-tier T4 Colab or lower VRAM GPUs!
It is also worth noting that torchao quantization is fully compatible with [torch.compile](/optimization/torch2.0#torchcompile), which allows for much faster inference speed. Additionally, models can be serialized and stored in a quantized datatype to save disk space with torchao. Find examples and benchmarks in the gists below.
- [torchao](https://gist.github.com/a-r-r-o-w/4d9732d17412888c885480c6521a9897)
- [quanto](https://gist.github.com/a-r-r-o-w/31be62828b00a9292821b85c1017effa)
## CogVideoXPipeline
[[autodoc]] CogVideoXPipeline
- all
- __call__
## CogVideoXVideoToVideoPipeline
[[autodoc]] CogVideoXVideoToVideoPipeline
- all
- __call__
## CogVideoXPipelineOutput
[[autodoc]] pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipelineOutput
[[autodoc]] pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput
@@ -0,0 +1,48 @@
<!--Copyright 2024 The HuggingFace Team and The InstantX Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# ControlNet with Flux.1
FluxControlNetPipeline is an implementation of ControlNet for Flux.1.
ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala.
With a ControlNet model, you can provide an additional control image to condition and control Stable Diffusion generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process.
The abstract from the paper is:
*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*
This controlnet code is implemented by [The InstantX Team](https://huggingface.co/InstantX). You can find pre-trained checkpoints for Flux-ControlNet in the table below:
| ControlNet type | Developer | Link |
| -------- | ---------- | ---- |
| Canny | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny) |
| Depth | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Depth) |
| Union | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union) |
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## FluxControlNetPipeline
[[autodoc]] FluxControlNetPipeline
- all
- __call__
## FluxPipelineOutput
[[autodoc]] pipelines.flux.pipeline_output.FluxPipelineOutput
+12
View File
@@ -163,3 +163,15 @@ image.save("flux-fp8-dev.png")
[[autodoc]] FluxPipeline
- all
- __call__
## FluxImg2ImgPipeline
[[autodoc]] FluxImg2ImgPipeline
- all
- __call__
## FluxInpaintPipeline
[[autodoc]] FluxInpaintPipeline
- all
- __call__
+17 -16
View File
@@ -30,63 +30,64 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
| Pipeline | Tasks |
|---|---|
| [AltDiffusion](alt_diffusion) | image2image |
| [aMUSEd](amused) | text2image |
| [AnimateDiff](animatediff) | text2video |
| [Attend-and-Excite](attend_and_excite) | text2image |
| [Audio Diffusion](audio_diffusion) | image2audio |
| [AudioLDM](audioldm) | text2audio |
| [AudioLDM2](audioldm2) | text2audio |
| [AuraFlow](auraflow) | text2image |
| [BLIP Diffusion](blip_diffusion) | text2image |
| [CogVideoX](cogvideox) | text2video |
| [Consistency Models](consistency_models) | unconditional image generation |
| [ControlNet](controlnet) | text2image, image2image, inpainting |
| [ControlNet with Flux.1](controlnet_flux) | text2image |
| [ControlNet with Hunyuan-DiT](controlnet_hunyuandit) | text2image |
| [ControlNet with Stable Diffusion 3](controlnet_sd3) | text2image |
| [ControlNet with Stable Diffusion XL](controlnet_sdxl) | text2image |
| [ControlNet-XS](controlnetxs) | text2image |
| [ControlNet-XS with Stable Diffusion XL](controlnetxs_sdxl) | text2image |
| [Cycle Diffusion](cycle_diffusion) | image2image |
| [Dance Diffusion](dance_diffusion) | unconditional audio generation |
| [DDIM](ddim) | unconditional image generation |
| [DDPM](ddpm) | unconditional image generation |
| [DeepFloyd IF](deepfloyd_if) | text2image, image2image, inpainting, super-resolution |
| [DiffEdit](diffedit) | inpainting |
| [DiT](dit) | text2image |
| [GLIGEN](stable_diffusion/gligen) | text2image |
| [Flux](flux) | text2image |
| [Hunyuan-DiT](hunyuandit) | text2image |
| [I2VGen-XL](i2vgenxl) | text2video |
| [InstructPix2Pix](pix2pix) | image editing |
| [Kandinsky 2.1](kandinsky) | text2image, image2image, inpainting, interpolation |
| [Kandinsky 2.2](kandinsky_v22) | text2image, image2image, inpainting |
| [Kandinsky 3](kandinsky3) | text2image, image2image |
| [Kolors](kolors) | text2image |
| [Latent Consistency Models](latent_consistency_models) | text2image |
| [Latent Diffusion](latent_diffusion) | text2image, super-resolution |
| [LDM3D](stable_diffusion/ldm3d_diffusion) | text2image, text-to-3D, text-to-pano, upscaling |
| [Latte](latte) | text2image |
| [LEDITS++](ledits_pp) | image editing |
| [Lumina-T2X](lumina) | text2image |
| [Marigold](marigold) | depth |
| [MultiDiffusion](panorama) | text2image |
| [MusicLDM](musicldm) | text2audio |
| [PAG](pag) | text2image |
| [Paint by Example](paint_by_example) | inpainting |
| [ParaDiGMS](paradigms) | text2image |
| [Pix2Pix Zero](pix2pix_zero) | image editing |
| [PIA](pia) | image2video |
| [PixArt-α](pixart) | text2image |
| [PNDM](pndm) | unconditional image generation |
| [RePaint](repaint) | inpainting |
| [Score SDE VE](score_sde_ve) | unconditional image generation |
| [PixArt-Σ](pixart_sigma) | text2image |
| [Self-Attention Guidance](self_attention_guidance) | text2image |
| [Semantic Guidance](semantic_stable_diffusion) | text2image |
| [Shap-E](shap_e) | text-to-3D, image-to-3D |
| [Spectrogram Diffusion](spectrogram_diffusion) | |
| [Stable Audio](stable_audio) | text2audio |
| [Stable Cascade](stable_cascade) | text2image |
| [Stable Diffusion](stable_diffusion/overview) | text2image, image2image, depth2image, inpainting, image variation, latent upscaler, super-resolution |
| [Stable Diffusion Model Editing](model_editing) | model editing |
| [Stable Diffusion XL](stable_diffusion/stable_diffusion_xl) | text2image, image2image, inpainting |
| [Stable Diffusion XL Turbo](stable_diffusion/sdxl_turbo) | text2image, image2image, inpainting |
| [Stable unCLIP](stable_unclip) | text2image, image variation |
| [Stochastic Karras VE](stochastic_karras_ve) | unconditional image generation |
| [T2I-Adapter](stable_diffusion/adapter) | text2image |
| [Text2Video](text_to_video) | text2video, video2video |
| [Text2Video-Zero](text_to_video_zero) | text2video |
| [unCLIP](unclip) | text2image, image variation |
| [Unconditional Latent Diffusion](latent_diffusion_uncond) | unconditional image generation |
| [UniDiffuser](unidiffuser) | text2image, image2text, image variation, text variation, unconditional image generation, unconditional audio generation |
| [Value-guided planning](value_guided_sampling) | value guided sampling |
| [Versatile Diffusion](versatile_diffusion) | text2image, image variation |
| [VQ Diffusion](vq_diffusion) | text2image |
| [Wuerstchen](wuerstchen) | text2image |
## DiffusionPipeline
@@ -314,11 +314,12 @@ def save_new_embed(text_encoder, modifier_token_id, accelerator, args, output_di
for x, y in zip(modifier_token_id, args.modifier_token):
learned_embeds_dict = {}
learned_embeds_dict[y] = learned_embeds[x]
filename = f"{output_dir}/{y}.bin"
if safe_serialization:
filename = f"{output_dir}/{y}.safetensors"
safetensors.torch.save_file(learned_embeds_dict, filename, metadata={"format": "pt"})
else:
filename = f"{output_dir}/{y}.bin"
torch.save(learned_embeds_dict, filename)
@@ -1040,17 +1041,22 @@ def main(args):
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
num_training_steps_for_scheduler = (
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
)
else:
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_warmup_steps=num_warmup_steps_for_scheduler,
num_training_steps=num_training_steps_for_scheduler,
)
# Prepare everything with our `accelerator`.
@@ -1065,8 +1071,14 @@ def main(args):
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
logger.warning(
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
f"This inconsistency may result in the learning rate scheduler not functioning properly."
)
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+41 -8
View File
@@ -8,8 +8,10 @@ The `train_dreambooth_flux.py` script shows how to implement the training proced
>
> Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements -
> a LoRA with a rank of 16 (w/ all components trained) can exceed 40GB of VRAM for training.
> For more tips & guidance on training on a resource-constrained device please visit [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX.md)
> For more tips & guidance on training on a resource-constrained device and general good practices please check out these great guides and trainers for FLUX:
> 1) [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX.md)
> 2) [`ostris`'s guide](https://github.com/ostris/ai-toolkit?tab=readme-ov-file#flux1-training)
> [!NOTE]
> **Gated model**
@@ -100,8 +102,10 @@ accelerate launch train_dreambooth_flux.py \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--guidance_scale=1 \
--gradient_accumulation_steps=4 \
--learning_rate=1e-4 \
--optimizer="prodigy" \
--learning_rate=1. \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
@@ -120,15 +124,23 @@ To better track our training experiments, we're using the following flags in the
> [!NOTE]
> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases.
> [!TIP]
> You can pass `--use_8bit_adam` to reduce the memory requirements of training. Make sure to install `bitsandbytes` if you want to do so.
## LoRA + DreamBooth
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
### Prodigy Optimizer
Prodigy is an adaptive optimizer that dynamically adjusts the learning rate learned parameters based on past gradients, allowing for more efficient convergence.
By using prodigy we can "eliminate" the need for manual learning rate tuning. read more [here](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers).
to use prodigy, specify
```bash
--optimizer="prodigy"
```
> [!TIP]
> When using prodigy it's generally good practice to set- `--learning_rate=1.0`
To perform DreamBooth with LoRA, run:
```bash
@@ -144,8 +156,10 @@ accelerate launch train_dreambooth_lora_flux.py \
--instance_prompt="a photo of sks dog" \
--resolution=512 \
--train_batch_size=1 \
--guidance_scale=1 \
--gradient_accumulation_steps=4 \
--learning_rate=1e-5 \
--optimizer="prodigy" \
--learning_rate=1. \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
@@ -162,6 +176,7 @@ Alongside the transformer, fine-tuning of the CLIP text encoder is also supporte
To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind:
> [!NOTE]
> This is still an experimental feature.
> FLUX.1 has 2 text encoders (CLIP L/14 and T5-v1.1-XXL).
By enabling `--train_text_encoder`, fine-tuning of the **CLIP encoder** is performed.
> At the moment, T5 fine-tuning is not supported and weights remain frozen when text encoder training is enabled.
@@ -180,8 +195,10 @@ accelerate launch train_dreambooth_lora_flux.py \
--instance_prompt="a photo of sks dog" \
--resolution=512 \
--train_batch_size=1 \
--guidance_scale=1 \
--gradient_accumulation_steps=4 \
--learning_rate=1e-5 \
--optimizer="prodigy" \
--learning_rate=1. \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
@@ -191,5 +208,21 @@ accelerate launch train_dreambooth_lora_flux.py \
--push_to_hub
```
## Memory Optimizations
As mentioned, Flux Dreambooth LoRA training is very memory intensive Here are some options (some still experimental) for a more memory efficient training.
### Image Resolution
An easy way to mitigate some of the memory requirements is through `--resolution`. `--resolution` refers to the resolution for input images, all the images in the train/validation dataset are resized to this.
Note that by default, images are resized to resolution of 512, but it's good to keep in mind in case you're accustomed to training on higher resolutions.
### Gradient Checkpointing and Accumulation
* `--gradient accumulation` refers to the number of updates steps to accumulate before performing a backward/update pass.
by passing a value > 1 you can reduce the amount of backward/update passes and hence also memory reqs.
* with `--gradient checkpointing` we can save memory by not storing all intermediate activations during the forward pass.
Instead, only a subset of these activations (the checkpoints) are stored and the rest is recomputed as needed during the backward pass. Note that this comes at the expanse of a slower backward pass.
### 8-bit-Adam Optimizer
When training with `AdamW`(doesn't apply to `prodigy`) You can pass `--use_8bit_adam` to reduce the memory requirements of training.
Make sure to install `bitsandbytes` if you want to do so.
### latent caching
When training w/o validation runs, we can pre-encode the training images with the vae, and then delete it to free up some memory.
to enable `latent_caching`, first, use the version in [this PR](https://github.com/huggingface/diffusers/blob/1b195933d04e4c8281a2634128c0d2d380893f73/examples/dreambooth/train_dreambooth_lora_flux.py), and then pass `--cache_latents`
## Other notes
Thanks to `bghira` for their help with reviewing & insight sharing ♥️
Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️
+72 -58
View File
@@ -842,7 +842,7 @@ class PromptDataset(Dataset):
return example
def tokenize_prompt(tokenizer, prompt, max_sequence_length=512):
def tokenize_prompt(tokenizer, prompt, max_sequence_length):
text_inputs = tokenizer(
prompt,
padding="max_length",
@@ -863,20 +863,26 @@ def _encode_prompt_with_t5(
prompt=None,
num_images_per_prompt=1,
device=None,
text_input_ids=None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
if tokenizer is not None:
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
else:
if text_input_ids is None:
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
dtype = text_encoder.dtype
@@ -896,22 +902,28 @@ def _encode_prompt_with_clip(
tokenizer,
prompt: str,
device=None,
text_input_ids=None,
num_images_per_prompt: int = 1,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=77,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",
)
if tokenizer is not None:
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=77,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
else:
if text_input_ids is None:
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
# Use pooled output of CLIPTextModel
@@ -932,17 +944,19 @@ def encode_prompt(
max_sequence_length,
device=None,
num_images_per_prompt: int = 1,
text_input_ids_list=None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
dtype = text_encoders[0].dtype
device = device if device is not None else text_encoders[1].device
pooled_prompt_embeds = _encode_prompt_with_clip(
text_encoder=text_encoders[0],
tokenizer=tokenizers[0],
prompt=prompt,
device=device if device is not None else text_encoders[0].device,
device=device,
num_images_per_prompt=num_images_per_prompt,
text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
)
prompt_embeds = _encode_prompt_with_t5(
@@ -951,7 +965,8 @@ def encode_prompt(
max_sequence_length=max_sequence_length,
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
device=device if device is not None else text_encoders[1].device,
device=device,
text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
)
text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
@@ -1499,7 +1514,25 @@ def main(args):
)
else:
tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77)
tokens_two = tokenize_prompt(tokenizer_two, prompts, max_sequence_length=512)
tokens_two = tokenize_prompt(
tokenizer_two, prompts, max_sequence_length=args.max_sequence_length
)
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two],
tokenizers=[None, None],
text_input_ids_list=[tokens_one, tokens_two],
max_sequence_length=args.max_sequence_length,
prompt=prompts,
)
else:
if args.train_text_encoder:
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two],
tokenizers=[None, None],
text_input_ids_list=[tokens_one, tokens_two],
max_sequence_length=args.max_sequence_length,
prompt=args.instance_prompt,
)
# Convert images to latent space
model_input = vae.encode(pixel_values).latent_dist.sample()
@@ -1553,41 +1586,22 @@ def main(args):
guidance = None
# Predict the noise residual
if not args.train_text_encoder:
model_pred = transformer(
hidden_states=packed_noisy_model_input,
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
timestep=timesteps / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
return_dict=False,
)[0]
else:
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two],
tokenizers=None,
prompt=None,
text_input_ids_list=[tokens_one, tokens_two],
)
model_pred = transformer(
hidden_states=packed_noisy_model_input,
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
timestep=timesteps / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
return_dict=False,
)[0]
model_pred = transformer(
hidden_states=packed_noisy_model_input,
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
timestep=timesteps / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
return_dict=False,
)[0]
# upscaling height & width as discussed in https://github.com/huggingface/diffusers/pull/9257#discussion_r1731108042
model_pred = FluxPipeline._unpack_latents(
model_pred,
height=int(model_input.shape[2]),
width=int(model_input.shape[3]),
height=int(model_input.shape[2] * vae_scale_factor / 2),
width=int(model_input.shape[3] * vae_scale_factor / 2),
vae_scale_factor=vae_scale_factor,
)
@@ -15,7 +15,6 @@
import argparse
import copy
import gc
import itertools
import logging
import math
@@ -56,6 +55,7 @@ from diffusers.optimization import get_scheduler
from diffusers.training_utils import (
_set_state_dict_into_text_encoder,
cast_training_params,
clear_objs_and_retain_memory,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
)
@@ -210,9 +210,7 @@ def log_validation(
}
)
del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
clear_objs_and_retain_memory(objs=[pipeline])
return images
@@ -1107,9 +1105,7 @@ def main(args):
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
image.save(image_filename)
del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
clear_objs_and_retain_memory(objs=[pipeline])
# Handle the repository creation
if accelerator.is_main_process:
@@ -1455,12 +1451,10 @@ def main(args):
# Clear the memory here
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
del tokenizers, text_encoders
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
del text_encoder_one, text_encoder_two, text_encoder_three
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
clear_objs_and_retain_memory(
objs=[tokenizers, text_encoders, text_encoder_one, text_encoder_two, text_encoder_three]
)
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
@@ -1795,11 +1789,11 @@ def main(args):
pipeline_args=pipeline_args,
epoch=epoch,
)
objs = []
if not args.train_text_encoder:
del text_encoder_one, text_encoder_two, text_encoder_three
objs.extend([text_encoder_one, text_encoder_two, text_encoder_three])
torch.cuda.empty_cache()
gc.collect()
clear_objs_and_retain_memory(objs=objs)
# Save the lora layers
accelerator.wait_for_everyone()
+8
View File
@@ -89,6 +89,7 @@ else:
"ControlNetXSAdapter",
"DiTTransformer2DModel",
"FluxControlNetModel",
"FluxMultiControlNetModel",
"FluxTransformer2DModel",
"HunyuanDiT2DControlNetModel",
"HunyuanDiT2DModel",
@@ -254,8 +255,11 @@ else:
"BlipDiffusionPipeline",
"CLIPImageProjection",
"CogVideoXPipeline",
"CogVideoXVideoToVideoPipeline",
"CycleDiffusionPipeline",
"FluxControlNetPipeline",
"FluxImg2ImgPipeline",
"FluxInpaintPipeline",
"FluxPipeline",
"HunyuanDiTControlNetPipeline",
"HunyuanDiTPAGPipeline",
@@ -554,6 +558,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ControlNetXSAdapter,
DiTTransformer2DModel,
FluxControlNetModel,
FluxMultiControlNetModel,
FluxTransformer2DModel,
HunyuanDiT2DControlNetModel,
HunyuanDiT2DModel,
@@ -697,8 +702,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AuraFlowPipeline,
CLIPImageProjection,
CogVideoXPipeline,
CogVideoXVideoToVideoPipeline,
CycleDiffusionPipeline,
FluxControlNetPipeline,
FluxImg2ImgPipeline,
FluxInpaintPipeline,
FluxPipeline,
HunyuanDiTControlNetPipeline,
HunyuanDiTPAGPipeline,
+1 -1
View File
@@ -569,7 +569,7 @@ class VaeImageProcessor(ConfigMixin):
channel = image.shape[1]
# don't need any preprocess if the image is latents
if channel == self.vae_latent_channels:
if channel == self.config.vae_latent_channels:
return image
height, width = self.get_default_height_width(image, height, width)
+2
View File
@@ -208,6 +208,8 @@ class IPAdapterMixin:
pretrained_model_name_or_path_or_dict,
subfolder=image_encoder_subfolder,
low_cpu_mem_usage=low_cpu_mem_usage,
cache_dir=cache_dir,
local_files_only=local_files_only,
).to(self.device, dtype=self.dtype)
self.register_modules(image_encoder=image_encoder)
else:
+13 -6
View File
@@ -38,6 +38,7 @@ from ..utils import (
set_adapter_layers,
set_weights_and_activate_adapters,
)
from ..utils.torch_utils import is_compiled_module
if is_transformers_available():
@@ -371,6 +372,7 @@ class LoraBaseMixin:
for component in self._lora_loadable_modules:
model = getattr(self, component, None)
if model is not None:
model = model._orig_mod if is_compiled_module(model) else model
if issubclass(model.__class__, ModelMixin):
model.unload_lora()
elif issubclass(model.__class__, PreTrainedModel):
@@ -446,6 +448,7 @@ class LoraBaseMixin:
model = getattr(self, fuse_component, None)
if model is not None:
model = model._orig_mod if is_compiled_module(model) else model
# check if diffusers model
if issubclass(model.__class__, ModelMixin):
model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
@@ -506,6 +509,7 @@ class LoraBaseMixin:
model = getattr(self, fuse_component, None)
if model is not None:
model = model._orig_mod if is_compiled_module(model) else model
if issubclass(model.__class__, (ModelMixin, PreTrainedModel)):
for module in model.modules():
if isinstance(module, BaseTunerLayer):
@@ -569,6 +573,7 @@ class LoraBaseMixin:
_component_adapter_weights.setdefault(component, [])
_component_adapter_weights[component].append(component_adapter_weights)
model = model._orig_mod if is_compiled_module(model) else model
if issubclass(model.__class__, ModelMixin):
model.set_adapters(adapter_names, _component_adapter_weights[component])
elif issubclass(model.__class__, PreTrainedModel):
@@ -581,6 +586,7 @@ class LoraBaseMixin:
for component in self._lora_loadable_modules:
model = getattr(self, component, None)
if model is not None:
model = model._orig_mod if is_compiled_module(model) else model
if issubclass(model.__class__, ModelMixin):
model.disable_lora()
elif issubclass(model.__class__, PreTrainedModel):
@@ -593,6 +599,7 @@ class LoraBaseMixin:
for component in self._lora_loadable_modules:
model = getattr(self, component, None)
if model is not None:
model = model._orig_mod if is_compiled_module(model) else model
if issubclass(model.__class__, ModelMixin):
model.enable_lora()
elif issubclass(model.__class__, PreTrainedModel):
@@ -614,6 +621,7 @@ class LoraBaseMixin:
for component in self._lora_loadable_modules:
model = getattr(self, component, None)
if model is not None:
model = model._orig_mod if is_compiled_module(model) else model
if issubclass(model.__class__, ModelMixin):
model.delete_adapters(adapter_names)
elif issubclass(model.__class__, PreTrainedModel):
@@ -645,6 +653,7 @@ class LoraBaseMixin:
for component in self._lora_loadable_modules:
model = getattr(self, component, None)
model = model._orig_mod if is_compiled_module(model) else model
if model is not None and issubclass(model.__class__, ModelMixin):
for module in model.modules():
if isinstance(module, BaseTunerLayer):
@@ -666,12 +675,10 @@ class LoraBaseMixin:
for component in self._lora_loadable_modules:
model = getattr(self, component, None)
if (
model is not None
and issubclass(model.__class__, (ModelMixin, PreTrainedModel))
and hasattr(model, "peft_config")
):
set_adapters[component] = list(model.peft_config.keys())
if model is not None:
model = model._orig_mod if is_compiled_module(model) else model
if issubclass(model.__class__, (ModelMixin, PreTrainedModel)) and hasattr(model, "peft_config"):
set_adapters[component] = list(model.peft_config.keys())
return set_adapters
@@ -14,6 +14,8 @@
import re
import torch
from ..utils import is_peft_version, logging
@@ -326,3 +328,296 @@ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
prefix = "text_encoder_2."
new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
return {new_name: alpha}
# The utilities under `_convert_kohya_flux_lora_to_diffusers()`
# are taken from https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
# All credits go to `kohya-ss`.
def _convert_kohya_flux_lora_to_diffusers(state_dict):
def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
if sds_key + ".lora_down.weight" not in sds_sd:
return
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
# scale weight by alpha and dim
rank = down_weight.shape[0]
alpha = sds_sd.pop(sds_key + ".alpha").item() # alpha is scalar
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
# calculate scale_down and scale_up to keep the same value. if scale is 4, scale_down is 2 and scale_up is 2
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down
ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up
def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
if sds_key + ".lora_down.weight" not in sds_sd:
return
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
sd_lora_rank = down_weight.shape[0]
# scale weight by alpha and dim
alpha = sds_sd.pop(sds_key + ".alpha")
scale = alpha / sd_lora_rank
# calculate scale_down and scale_up
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
down_weight = down_weight * scale_down
up_weight = up_weight * scale_up
# calculate dims if not provided
num_splits = len(ait_keys)
if dims is None:
dims = [up_weight.shape[0] // num_splits] * num_splits
else:
assert sum(dims) == up_weight.shape[0]
# check upweight is sparse or not
is_sparse = False
if sd_lora_rank % num_splits == 0:
ait_rank = sd_lora_rank // num_splits
is_sparse = True
i = 0
for j in range(len(dims)):
for k in range(len(dims)):
if j == k:
continue
is_sparse = is_sparse and torch.all(
up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0
)
i += dims[j]
if is_sparse:
logger.info(f"weight is sparse: {sds_key}")
# make ai-toolkit weight
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
if not is_sparse:
# down_weight is copied to each split
ait_sd.update({k: down_weight for k in ait_down_keys})
# up_weight is split to each split
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
else:
# down_weight is chunked to each split
ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416
# up_weight is sparse: only non-zero values are copied to each split
i = 0
for j in range(len(dims)):
ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous()
i += dims[j]
def _convert_sd_scripts_to_ai_toolkit(sds_sd):
ait_sd = {}
for i in range(19):
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_img_attn_proj",
f"transformer.transformer_blocks.{i}.attn.to_out.0",
)
_convert_to_ai_toolkit_cat(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_img_attn_qkv",
[
f"transformer.transformer_blocks.{i}.attn.to_q",
f"transformer.transformer_blocks.{i}.attn.to_k",
f"transformer.transformer_blocks.{i}.attn.to_v",
],
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_img_mlp_0",
f"transformer.transformer_blocks.{i}.ff.net.0.proj",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_img_mlp_2",
f"transformer.transformer_blocks.{i}.ff.net.2",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_img_mod_lin",
f"transformer.transformer_blocks.{i}.norm1.linear",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_attn_proj",
f"transformer.transformer_blocks.{i}.attn.to_add_out",
)
_convert_to_ai_toolkit_cat(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_attn_qkv",
[
f"transformer.transformer_blocks.{i}.attn.add_q_proj",
f"transformer.transformer_blocks.{i}.attn.add_k_proj",
f"transformer.transformer_blocks.{i}.attn.add_v_proj",
],
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_mlp_0",
f"transformer.transformer_blocks.{i}.ff_context.net.0.proj",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_mlp_2",
f"transformer.transformer_blocks.{i}.ff_context.net.2",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_mod_lin",
f"transformer.transformer_blocks.{i}.norm1_context.linear",
)
for i in range(38):
_convert_to_ai_toolkit_cat(
sds_sd,
ait_sd,
f"lora_unet_single_blocks_{i}_linear1",
[
f"transformer.single_transformer_blocks.{i}.attn.to_q",
f"transformer.single_transformer_blocks.{i}.attn.to_k",
f"transformer.single_transformer_blocks.{i}.attn.to_v",
f"transformer.single_transformer_blocks.{i}.proj_mlp",
],
dims=[3072, 3072, 3072, 12288],
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_single_blocks_{i}_linear2",
f"transformer.single_transformer_blocks.{i}.proj_out",
)
_convert_to_ai_toolkit(
sds_sd,
ait_sd,
f"lora_unet_single_blocks_{i}_modulation_lin",
f"transformer.single_transformer_blocks.{i}.norm.linear",
)
if len(sds_sd) > 0:
logger.warning(f"Unsuppored keys for ai-toolkit: {sds_sd.keys()}")
return ait_sd
return _convert_sd_scripts_to_ai_toolkit(state_dict)
# Adapted from https://gist.github.com/Leommm-byte/6b331a1e9bd53271210b26543a7065d6
# Some utilities were reused from
# https://github.com/kohya-ss/sd-scripts/blob/a61cf73a5cb5209c3f4d1a3688dd276a4dfd1ecb/networks/convert_flux_lora.py
def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
new_state_dict = {}
orig_keys = list(old_state_dict.keys())
def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
down_weight = sds_sd.pop(sds_key)
up_weight = sds_sd.pop(sds_key.replace(".down.weight", ".up.weight"))
# calculate dims if not provided
num_splits = len(ait_keys)
if dims is None:
dims = [up_weight.shape[0] // num_splits] * num_splits
else:
assert sum(dims) == up_weight.shape[0]
# make ai-toolkit weight
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
# down_weight is copied to each split
ait_sd.update({k: down_weight for k in ait_down_keys})
# up_weight is split to each split
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
for old_key in orig_keys:
# Handle double_blocks
if old_key.startswith(("diffusion_model.double_blocks", "double_blocks")):
block_num = re.search(r"double_blocks\.(\d+)", old_key).group(1)
new_key = f"transformer.transformer_blocks.{block_num}"
if "processor.proj_lora1" in old_key:
new_key += ".attn.to_out.0"
elif "processor.proj_lora2" in old_key:
new_key += ".attn.to_add_out"
# Handle text latents.
elif "processor.qkv_lora2" in old_key and "up" not in old_key:
handle_qkv(
old_state_dict,
new_state_dict,
old_key,
[
f"transformer.transformer_blocks.{block_num}.attn.add_q_proj",
f"transformer.transformer_blocks.{block_num}.attn.add_k_proj",
f"transformer.transformer_blocks.{block_num}.attn.add_v_proj",
],
)
# continue
# Handle image latents.
elif "processor.qkv_lora1" in old_key and "up" not in old_key:
handle_qkv(
old_state_dict,
new_state_dict,
old_key,
[
f"transformer.transformer_blocks.{block_num}.attn.to_q",
f"transformer.transformer_blocks.{block_num}.attn.to_k",
f"transformer.transformer_blocks.{block_num}.attn.to_v",
],
)
# continue
if "down" in old_key:
new_key += ".lora_A.weight"
elif "up" in old_key:
new_key += ".lora_B.weight"
# Handle single_blocks
elif old_key.startswith("diffusion_model.single_blocks", "single_blocks"):
block_num = re.search(r"single_blocks\.(\d+)", old_key).group(1)
new_key = f"transformer.single_transformer_blocks.{block_num}"
if "proj_lora1" in old_key or "proj_lora2" in old_key:
new_key += ".proj_out"
elif "qkv_lora1" in old_key or "qkv_lora2" in old_key:
new_key += ".norm.linear"
if "down" in old_key:
new_key += ".lora_A.weight"
elif "up" in old_key:
new_key += ".lora_B.weight"
else:
# Handle other potential key patterns here
new_key = old_key
# Since we already handle qkv above.
if "qkv" not in old_key:
new_state_dict[new_key] = old_state_dict.pop(old_key)
if len(old_state_dict) > 0:
raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.")
return new_state_dict
+23 -1
View File
@@ -30,8 +30,14 @@ from ..utils import (
logging,
scale_lora_layers,
)
from ..utils.torch_utils import is_compiled_module
from .lora_base import LoraBaseMixin
from .lora_conversion_utils import _convert_non_diffusers_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers
from .lora_conversion_utils import (
_convert_kohya_flux_lora_to_diffusers,
_convert_non_diffusers_lora_to_diffusers,
_convert_xlabs_flux_lora_to_diffusers,
_maybe_map_sgm_blocks_to_diffusers,
)
if is_transformers_available():
@@ -1583,6 +1589,20 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
allow_pickle=allow_pickle,
)
# TODO (sayakpaul): to a follow-up to clean and try to unify the conditions.
is_kohya = any(".lora_down.weight" in k for k in state_dict)
if is_kohya:
state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict)
# Kohya already takes care of scaling the LoRA parameters with alpha.
return (state_dict, None) if return_alphas else state_dict
is_xlabs = any("processor" in k for k in state_dict)
if is_xlabs:
state_dict = _convert_xlabs_flux_lora_to_diffusers(state_dict)
# xlabs doesn't use `alpha`.
return (state_dict, None) if return_alphas else state_dict
# For state dicts like
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
keys = list(state_dict.keys())
@@ -1733,6 +1753,8 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
if is_compiled_module(transformer):
state_dict = {"_orig_mod." + k: v for k, v in state_dict.items()}
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
if incompatible_keys is not None:
+2 -2
View File
@@ -91,11 +91,11 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"xl_inpaint": {"pretrained_model_name_or_path": "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"},
"playground-v2-5": {"pretrained_model_name_or_path": "playgroundai/playground-v2.5-1024px-aesthetic"},
"upscale": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-x4-upscaler"},
"inpainting": {"pretrained_model_name_or_path": "runwayml/stable-diffusion-inpainting"},
"inpainting": {"pretrained_model_name_or_path": "Lykon/dreamshaper-8-inpainting"},
"inpainting_v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-inpainting"},
"controlnet": {"pretrained_model_name_or_path": "lllyasviel/control_v11p_sd15_canny"},
"v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-1"},
"v1": {"pretrained_model_name_or_path": "runwayml/stable-diffusion-v1-5"},
"v1": {"pretrained_model_name_or_path": "Lykon/dreamshaper-8"},
"stable_cascade_stage_b": {"pretrained_model_name_or_path": "stabilityai/stable-cascade", "subfolder": "decoder"},
"stable_cascade_stage_b_lite": {
"pretrained_model_name_or_path": "stabilityai/stable-cascade",
+2 -2
View File
@@ -35,7 +35,7 @@ if is_torch_available():
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["autoencoders.vq_model"] = ["VQModel"]
_import_structure["controlnet"] = ["ControlNetModel"]
_import_structure["controlnet_flux"] = ["FluxControlNetModel"]
_import_structure["controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
_import_structure["controlnet_hunyuan"] = ["HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel"]
_import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
_import_structure["controlnet_sparsectrl"] = ["SparseControlNetModel"]
@@ -88,7 +88,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
VQModel,
)
from .controlnet import ControlNetModel
from .controlnet_flux import FluxControlNetModel
from .controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
from .controlnet_hunyuan import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel
from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
from .controlnet_sparsectrl import SparseControlNetModel
+41 -6
View File
@@ -972,15 +972,32 @@ class FreeNoiseTransformerBlock(nn.Module):
return frame_indices
def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
if weighting_scheme == "pyramid":
if weighting_scheme == "flat":
weights = [1.0] * num_frames
elif weighting_scheme == "pyramid":
if num_frames % 2 == 0:
# num_frames = 4 => [1, 2, 2, 1]
weights = list(range(1, num_frames // 2 + 1))
mid = num_frames // 2
weights = list(range(1, mid + 1))
weights = weights + weights[::-1]
else:
# num_frames = 5 => [1, 2, 3, 2, 1]
weights = list(range(1, num_frames // 2 + 1))
weights = weights + [num_frames // 2 + 1] + weights[::-1]
mid = (num_frames + 1) // 2
weights = list(range(1, mid))
weights = weights + [mid] + weights[::-1]
elif weighting_scheme == "delayed_reverse_sawtooth":
if num_frames % 2 == 0:
# num_frames = 4 => [0.01, 2, 2, 1]
mid = num_frames // 2
weights = [0.01] * (mid - 1) + [mid]
weights = weights + list(range(mid, 0, -1))
else:
# num_frames = 5 => [0.01, 0.01, 3, 2, 1]
mid = (num_frames + 1) // 2
weights = [0.01] * mid
weights = weights + list(range(mid, 0, -1))
else:
raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")
@@ -1087,8 +1104,26 @@ class FreeNoiseTransformerBlock(nn.Module):
accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
num_times_accumulated[:, frame_start:frame_end] += weights
hidden_states = torch.where(
num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
# TODO(aryan): Maybe this could be done in a better way.
#
# Previously, this was:
# hidden_states = torch.where(
# num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
# )
#
# The reasoning for the change here is `torch.where` became a bottleneck at some point when golfing memory
# spikes. It is particularly noticeable when the number of frames is high. My understanding is that this comes
# from tensors being copied - which is why we resort to spliting and concatenating here. I've not particularly
# looked into this deeply because other memory optimizations led to more pronounced reductions.
hidden_states = torch.cat(
[
torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split)
for accumulated_split, num_times_split in zip(
accumulated_values.split(self.context_length, dim=1),
num_times_accumulated.split(self.context_length, dim=1),
)
],
dim=1,
).to(dtype)
# 3. Feed-forward
@@ -999,6 +999,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
# setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
# number of temporal frames.
self.num_latent_frames_batch_size = 2
self.num_sample_frames_batch_size = 8
# We make the minimum height and width of sample for tiling half that of the generally supported
self.tile_sample_min_height = sample_height // 2
@@ -1081,6 +1082,29 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
return self.tiled_encode(x)
frame_batch_size = self.num_sample_frames_batch_size
enc = []
for i in range(num_frames // frame_batch_size):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
end_frame = frame_batch_size * (i + 1) + remaining_frames
x_intermediate = x[:, :, start_frame:end_frame]
x_intermediate = self.encoder(x_intermediate)
if self.quant_conv is not None:
x_intermediate = self.quant_conv(x_intermediate)
enc.append(x_intermediate)
self._clear_fake_context_parallel_cache()
enc = torch.cat(enc, dim=2)
return enc
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
@@ -1094,13 +1118,17 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded images. If `return_dict` is True, a
The latent representations of the encoded videos. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
h = self.encoder(x)
if self.quant_conv is not None:
h = self.quant_conv(h)
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
@@ -1172,6 +1200,75 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
)
return b
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
output, but they should be much less noticeable.
Args:
x (`torch.Tensor`): Input batch of videos.
Returns:
`torch.Tensor`:
The latent representation of the encoded videos.
"""
# For a rough memory estimate, take a look at the `tiled_decode` method.
batch_size, num_channels, num_frames, height, width = x.shape
overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor_height))
overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor_width))
blend_extent_height = int(self.tile_latent_min_height * self.tile_overlap_factor_height)
blend_extent_width = int(self.tile_latent_min_width * self.tile_overlap_factor_width)
row_limit_height = self.tile_latent_min_height - blend_extent_height
row_limit_width = self.tile_latent_min_width - blend_extent_width
frame_batch_size = self.num_sample_frames_batch_size
# Split x into overlapping tiles and encode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, overlap_height):
row = []
for j in range(0, width, overlap_width):
time = []
for k in range(num_frames // frame_batch_size):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
end_frame = frame_batch_size * (k + 1) + remaining_frames
tile = x[
:,
:,
start_frame:end_frame,
i : i + self.tile_sample_min_height,
j : j + self.tile_sample_min_width,
]
tile = self.encoder(tile)
if self.quant_conv is not None:
tile = self.quant_conv(tile)
time.append(tile)
self._clear_fake_context_parallel_cache()
row.append(torch.cat(time, dim=2))
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent_width)
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
result_rows.append(torch.cat(result_row, dim=4))
enc = torch.cat(result_rows, dim=3)
return enc
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.
+134 -3
View File
@@ -54,6 +54,7 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: List[int] = [16, 56, 56],
num_mode: int = None,
):
super().__init__()
self.out_channels = in_channels
@@ -101,6 +102,10 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
for _ in range(len(self.single_transformer_blocks)):
self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
self.union = num_mode is not None
if self.union:
self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
self.gradient_checkpointing = False
@@ -173,8 +178,8 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
def from_transformer(
cls,
transformer,
num_layers=4,
num_single_layers=10,
num_layers: int = 4,
num_single_layers: int = 10,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
load_weights_from_transformer=True,
@@ -205,6 +210,7 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
self,
hidden_states: torch.Tensor,
controlnet_cond: torch.Tensor,
controlnet_mode: torch.Tensor = None,
conditioning_scale: float = 1.0,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
@@ -221,6 +227,12 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
controlnet_cond (`torch.Tensor`):
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
controlnet_mode (`torch.Tensor`):
The mode tensor of shape `(batch_size, 1)`.
conditioning_scale (`float`, defaults to `1.0`):
The scale factor for ControlNet outputs.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
@@ -272,6 +284,15 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if self.union:
# union mode
if controlnet_mode is None:
raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
# union mode emb
controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)
if txt_ids.ndim == 3:
logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
@@ -367,7 +388,6 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples]
#
controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
controlnet_single_block_samples = (
None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
@@ -384,3 +404,114 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
controlnet_block_samples=controlnet_block_samples,
controlnet_single_block_samples=controlnet_single_block_samples,
)
class FluxMultiControlNetModel(ModelMixin):
r"""
`FluxMultiControlNetModel` wrapper class for Multi-FluxControlNetModel
This module is a wrapper for multiple instances of the `FluxControlNetModel`. The `forward()` API is designed to be
compatible with `FluxControlNetModel`.
Args:
controlnets (`List[FluxControlNetModel]`):
Provides additional conditioning to the unet during the denoising process. You must set multiple
`FluxControlNetModel` as a list.
"""
def __init__(self, controlnets):
super().__init__()
self.nets = nn.ModuleList(controlnets)
def forward(
self,
hidden_states: torch.FloatTensor,
controlnet_cond: List[torch.tensor],
controlnet_mode: List[torch.tensor],
conditioning_scale: List[float],
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[FluxControlNetOutput, Tuple]:
# ControlNet-Union with multiple conditions
# only load one ControlNet for saving memories
if len(self.nets) == 1 and self.nets[0].union:
controlnet = self.nets[0]
for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)):
block_samples, single_block_samples = controlnet(
hidden_states=hidden_states,
controlnet_cond=image,
controlnet_mode=mode[:, None],
conditioning_scale=scale,
timestep=timestep,
guidance=guidance,
pooled_projections=pooled_projections,
encoder_hidden_states=encoder_hidden_states,
txt_ids=txt_ids,
img_ids=img_ids,
joint_attention_kwargs=joint_attention_kwargs,
return_dict=return_dict,
)
# merge samples
if i == 0:
control_block_samples = block_samples
control_single_block_samples = single_block_samples
else:
control_block_samples = [
control_block_sample + block_sample
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
]
control_single_block_samples = [
control_single_block_sample + block_sample
for control_single_block_sample, block_sample in zip(
control_single_block_samples, single_block_samples
)
]
# Regular Multi-ControlNets
# load all ControlNets into memories
else:
for i, (image, mode, scale, controlnet) in enumerate(
zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets)
):
block_samples, single_block_samples = controlnet(
hidden_states=hidden_states,
controlnet_cond=image,
controlnet_mode=mode[:, None],
conditioning_scale=scale,
timestep=timestep,
guidance=guidance,
pooled_projections=pooled_projections,
encoder_hidden_states=encoder_hidden_states,
txt_ids=txt_ids,
img_ids=img_ids,
joint_attention_kwargs=joint_attention_kwargs,
return_dict=return_dict,
)
# merge samples
if i == 0:
control_block_samples = block_samples
control_single_block_samples = single_block_samples
else:
control_block_samples = [
control_block_sample + block_sample
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
]
control_single_block_samples = [
control_single_block_sample + block_sample
for control_single_block_sample, block_sample in zip(
control_single_block_samples, single_block_samples
)
]
return control_block_samples, control_single_block_samples
@@ -691,7 +691,6 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
emb = self.time_embedding(t_emb, timestep_cond)
emb = emb.repeat_interleave(sample_num_frames, dim=0)
encoder_hidden_states = encoder_hidden_states.repeat_interleave(sample_num_frames, dim=0)
# 2. pre-process
batch_size, channels, num_frames, height, width = sample.shape
+109 -57
View File
@@ -342,15 +342,58 @@ class CogVideoXPatchEmbed(nn.Module):
embed_dim: int = 1920,
text_embed_dim: int = 4096,
bias: bool = True,
sample_width: int = 90,
sample_height: int = 60,
sample_frames: int = 49,
temporal_compression_ratio: int = 4,
max_text_seq_length: int = 226,
spatial_interpolation_scale: float = 1.875,
temporal_interpolation_scale: float = 1.0,
use_positional_embeddings: bool = True,
) -> None:
super().__init__()
self.patch_size = patch_size
self.embed_dim = embed_dim
self.sample_height = sample_height
self.sample_width = sample_width
self.sample_frames = sample_frames
self.temporal_compression_ratio = temporal_compression_ratio
self.max_text_seq_length = max_text_seq_length
self.spatial_interpolation_scale = spatial_interpolation_scale
self.temporal_interpolation_scale = temporal_interpolation_scale
self.use_positional_embeddings = use_positional_embeddings
self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
if use_positional_embeddings:
pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
post_patch_height = sample_height // self.patch_size
post_patch_width = sample_width // self.patch_size
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
num_patches = post_patch_height * post_patch_width * post_time_compression_frames
pos_embedding = get_3d_sincos_pos_embed(
self.embed_dim,
(post_patch_width, post_patch_height),
post_time_compression_frames,
self.spatial_interpolation_scale,
self.temporal_interpolation_scale,
)
pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
joint_pos_embedding = torch.zeros(
1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
)
joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
return joint_pos_embedding
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
r"""
Args:
@@ -371,6 +414,21 @@ class CogVideoXPatchEmbed(nn.Module):
embeds = torch.cat(
[text_embeds, image_embeds], dim=1
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
if self.use_positional_embeddings:
pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
if (
self.sample_height != height
or self.sample_width != width
or self.sample_frames != pre_time_compression_frames
):
pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames)
pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
else:
pos_embedding = self.pos_embedding
embeds = embeds + pos_embedding
return embeds
@@ -391,15 +449,16 @@ def get_3d_rotary_pos_embed(
The size of the temporal dimension.
theta (`float`):
Scaling factor for frequency computation.
use_real (`bool`):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
Returns:
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
"""
if use_real is not True:
raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
start, stop = crops_coords
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
grid_size_h, grid_size_w = grid_size
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
# Compute dimensions for each axis
@@ -408,54 +467,37 @@ def get_3d_rotary_pos_embed(
dim_w = embed_dim // 8 * 3
# Temporal frequencies
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
grid_t = torch.from_numpy(grid_t).float()
freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
freqs_t = freqs_t.repeat_interleave(2, dim=-1)
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
# Spatial frequencies for height and width
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
grid_h = torch.from_numpy(grid_h).float()
grid_w = torch.from_numpy(grid_w).float()
freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
freqs_h = freqs_h.repeat_interleave(2, dim=-1)
freqs_w = freqs_w.repeat_interleave(2, dim=-1)
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
# Broadcast and concatenate tensors along specified dimension
def broadcast(tensors, dim=-1):
num_tensors = len(tensors)
shape_lens = {len(t.shape) for t in tensors}
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
shape_len = list(shape_lens)[0]
dim = (dim + shape_len) if dim < 0 else dim
dims = list(zip(*(list(t.shape) for t in tensors)))
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
assert all(
[*(len(set(t[1])) <= 2 for t in expandable_dims)]
), "invalid dimensions for broadcastable concatenation"
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
expanded_dims.insert(dim, (dim, dims[dim]))
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
return torch.cat(tensors, dim=dim)
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
def combine_time_height_width(freqs_t, freqs_h, freqs_w):
freqs_t = freqs_t[:, None, None, :].expand(
-1, grid_size_h, grid_size_w, -1
) # temporal_size, grid_size_h, grid_size_w, dim_t
freqs_h = freqs_h[None, :, None, :].expand(
temporal_size, -1, grid_size_w, -1
) # temporal_size, grid_size_h, grid_size_2, dim_h
freqs_w = freqs_w[None, None, :, :].expand(
temporal_size, grid_size_h, -1, -1
) # temporal_size, grid_size_h, grid_size_2, dim_w
freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
freqs = torch.cat(
[freqs_t, freqs_h, freqs_w], dim=-1
) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w)
freqs = freqs.view(
temporal_size * grid_size_h * grid_size_w, -1
) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w)
return freqs
t, h, w, d = freqs.shape
freqs = freqs.view(t * h * w, d)
# Generate sine and cosine components
sin = freqs.sin()
cos = freqs.cos()
if use_real:
return cos, sin
else:
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
cos = combine_time_height_width(t_cos, h_cos, w_cos)
sin = combine_time_height_width(t_sin, h_sin, w_sin)
return cos, sin
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
@@ -530,7 +572,7 @@ def get_1d_rotary_pos_embed(
linear_factor=1.0,
ntk_factor=1.0,
repeat_interleave_real=True,
freqs_dtype=torch.float32, # torch.float32 (hunyuan, stable audio), torch.float64 (flux)
freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux)
):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
@@ -561,21 +603,30 @@ def get_1d_rotary_pos_embed(
assert dim % 2 == 0
if isinstance(pos, int):
pos = np.arange(pos)
pos = torch.arange(pos)
if isinstance(pos, np.ndarray):
pos = torch.from_numpy(pos) # type: ignore # [S]
theta = theta * ntk_factor
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2]
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
freqs = torch.outer(t, freqs) # type: ignore # [S, D/2]
freqs = (
1.0
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
/ linear_factor
) # [D/2]
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
if use_real and repeat_interleave_real:
# flux, hunyuan-dit, cogvideox
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
return freqs_cos, freqs_sin
elif use_real:
# stable audio
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D]
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D]
return freqs_cos, freqs_sin
else:
freqs_cis = torch.polar(torch.ones_like(freqs), freqs).float() # complex64 # [S, D/2]
# lumina
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
return freqs_cis
@@ -606,11 +657,11 @@ def apply_rotary_emb(
cos, sin = cos.to(x.device), sin.to(x.device)
if use_real_unbind_dim == -1:
# Use for example in Lumina
# Used for flux, cogvideox, hunyuan-dit
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
elif use_real_unbind_dim == -2:
# Use for example in Stable Audio
# Used for Stable Audio
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
else:
@@ -620,6 +671,7 @@ def apply_rotary_emb(
return out
else:
# used for lumina
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
@@ -638,7 +690,7 @@ class FluxPosEmbed(nn.Module):
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
pos = ids.squeeze().float().cpu().numpy()
pos = ids.squeeze().float()
is_mps = ids.device.type == "mps"
freqs_dtype = torch.float32 if is_mps else torch.float64
for i in range(n_axes):
@@ -23,7 +23,7 @@ from ...utils import is_torch_version, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, FeedForward
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
@@ -239,33 +239,29 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
post_patch_height = sample_height // patch_size
post_patch_width = sample_width // patch_size
post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
# 1. Patch embedding
self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True)
self.patch_embed = CogVideoXPatchEmbed(
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
text_embed_dim=text_embed_dim,
bias=True,
sample_width=sample_width,
sample_height=sample_height,
sample_frames=sample_frames,
temporal_compression_ratio=temporal_compression_ratio,
max_text_seq_length=max_text_seq_length,
spatial_interpolation_scale=spatial_interpolation_scale,
temporal_interpolation_scale=temporal_interpolation_scale,
use_positional_embeddings=not use_rotary_positional_embeddings,
)
self.embedding_dropout = nn.Dropout(dropout)
# 2. 3D positional embeddings
spatial_pos_embedding = get_3d_sincos_pos_embed(
inner_dim,
(post_patch_width, post_patch_height),
post_time_compression_frames,
spatial_interpolation_scale,
temporal_interpolation_scale,
)
spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1)
pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False)
pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding)
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
# 3. Time embeddings
# 2. Time embeddings
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
# 4. Define spatio-temporal transformers blocks
# 3. Define spatio-temporal transformers blocks
self.transformer_blocks = nn.ModuleList(
[
CogVideoXBlock(
@@ -284,7 +280,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
)
self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
# 5. Output blocks
# 4. Output blocks
self.norm_out = AdaLayerNorm(
embedding_dim=time_embed_dim,
output_dim=2 * inner_dim,
@@ -422,20 +418,13 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
# 2. Patch embedding
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
hidden_states = self.embedding_dropout(hidden_states)
# 3. Position embedding
text_seq_length = encoder_hidden_states.shape[1]
if not self.config.use_rotary_positional_embeddings:
seq_length = height * width * num_frames // (self.config.patch_size**2)
pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
hidden_states = hidden_states + pos_embeds
hidden_states = self.embedding_dropout(hidden_states)
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
# 4. Transformer blocks
# 3. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
@@ -471,11 +460,11 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
hidden_states = self.norm_final(hidden_states)
hidden_states = hidden_states[:, text_seq_length:]
# 5. Final block
# 4. Final block
hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states)
# 6. Unpatchify
# 5. Unpatchify
p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
@@ -463,7 +463,6 @@ class UNet2DConditionModel(
dropout=dropout,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
if norm_num_groups is not None:
@@ -599,7 +598,7 @@ class UNet2DConditionModel(
)
elif encoder_hid_dim_type is not None:
raise ValueError(
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
f"`encoder_hid_dim_type`: {encoder_hid_dim_type} must be None, 'text_proj', 'text_image_proj', or 'image_proj'."
)
else:
self.encoder_hid_proj = None
@@ -679,7 +678,9 @@ class UNet2DConditionModel(
# Kandinsky 2.2 ControlNet
self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
elif addition_embed_type is not None:
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
raise ValueError(
f"`addition_embed_type`: {addition_embed_type} must be None, 'text', 'text_image', 'text_time', 'image', or 'image_hint'."
)
def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
if attention_type in ["gated", "gated-text-image"]:
@@ -990,7 +991,7 @@ class UNet2DConditionModel(
image_embs = added_cond_kwargs.get("image_embeds")
aug_emb = self.add_embedding(image_embs)
elif self.config.addition_embed_type == "image_hint":
# Kandinsky 2.2 - style
# Kandinsky 2.2 ControlNet - style
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
@@ -1009,7 +1010,7 @@ class UNet2DConditionModel(
# Kandinsky 2.1 - style
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
)
image_embeds = added_cond_kwargs.get("image_embeds")
@@ -1018,14 +1019,14 @@ class UNet2DConditionModel(
# Kandinsky 2.2 - style
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
)
image_embeds = added_cond_kwargs.get("image_embeds")
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
)
if hasattr(self, "text_encoder_hid_proj") and self.text_encoder_hid_proj is not None:
@@ -1140,7 +1141,6 @@ class UNet2DConditionModel(
# 1. time
t_emb = self.get_time_embed(sample=sample, timestep=timestep)
emb = self.time_embedding(t_emb, timestep_cond)
aug_emb = None
class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
if class_emb is not None:
+41 -63
View File
@@ -116,7 +116,7 @@ class AnimateDiffTransformer3D(nn.Module):
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
self.proj_in = nn.Linear(in_channels, inner_dim)
# 3. Define transformers blocks
@@ -187,12 +187,12 @@ class AnimateDiffTransformer3D(nn.Module):
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
hidden_states = self.proj_in(hidden_states)
hidden_states = self.proj_in(input=hidden_states)
# 2. Blocks
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
@@ -200,7 +200,7 @@ class AnimateDiffTransformer3D(nn.Module):
)
# 3. Output
hidden_states = self.proj_out(hidden_states)
hidden_states = self.proj_out(input=hidden_states)
hidden_states = (
hidden_states[None, None, :]
.reshape(batch_size, height, width, num_frames, channel)
@@ -344,7 +344,7 @@ class DownBlockMotion(nn.Module):
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
hidden_states = motion_module(hidden_states, num_frames=num_frames)
@@ -352,7 +352,7 @@ class DownBlockMotion(nn.Module):
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
hidden_states = downsampler(hidden_states=hidden_states)
output_states = output_states + (hidden_states,)
@@ -531,25 +531,18 @@ class CrossAttnDownBlockMotion(nn.Module):
temb,
**ckpt_kwargs,
)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
hidden_states = attn(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
@@ -563,7 +556,7 @@ class CrossAttnDownBlockMotion(nn.Module):
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
hidden_states = downsampler(hidden_states=hidden_states)
output_states = output_states + (hidden_states,)
@@ -757,25 +750,18 @@ class CrossAttnUpBlockMotion(nn.Module):
temb,
**ckpt_kwargs,
)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
hidden_states = attn(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
@@ -783,7 +769,7 @@ class CrossAttnUpBlockMotion(nn.Module):
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size)
return hidden_states
@@ -929,13 +915,13 @@ class UpBlockMotion(nn.Module):
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
hidden_states = motion_module(hidden_states, num_frames=num_frames)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size)
return hidden_states
@@ -1080,10 +1066,19 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
hidden_states = self.resnets[0](hidden_states, temb)
hidden_states = self.resnets[0](input_tensor=hidden_states, temb=temb)
blocks = zip(self.attentions, self.resnets[1:], self.motion_modules)
for attn, resnet, motion_module in blocks:
hidden_states = attn(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
@@ -1096,14 +1091,6 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(motion_module),
hidden_states,
@@ -1117,19 +1104,11 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
**ckpt_kwargs,
)
else:
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = motion_module(
hidden_states,
num_frames=num_frames,
)
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
return hidden_states
@@ -2178,7 +2157,6 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
emb = emb if aug_emb is None else emb + aug_emb
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
if "image_embeds" not in added_cond_kwargs:
+9 -4
View File
@@ -124,7 +124,12 @@ else:
"AnimateDiffSparseControlNetPipeline",
"AnimateDiffVideoToVideoPipeline",
]
_import_structure["flux"] = ["FluxPipeline", "FluxControlNetPipeline"]
_import_structure["flux"] = [
"FluxControlNetPipeline",
"FluxImg2ImgPipeline",
"FluxInpaintPipeline",
"FluxPipeline",
]
_import_structure["audioldm"] = ["AudioLDMPipeline"]
_import_structure["audioldm2"] = [
"AudioLDM2Pipeline",
@@ -132,7 +137,7 @@ else:
"AudioLDM2UNet2DConditionModel",
]
_import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
_import_structure["cogvideo"] = ["CogVideoXPipeline"]
_import_structure["cogvideo"] = ["CogVideoXPipeline", "CogVideoXVideoToVideoPipeline"]
_import_structure["controlnet"].extend(
[
"BlipDiffusionControlNetPipeline",
@@ -454,7 +459,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
)
from .aura_flow import AuraFlowPipeline
from .blip_diffusion import BlipDiffusionPipeline
from .cogvideo import CogVideoXPipeline
from .cogvideo import CogVideoXPipeline, CogVideoXVideoToVideoPipeline
from .controlnet import (
BlipDiffusionControlNetPipeline,
StableDiffusionControlNetImg2ImgPipeline,
@@ -494,7 +499,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
VersatileDiffusionTextToImagePipeline,
VQDiffusionPipeline,
)
from .flux import FluxControlNetPipeline, FluxPipeline
from .flux import FluxControlNetPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxPipeline
from .hunyuandit import HunyuanDiTPipeline
from .i2vgen_xl import I2VGenXLPipeline
from .kandinsky import (
@@ -432,7 +432,6 @@ class AnimateDiffPipeline(
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
def check_inputs(
self,
prompt,
@@ -470,8 +469,8 @@ class AnimateDiffPipeline(
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif prompt is not None and not isinstance(prompt, (str, list, dict)):
raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)=}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
@@ -557,11 +556,15 @@ class AnimateDiffPipeline(
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt: Optional[Union[str, List[str]]] = None,
num_frames: Optional[int] = 16,
height: Optional[int] = None,
width: Optional[int] = None,
@@ -701,9 +704,10 @@ class AnimateDiffPipeline(
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
if prompt is not None and isinstance(prompt, (str, dict)):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
@@ -716,22 +720,39 @@ class AnimateDiffPipeline(
text_encoder_lora_scale = (
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_videos_per_prompt,
self.do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=self.clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if self.free_noise_enabled:
prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise(
prompt=prompt,
num_frames=num_frames,
device=device,
num_videos_per_prompt=num_videos_per_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=self.clip_skip,
)
else:
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_videos_per_prompt,
self.do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=self.clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
@@ -783,6 +804,9 @@ class AnimateDiffPipeline(
# 8. Denoising loop
with self.progress_bar(total=self._num_timesteps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -505,8 +505,8 @@ class AnimateDiffControlNetPipeline(
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif prompt is not None and not isinstance(prompt, (str, list, dict)):
raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
@@ -699,6 +699,10 @@ class AnimateDiffControlNetPipeline(
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
def __call__(
self,
@@ -858,9 +862,10 @@ class AnimateDiffControlNetPipeline(
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
if prompt is not None and isinstance(prompt, (str, dict)):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
@@ -883,22 +888,39 @@ class AnimateDiffControlNetPipeline(
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_videos_per_prompt,
self.do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=self.clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if self.free_noise_enabled:
prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise(
prompt=prompt,
num_frames=num_frames,
device=device,
num_videos_per_prompt=num_videos_per_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=self.clip_skip,
)
else:
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_videos_per_prompt,
self.do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=self.clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
@@ -990,6 +1012,9 @@ class AnimateDiffControlNetPipeline(
# 8. Denoising loop
with self.progress_bar(total=self._num_timesteps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -1002,7 +1027,6 @@ class AnimateDiffControlNetPipeline(
else:
control_model_input = latent_model_input
controlnet_prompt_embeds = prompt_embeds
controlnet_prompt_embeds = controlnet_prompt_embeds.repeat_interleave(num_frames, dim=0)
if isinstance(controlnet_keep[i], list):
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
@@ -1143,6 +1143,8 @@ class AnimateDiffSDXLPipeline(
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
prompt_embeds = prompt_embeds.to(device)
add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_videos_per_prompt, 1)
@@ -878,6 +878,8 @@ class AnimateDiffSparseControlNetPipeline(
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
# 4. Prepare IP-Adapter embeddings
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
@@ -246,7 +246,6 @@ class AnimateDiffVideoToVideoPipeline(
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
def encode_prompt(
self,
prompt,
@@ -299,7 +298,7 @@ class AnimateDiffVideoToVideoPipeline(
else:
scale_lora_layers(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
if prompt is not None and isinstance(prompt, (str, dict)):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
@@ -582,8 +581,8 @@ class AnimateDiffVideoToVideoPipeline(
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif prompt is not None and not isinstance(prompt, (str, list, dict)):
raise ValueError(f"`prompt` has to be of type `str`, `list` or `dict` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
@@ -628,23 +627,20 @@ class AnimateDiffVideoToVideoPipeline(
def prepare_latents(
self,
video,
height,
width,
num_channels_latents,
batch_size,
timestep,
dtype,
device,
generator,
latents=None,
video: Optional[torch.Tensor] = None,
height: int = 64,
width: int = 64,
num_channels_latents: int = 4,
batch_size: int = 1,
timestep: Optional[int] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
decode_chunk_size: int = 16,
):
if latents is None:
num_frames = video.shape[1]
else:
num_frames = latents.shape[2]
add_noise: bool = False,
) -> torch.Tensor:
num_frames = video.shape[1] if latents is None else latents.shape[2]
shape = (
batch_size,
num_channels_latents,
@@ -708,8 +704,13 @@ class AnimateDiffVideoToVideoPipeline(
if shape != latents.shape:
# [B, C, F, H, W]
raise ValueError(f"`latents` expected to have {shape=}, but found {latents.shape=}")
latents = latents.to(device, dtype=dtype)
if add_noise:
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self.scheduler.add_noise(latents, noise, timestep)
return latents
@property
@@ -735,6 +736,10 @@ class AnimateDiffVideoToVideoPipeline(
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
def __call__(
self,
@@ -743,6 +748,7 @@ class AnimateDiffVideoToVideoPipeline(
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
enforce_inference_steps: bool = False,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 7.5,
@@ -874,9 +880,10 @@ class AnimateDiffVideoToVideoPipeline(
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
if prompt is not None and isinstance(prompt, (str, dict)):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
@@ -884,29 +891,85 @@ class AnimateDiffVideoToVideoPipeline(
batch_size = prompt_embeds.shape[0]
device = self._execution_device
dtype = self.dtype
# 3. Encode input prompt
# 3. Prepare timesteps
if not enforce_inference_steps:
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas
)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
else:
denoising_inference_steps = int(num_inference_steps / strength)
timesteps, denoising_inference_steps = retrieve_timesteps(
self.scheduler, denoising_inference_steps, device, timesteps, sigmas
)
timesteps = timesteps[-num_inference_steps:]
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
# 4. Prepare latent variables
if latents is None:
video = self.video_processor.preprocess_video(video, height=height, width=width)
# Move the number of frames before the number of channels.
video = video.permute(0, 2, 1, 3, 4)
video = video.to(device=device, dtype=dtype)
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
video=video,
height=height,
width=width,
num_channels_latents=num_channels_latents,
batch_size=batch_size * num_videos_per_prompt,
timestep=latent_timestep,
dtype=dtype,
device=device,
generator=generator,
latents=latents,
decode_chunk_size=decode_chunk_size,
add_noise=enforce_inference_steps,
)
# 5. Encode input prompt
text_encoder_lora_scale = (
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_videos_per_prompt,
self.do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=self.clip_skip,
)
num_frames = latents.shape[2]
if self.free_noise_enabled:
prompt_embeds, negative_prompt_embeds = self._encode_prompt_free_noise(
prompt=prompt,
num_frames=num_frames,
device=device,
num_videos_per_prompt=num_videos_per_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=self.clip_skip,
)
else:
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_videos_per_prompt,
self.do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=self.clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
# 6. Prepare IP-Adapter embeddings
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image,
@@ -916,38 +979,10 @@ class AnimateDiffVideoToVideoPipeline(
self.do_classifier_free_guidance,
)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas
)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
# 5. Prepare latent variables
if latents is None:
video = self.video_processor.preprocess_video(video, height=height, width=width)
# Move the number of frames before the number of channels.
video = video.permute(0, 2, 1, 3, 4)
video = video.to(device=device, dtype=prompt_embeds.dtype)
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
video=video,
height=height,
width=width,
num_channels_latents=num_channels_latents,
batch_size=batch_size * num_videos_per_prompt,
timestep=latent_timestep,
dtype=prompt_embeds.dtype,
device=device,
generator=generator,
latents=latents,
decode_chunk_size=decode_chunk_size,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Add image embeds for IP-Adapter
# 8. Add image embeds for IP-Adapter
added_cond_kwargs = (
{"image_embeds": image_embeds}
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
@@ -967,9 +1002,12 @@ class AnimateDiffVideoToVideoPipeline(
self._num_timesteps = len(timesteps)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
# 8. Denoising loop
# 9. Denoising loop
with self.progress_bar(total=self._num_timesteps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -1005,14 +1043,14 @@ class AnimateDiffVideoToVideoPipeline(
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
# 9. Post-processing
# 10. Post-processing
if output_type == "latent":
video = latents
else:
video_tensor = self.decode_latents(latents, decode_chunk_size)
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)
# 10. Offload all models
# 11. Offload all models
self.maybe_free_model_hooks()
if not return_dict:
+18 -7
View File
@@ -29,7 +29,7 @@ from .controlnet import (
StableDiffusionXLControlNetPipeline,
)
from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
from .flux import FluxPipeline
from .flux import FluxControlNetPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxPipeline
from .hunyuandit import HunyuanDiTPipeline
from .kandinsky import (
KandinskyCombinedPipeline,
@@ -108,6 +108,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("pixart-sigma-pag", PixArtSigmaPAGPipeline),
("auraflow", AuraFlowPipeline),
("flux", FluxPipeline),
("flux-controlnet", FluxControlNetPipeline),
("lumina", LuminaText2ImgPipeline),
]
)
@@ -126,6 +127,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline),
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline),
("lcm", LatentConsistencyModelImg2ImgPipeline),
("flux", FluxImg2ImgPipeline),
]
)
@@ -140,6 +142,7 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
("stable-diffusion-controlnet", StableDiffusionControlNetInpaintPipeline),
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline),
("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline),
("flux", FluxInpaintPipeline),
]
)
@@ -660,12 +663,17 @@ class AutoPipelineForImage2Image(ConfigMixin):
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
orig_class_name = config["_class_name"]
# the `orig_class_name` can be:
# `- *Pipeline` (for regular text-to-image checkpoint)
# `- *Img2ImgPipeline` (for refiner checkpoint)
to_replace = "Img2ImgPipeline" if "Img2Img" in config["_class_name"] else "Pipeline"
if "controlnet" in kwargs:
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
if "enable_pag" in kwargs:
enable_pag = kwargs.pop("enable_pag")
if enable_pag:
orig_class_name = orig_class_name.replace("Pipeline", "PAGPipeline")
orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name)
@@ -952,14 +960,17 @@ class AutoPipelineForInpainting(ConfigMixin):
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
orig_class_name = config["_class_name"]
# The `orig_class_name`` can be:
# `- *InpaintPipeline` (for inpaint-specific checkpoint)
# - or *Pipeline (for regular text-to-image checkpoint)
to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline"
if "controlnet" in kwargs:
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
if "enable_pag" in kwargs:
enable_pag = kwargs.pop("enable_pag")
if enable_pag:
to_replace = "InpaintPipeline" if "Inpaint" in config["_class_name"] else "Pipeline"
orig_class_name = config["_class_name"].replace(to_replace, "PAG" + to_replace)
orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name)
kwargs = {**load_config_kwargs, **kwargs}
@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_cogvideox"] = ["CogVideoXPipeline"]
_import_structure["pipeline_cogvideox_video2video"] = ["CogVideoXVideoToVideoPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -33,6 +34,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_cogvideox import CogVideoXPipeline
from .pipeline_cogvideox_video2video import CogVideoXVideoToVideoPipeline
else:
import sys
@@ -15,7 +15,6 @@
import inspect
import math
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
@@ -26,9 +25,10 @@ from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
from ...models.embeddings import get_3d_rotary_pos_embed
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from ...utils import BaseOutput, logging, replace_example_docstring
from ...utils import logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from .pipeline_output import CogVideoXPipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -136,21 +136,6 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
@dataclass
class CogVideoXPipelineOutput(BaseOutput):
r"""
Output class for CogVideo pipelines.
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`.
"""
frames: torch.Tensor
class CogVideoXPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-video generation using CogVideoX.
@@ -463,7 +448,6 @@ class CogVideoXPipeline(DiffusionPipeline):
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
use_real=True,
)
freqs_cos = freqs_cos.to(device=device)
@@ -0,0 +1,812 @@
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import math
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
from PIL import Image
from transformers import T5EncoderModel, T5Tokenizer
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
from ...models.embeddings import get_3d_rotary_pos_embed
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
from ...utils import (
logging,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from .pipeline_output import CogVideoXPipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers import CogVideoXDPMScheduler, CogVideoXVideoToVideoPipeline
>>> from diffusers.utils import export_to_video, load_video
>>> # Models: "THUDM/CogVideoX-2b" or "THUDM/CogVideoX-5b"
>>> pipe = CogVideoXVideoToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)
>>> input_video = load_video(
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4"
... )
>>> prompt = (
... "An astronaut stands triumphantly at the peak of a towering mountain. Panorama of rugged peaks and "
... "valleys. Very futuristic vibe and animated aesthetic. Highlights of purple and golden colors in "
... "the scene. The sky is looks like an animated/cartoonish dream of galaxies, nebulae, stars, planets, "
... "moons, but the remainder of the scene is mostly realistic."
... )
>>> video = pipe(
... video=input_video, prompt=prompt, strength=0.8, guidance_scale=6, num_inference_steps=50
... ).frames[0]
>>> export_to_video(video, "output.mp4", fps=8)
```
"""
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
tw = tgt_width
th = tgt_height
h, w = src
r = h / w
if r > (th / tw):
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h))
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class CogVideoXVideoToVideoPipeline(DiffusionPipeline):
r"""
Pipeline for video-to-video generation using CogVideoX.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
text_encoder ([`T5EncoderModel`]):
Frozen text-encoder. CogVideoX uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
tokenizer (`T5Tokenizer`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
transformer ([`CogVideoXTransformer3DModel`]):
A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
"""
_optional_components = []
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
]
def __init__(
self,
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
vae: AutoencoderKLCogVideoX,
transformer: CogVideoXTransformer3DModel,
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
):
super().__init__()
self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor_spatial = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.vae_scale_factor_temporal = (
self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_videos_per_prompt: int = 1,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
return prompt_embeds
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = self._get_t5_prompt_embeds(
prompt=negative_prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
return prompt_embeds, negative_prompt_embeds
def prepare_latents(
self,
video: Optional[torch.Tensor] = None,
batch_size: int = 1,
num_channels_latents: int = 16,
height: int = 60,
width: int = 90,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.Tensor] = None,
timestep: Optional[torch.Tensor] = None,
):
num_frames = (video.size(2) - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.size(1)
shape = (
batch_size,
num_frames,
num_channels_latents,
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
if isinstance(generator, list):
if len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
init_latents = [
retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
]
else:
init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
init_latents = torch.cat(init_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
init_latents = self.vae.config.scaling_factor * init_latents
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self.scheduler.add_noise(init_latents, noise, timestep)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
latents = 1 / self.vae.config.scaling_factor * latents
frames = self.vae.decode(latents).sample
return frames
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, timesteps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = timesteps[t_start * self.scheduler.order :]
return timesteps, num_inference_steps - t_start
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(
self,
prompt,
height,
width,
strength,
negative_prompt,
callback_on_step_end_tensor_inputs,
video=None,
latents=None,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if video is not None and latents is not None:
raise ValueError("Only one of `video` or `latents` should be provided")
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.fuse_qkv_projections
def fuse_qkv_projections(self) -> None:
r"""Enables fused QKV projections."""
self.fusing_transformer = True
self.transformer.fuse_qkv_projections()
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.unfuse_qkv_projections
def unfuse_qkv_projections(self) -> None:
r"""Disable QKV projection fusion if enabled."""
if not self.fusing_transformer:
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
else:
self.transformer.unfuse_qkv_projections()
self.fusing_transformer = False
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings
def _prepare_rotary_positional_embeddings(
self,
height: int,
width: int,
num_frames: int,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height
)
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
)
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin
@property
def guidance_scale(self):
return self._guidance_scale
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
video: List[Image.Image] = None,
prompt: Optional[Union[str, List[str]]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: int = 480,
width: int = 720,
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
strength: float = 0.8,
guidance_scale: float = 6,
use_dynamic_cfg: bool = False,
num_videos_per_prompt: int = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: str = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 226,
) -> Union[CogVideoXPipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
Args:
video (`List[PIL.Image.Image]`):
The input video to condition the generation on. Must be a list of images/frames of the video.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
strength (`float`, *optional*, defaults to 0.8):
Higher strength leads to more differences between original video and generated video.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int`, defaults to `226`):
Maximum sequence length in encoded prompt. Must be consistent with
`self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
Examples:
Returns:
[`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] or `tuple`:
[`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
height = height or self.transformer.config.sample_size * self.vae_scale_factor_spatial
width = width or self.transformer.config.sample_size * self.vae_scale_factor_spatial
num_videos_per_prompt = 1
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
strength,
negative_prompt,
callback_on_step_end_tensor_inputs,
prompt_embeds,
negative_prompt_embeds,
)
self._guidance_scale = guidance_scale
self._interrupt = False
# 2. Default call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
negative_prompt,
do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
device=device,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
self._num_timesteps = len(timesteps)
# 5. Prepare latents
if latents is None:
video = self.video_processor.preprocess_video(video, height=height, width=width)
video = video.to(device=device, dtype=prompt_embeds.dtype)
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
video,
batch_size * num_videos_per_prompt,
latent_channels,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
latent_timestep,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Create rotary embeds if required
image_rotary_emb = (
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
if self.transformer.config.use_rotary_positional_embeddings
else None
)
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
# for DPM-solver++
old_pred_original_sample = None
for i, t in enumerate(timesteps):
if self.interrupt:
continue
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
# perform guidance
if use_dynamic_cfg:
self._guidance_scale = 1 + guidance_scale * (
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
)
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
else:
latents, old_pred_original_sample = self.scheduler.step(
noise_pred,
old_pred_original_sample,
t,
timesteps[i - 1] if i > 0 else None,
latents,
**extra_step_kwargs,
return_dict=False,
)
latents = latents.to(prompt_embeds.dtype)
# call the callback, if provided
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if not output_type == "latent":
video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return CogVideoXPipelineOutput(frames=video)
@@ -0,0 +1,20 @@
from dataclasses import dataclass
import torch
from diffusers.utils import BaseOutput
@dataclass
class CogVideoXPipelineOutput(BaseOutput):
r"""
Output class for CogVideo pipelines.
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`.
"""
frames: torch.Tensor
@@ -546,7 +546,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
)
elif encoder_hid_dim_type is not None:
raise ValueError(
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
f"`encoder_hid_dim_type`: {encoder_hid_dim_type} must be None, 'text_proj', 'text_image_proj' or 'image_proj'."
)
else:
self.encoder_hid_proj = None
+4
View File
@@ -24,6 +24,8 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["pipeline_flux"] = ["FluxPipeline"]
_import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"]
_import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"]
_import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
@@ -33,6 +35,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .pipeline_flux import FluxPipeline
from .pipeline_flux_controlnet import FluxControlNetPipeline
from .pipeline_flux_img2img import FluxImg2ImgPipeline
from .pipeline_flux_inpaint import FluxInpaintPipeline
else:
import sys
@@ -280,7 +280,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds
@@ -536,7 +536,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
width: Optional[int] = None,
num_inference_steps: int = 28,
timesteps: List[int] = None,
guidance_scale: float = 7.0,
guidance_scale: float = 3.5,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
@@ -13,7 +13,7 @@
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
@@ -27,7 +27,7 @@ from transformers import (
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin
from ...models.autoencoders import AutoencoderKL
from ...models.controlnet_flux import FluxControlNetModel
from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
@@ -61,7 +61,7 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers import FluxControlNetPipeline
>>> from diffusers import FluxControlNetModel
>>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny-alpha"
>>> controlnet_model = "InstantX/FLUX.1-dev-controlnet-canny"
>>> controlnet = FluxControlNetModel.from_pretrained(controlnet_model, torch_dtype=torch.bfloat16)
>>> pipe = FluxControlNetPipeline.from_pretrained(
... base_model, controlnet=controlnet, torch_dtype=torch.bfloat16
@@ -195,7 +195,9 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
text_encoder_2: T5EncoderModel,
tokenizer_2: T5TokenizerFast,
transformer: FluxTransformer2DModel,
controlnet: FluxControlNetModel,
controlnet: Union[
FluxControlNetModel, List[FluxControlNetModel], Tuple[FluxControlNetModel], FluxMultiControlNetModel
],
):
super().__init__()
@@ -300,7 +302,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds
@@ -571,6 +573,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
timesteps: List[int] = None,
guidance_scale: float = 7.0,
control_image: PipelineImageInput = None,
control_mode: Optional[Union[int, List[int]]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
@@ -611,6 +614,20 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
images must be passed as a list such that each element of the list can be correctly batched for input
to a single ControlNet.
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
the corresponding scale as a list.
control_mode (`int` or `List[int]`,, *optional*, defaults to None):
The control mode when applying ControlNet-Union.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -730,6 +747,55 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
width_control_image,
)
# set control mode
if control_mode is not None:
control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
control_mode = control_mode.reshape([-1, 1])
elif isinstance(self.controlnet, FluxMultiControlNetModel):
control_images = []
for control_image_ in control_image:
control_image_ = self.prepare_image(
image=control_image_,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=dtype,
)
height, width = control_image_.shape[-2:]
# vae encode
control_image_ = self.vae.encode(control_image_).latent_dist.sample()
control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
# pack
height_control_image, width_control_image = control_image_.shape[2:]
control_image_ = self._pack_latents(
control_image_,
batch_size * num_images_per_prompt,
num_channels_latents,
height_control_image,
width_control_image,
)
control_images.append(control_image_)
control_image = control_images
# set control mode
control_mode_ = []
if isinstance(control_mode, list):
for cmode in control_mode:
if cmode is None:
control_mode_.append(-1)
else:
control_mode_.append(cmode)
control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
control_mode = control_mode.reshape([-1, 1])
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
latents, latent_image_ids = self.prepare_latents(
@@ -785,6 +851,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
hidden_states=latents,
controlnet_cond=control_image,
controlnet_mode=control_mode,
conditioning_scale=controlnet_conditioning_scale,
timestep=timestep / 1000,
guidance=guidance,
@@ -0,0 +1,844 @@
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import FluxPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import FluxImg2ImgPipeline
>>> from diffusers.utils import load_image
>>> device = "cuda"
>>> pipe = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
>>> pipe = pipe.to(device)
>>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
>>> init_image = load_image(url).resize((1024, 1024))
>>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k"
>>> images = pipe(
... prompt=prompt, image=init_image, num_inference_steps=4, strength=0.95, guidance_scale=0.0
... ).images[0]
```
"""
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.16,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
r"""
The Flux pipeline for image inpainting.
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
Args:
transformer ([`FluxTransformer2DModel`]):
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
text_encoder_2 ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
tokenizer_2 (`T5TokenizerFast`):
Second Tokenizer of class
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
_optional_components = []
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
text_encoder_2: T5EncoderModel,
tokenizer_2: T5TokenizerFast,
transformer: FluxTransformer2DModel,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
self.default_sample_size = 64
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 512,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = self.tokenizer_2(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
dtype = self.text_encoder_2.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
def _get_clip_prompt_embeds(
self,
prompt: Union[str, List[str]],
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
):
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer_max_length,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer_max_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 512,
lora_scale: Optional[float] = None,
):
r"""
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in all text-encoders
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
device = device or self._execution_device
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if self.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
# We only use the pooled prompt output from the CLIPTextModel
pooled_prompt_embeds = self._get_clip_prompt_embeds(
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
)
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt_2,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
if self.text_encoder is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder_2, lora_scale)
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
return prompt_embeds, pooled_prompt_embeds, text_ids
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
return image_latents
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(num_inference_steps * strength, num_inference_steps)
t_start = int(max(num_inference_steps - init_timestep, 0))
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start
def check_inputs(
self,
prompt,
prompt_2,
strength,
height,
width,
prompt_embeds=None,
pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt_2 is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
)
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids.reshape(
latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
return latent_image_ids.to(device=device, dtype=dtype)
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
height = height // vae_scale_factor
width = width // vae_scale_factor
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
return latents
def prepare_latents(
self,
image,
timestep,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
height = 2 * (int(height) // self.vae_scale_factor)
width = 2 * (int(width) // self.vae_scale_factor)
shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
if latents is not None:
return latents.to(device=device, dtype=dtype), latent_image_ids
image = image.to(device=device, dtype=dtype)
image_latents = self._encode_vae_image(image=image, generator=generator)
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand init_latents for batch_size
additional_image_per_prompt = batch_size // image_latents.shape[0]
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
)
else:
image_latents = torch.cat([image_latents], dim=0)
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
return latents, latent_image_ids
@property
def guidance_scale(self):
return self._guidance_scale
@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
image: PipelineImageInput = None,
height: Optional[int] = None,
width: Optional[int] = None,
strength: float = 0.6,
num_inference_steps: int = 28,
timesteps: List[int] = None,
guidance_scale: float = 7.0,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
will be used instead
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
latents as `image`, but if passing latents directly it is not encoded again.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
strength (`float`, *optional*, defaults to 1.0):
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
essentially ignores `image`.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
images.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
strength,
height,
width,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
# 2. Preprocess image
init_image = self.image_processor.preprocess(image, height=height, width=width)
init_image = init_image.to(dtype=torch.float32)
# 3. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
(
prompt_embeds,
pooled_prompt_embeds,
text_ids,
) = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
# 4.Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas,
mu=mu,
)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
if num_inference_steps < 1:
raise ValueError(
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
latents, latent_image_ids = self.prepare_latents(
init_image,
latent_timestep,
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# handle guidance
if self.transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
image = latents
else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return FluxPipelineOutput(images=image)
File diff suppressed because it is too large Load Diff
+365 -5
View File
@@ -12,16 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from ..models.attention import BasicTransformerBlock, FreeNoiseTransformerBlock
from ..models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
from ..models.transformers.transformer_2d import Transformer2DModel
from ..models.unets.unet_motion_model import (
AnimateDiffTransformer3D,
CrossAttnDownBlockMotion,
DownBlockMotion,
UpBlockMotion,
)
from ..pipelines.pipeline_utils import DiffusionPipeline
from ..utils import logging
from ..utils.torch_utils import randn_tensor
@@ -29,6 +34,114 @@ from ..utils.torch_utils import randn_tensor
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class SplitInferenceModule(nn.Module):
r"""
A wrapper module class that splits inputs along a specified dimension before performing a forward pass.
This module is useful when you need to perform inference on large tensors in a memory-efficient way by breaking
them into smaller chunks, processing each chunk separately, and then reassembling the results.
Args:
module (`nn.Module`):
The underlying PyTorch module that will be applied to each chunk of split inputs.
split_size (`int`, defaults to `1`):
The size of each chunk after splitting the input tensor.
split_dim (`int`, defaults to `0`):
The dimension along which the input tensors are split.
input_kwargs_to_split (`List[str]`, defaults to `["hidden_states"]`):
A list of keyword arguments (strings) that represent the input tensors to be split.
Workflow:
1. The keyword arguments specified in `input_kwargs_to_split` are split into smaller chunks using
`torch.split()` along the dimension `split_dim` and with a chunk size of `split_size`.
2. The `module` is invoked once for each split with both the split inputs and any unchanged arguments
that were passed.
3. The output tensors from each split are concatenated back together along `split_dim` before returning.
Example:
```python
>>> import torch
>>> import torch.nn as nn
>>> model = nn.Linear(1000, 1000)
>>> split_module = SplitInferenceModule(model, split_size=2, split_dim=0, input_kwargs_to_split=["input"])
>>> input_tensor = torch.randn(42, 1000)
>>> # Will split the tensor into 21 slices of shape [2, 1000].
>>> output = split_module(input=input_tensor)
```
It is also possible to nest `SplitInferenceModule` across different split dimensions for more complex
multi-dimensional splitting.
"""
def __init__(
self,
module: nn.Module,
split_size: int = 1,
split_dim: int = 0,
input_kwargs_to_split: List[str] = ["hidden_states"],
) -> None:
super().__init__()
self.module = module
self.split_size = split_size
self.split_dim = split_dim
self.input_kwargs_to_split = set(input_kwargs_to_split)
def forward(self, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
r"""Forward method for the `SplitInferenceModule`.
This method processes the input by splitting specified keyword arguments along a given dimension, running the
underlying module on each split, and then concatenating the results. The splitting is controlled by the
`split_size` and `split_dim` parameters specified during initialization.
Args:
*args (`Any`):
Positional arguments that are passed directly to the `module` without modification.
**kwargs (`Dict[str, torch.Tensor]`):
Keyword arguments passed to the underlying `module`. Only keyword arguments whose names match the
entries in `input_kwargs_to_split` and are of type `torch.Tensor` will be split. The remaining keyword
arguments are passed unchanged.
Returns:
`Union[torch.Tensor, Tuple[torch.Tensor]]`:
The outputs obtained from `SplitInferenceModule` are the same as if the underlying module was inferred
without it.
- If the underlying module returns a single tensor, the result will be a single concatenated tensor
along the same `split_dim` after processing all splits.
- If the underlying module returns a tuple of tensors, each element of the tuple will be concatenated
along the `split_dim` across all splits, and the final result will be a tuple of concatenated tensors.
"""
split_inputs = {}
# 1. Split inputs that were specified during initialization and also present in passed kwargs
for key in list(kwargs.keys()):
if key not in self.input_kwargs_to_split or not torch.is_tensor(kwargs[key]):
continue
split_inputs[key] = torch.split(kwargs[key], self.split_size, self.split_dim)
kwargs.pop(key)
# 2. Invoke forward pass across each split
results = []
for split_input in zip(*split_inputs.values()):
inputs = dict(zip(split_inputs.keys(), split_input))
inputs.update(kwargs)
intermediate_tensor_or_tensor_tuple = self.module(*args, **inputs)
results.append(intermediate_tensor_or_tensor_tuple)
# 3. Concatenate split restuls to obtain final outputs
if isinstance(results[0], torch.Tensor):
return torch.cat(results, dim=self.split_dim)
elif isinstance(results[0], tuple):
return tuple([torch.cat(x, dim=self.split_dim) for x in zip(*results)])
else:
raise ValueError(
"In order to use the SplitInferenceModule, it is necessary for the underlying `module` to either return a torch.Tensor or a tuple of torch.Tensor's."
)
class AnimateDiffFreeNoiseMixin:
r"""Mixin class for [FreeNoise](https://arxiv.org/abs/2310.15169)."""
@@ -69,6 +182,9 @@ class AnimateDiffFreeNoiseMixin:
motion_module.transformer_blocks[i].load_state_dict(
basic_transfomer_block.state_dict(), strict=True
)
motion_module.transformer_blocks[i].set_chunk_feed_forward(
basic_transfomer_block._chunk_size, basic_transfomer_block._chunk_dim
)
def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]):
r"""Helper function to disable FreeNoise in transformer blocks."""
@@ -97,6 +213,145 @@ class AnimateDiffFreeNoiseMixin:
motion_module.transformer_blocks[i].load_state_dict(
free_noise_transfomer_block.state_dict(), strict=True
)
motion_module.transformer_blocks[i].set_chunk_feed_forward(
free_noise_transfomer_block._chunk_size, free_noise_transfomer_block._chunk_dim
)
def _check_inputs_free_noise(
self,
prompt,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
num_frames,
) -> None:
if not isinstance(prompt, (str, dict)):
raise ValueError(f"Expected `prompt` to have type `str` or `dict` but found {type(prompt)=}")
if negative_prompt is not None:
if not isinstance(negative_prompt, (str, dict)):
raise ValueError(
f"Expected `negative_prompt` to have type `str` or `dict` but found {type(negative_prompt)=}"
)
if prompt_embeds is not None or negative_prompt_embeds is not None:
raise ValueError("`prompt_embeds` and `negative_prompt_embeds` is not supported in FreeNoise yet.")
frame_indices = [isinstance(x, int) for x in prompt.keys()]
frame_prompts = [isinstance(x, str) for x in prompt.values()]
min_frame = min(list(prompt.keys()))
max_frame = max(list(prompt.keys()))
if not all(frame_indices):
raise ValueError("Expected integer keys in `prompt` dict for FreeNoise.")
if not all(frame_prompts):
raise ValueError("Expected str values in `prompt` dict for FreeNoise.")
if min_frame != 0:
raise ValueError("The minimum frame index in `prompt` dict must be 0 as a starting prompt is necessary.")
if max_frame >= num_frames:
raise ValueError(
f"The maximum frame index in `prompt` dict must be lesser than {num_frames=} and follow 0-based indexing."
)
def _encode_prompt_free_noise(
self,
prompt: Union[str, Dict[int, str]],
num_frames: int,
device: torch.device,
num_videos_per_prompt: int,
do_classifier_free_guidance: bool,
negative_prompt: Optional[Union[str, Dict[int, str]]] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None,
) -> torch.Tensor:
if negative_prompt is None:
negative_prompt = ""
# Ensure that we have a dictionary of prompts
if isinstance(prompt, str):
prompt = {0: prompt}
if isinstance(negative_prompt, str):
negative_prompt = {0: negative_prompt}
self._check_inputs_free_noise(prompt, negative_prompt, prompt_embeds, negative_prompt_embeds, num_frames)
# Sort the prompts based on frame indices
prompt = dict(sorted(prompt.items()))
negative_prompt = dict(sorted(negative_prompt.items()))
# Ensure that we have a prompt for the last frame index
prompt[num_frames - 1] = prompt[list(prompt.keys())[-1]]
negative_prompt[num_frames - 1] = negative_prompt[list(negative_prompt.keys())[-1]]
frame_indices = list(prompt.keys())
frame_prompts = list(prompt.values())
frame_negative_indices = list(negative_prompt.keys())
frame_negative_prompts = list(negative_prompt.values())
# Generate and interpolate positive prompts
prompt_embeds, _ = self.encode_prompt(
prompt=frame_prompts,
device=device,
num_images_per_prompt=num_videos_per_prompt,
do_classifier_free_guidance=False,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
lora_scale=lora_scale,
clip_skip=clip_skip,
)
shape = (num_frames, *prompt_embeds.shape[1:])
prompt_interpolation_embeds = prompt_embeds.new_zeros(shape)
for i in range(len(frame_indices) - 1):
start_frame = frame_indices[i]
end_frame = frame_indices[i + 1]
start_tensor = prompt_embeds[i].unsqueeze(0)
end_tensor = prompt_embeds[i + 1].unsqueeze(0)
prompt_interpolation_embeds[start_frame : end_frame + 1] = self._free_noise_prompt_interpolation_callback(
start_frame, end_frame, start_tensor, end_tensor
)
# Generate and interpolate negative prompts
negative_prompt_embeds = None
negative_prompt_interpolation_embeds = None
if do_classifier_free_guidance:
_, negative_prompt_embeds = self.encode_prompt(
prompt=[""] * len(frame_negative_prompts),
device=device,
num_images_per_prompt=num_videos_per_prompt,
do_classifier_free_guidance=True,
negative_prompt=frame_negative_prompts,
prompt_embeds=None,
negative_prompt_embeds=None,
lora_scale=lora_scale,
clip_skip=clip_skip,
)
negative_prompt_interpolation_embeds = negative_prompt_embeds.new_zeros(shape)
for i in range(len(frame_negative_indices) - 1):
start_frame = frame_negative_indices[i]
end_frame = frame_negative_indices[i + 1]
start_tensor = negative_prompt_embeds[i].unsqueeze(0)
end_tensor = negative_prompt_embeds[i + 1].unsqueeze(0)
negative_prompt_interpolation_embeds[
start_frame : end_frame + 1
] = self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor)
prompt_embeds = prompt_interpolation_embeds
negative_prompt_embeds = negative_prompt_interpolation_embeds
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
return prompt_embeds, negative_prompt_embeds
def _prepare_latents_free_noise(
self,
@@ -172,12 +427,29 @@ class AnimateDiffFreeNoiseMixin:
latents = latents[:, :, :num_frames]
return latents
def _lerp(
self, start_index: int, end_index: int, start_tensor: torch.Tensor, end_tensor: torch.Tensor
) -> torch.Tensor:
num_indices = end_index - start_index + 1
interpolated_tensors = []
for i in range(num_indices):
alpha = i / (num_indices - 1)
interpolated_tensor = (1 - alpha) * start_tensor + alpha * end_tensor
interpolated_tensors.append(interpolated_tensor)
interpolated_tensors = torch.cat(interpolated_tensors)
return interpolated_tensors
def enable_free_noise(
self,
context_length: Optional[int] = 16,
context_stride: int = 4,
weighting_scheme: str = "pyramid",
noise_type: str = "shuffle_context",
prompt_interpolation_callback: Optional[
Callable[[DiffusionPipeline, int, int, torch.Tensor, torch.Tensor], torch.Tensor]
] = None,
) -> None:
r"""
Enable long video generation using FreeNoise.
@@ -195,13 +467,27 @@ class AnimateDiffFreeNoiseMixin:
weighting_scheme (`str`, defaults to `pyramid`):
Weighting scheme for averaging latents after accumulation in FreeNoise blocks. The following weighting
schemes are supported currently:
- "flat"
Performs weighting averaging with a flat weight pattern: [1, 1, 1, 1, 1].
- "pyramid"
Peforms weighted averaging with a pyramid like weight pattern: [1, 2, 3, 2, 1].
Performs weighted averaging with a pyramid like weight pattern: [1, 2, 3, 2, 1].
- "delayed_reverse_sawtooth"
Performs weighted averaging with low weights for earlier frames and high-to-low weights for
later frames: [0.01, 0.01, 3, 2, 1].
noise_type (`str`, defaults to "shuffle_context"):
TODO
Must be one of ["shuffle_context", "repeat_context", "random"].
- "shuffle_context"
Shuffles a fixed batch of `context_length` latents to create a final latent of size
`num_frames`. This is usually the best setting for most generation scenarious. However, there
might be visible repetition noticeable in the kinds of motion/animation generated.
- "repeated_context"
Repeats a fixed batch of `context_length` latents to create a final latent of size
`num_frames`.
- "random"
The final latents are random without any repetition.
"""
allowed_weighting_scheme = ["pyramid"]
allowed_weighting_scheme = ["flat", "pyramid", "delayed_reverse_sawtooth"]
allowed_noise_type = ["shuffle_context", "repeat_context", "random"]
if context_length > self.motion_adapter.config.motion_max_seq_length:
@@ -219,18 +505,92 @@ class AnimateDiffFreeNoiseMixin:
self._free_noise_context_stride = context_stride
self._free_noise_weighting_scheme = weighting_scheme
self._free_noise_noise_type = noise_type
self._free_noise_prompt_interpolation_callback = prompt_interpolation_callback or self._lerp
if hasattr(self.unet.mid_block, "motion_modules"):
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
else:
blocks = [*self.unet.down_blocks, *self.unet.up_blocks]
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
for block in blocks:
self._enable_free_noise_in_block(block)
def disable_free_noise(self) -> None:
r"""Disable the FreeNoise sampling mechanism."""
self._free_noise_context_length = None
if hasattr(self.unet.mid_block, "motion_modules"):
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
else:
blocks = [*self.unet.down_blocks, *self.unet.up_blocks]
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
for block in blocks:
self._disable_free_noise_in_block(block)
def _enable_split_inference_motion_modules_(
self, motion_modules: List[AnimateDiffTransformer3D], spatial_split_size: int
) -> None:
for motion_module in motion_modules:
motion_module.proj_in = SplitInferenceModule(motion_module.proj_in, spatial_split_size, 0, ["input"])
for i in range(len(motion_module.transformer_blocks)):
motion_module.transformer_blocks[i] = SplitInferenceModule(
motion_module.transformer_blocks[i],
spatial_split_size,
0,
["hidden_states", "encoder_hidden_states"],
)
motion_module.proj_out = SplitInferenceModule(motion_module.proj_out, spatial_split_size, 0, ["input"])
def _enable_split_inference_attentions_(
self, attentions: List[Transformer2DModel], temporal_split_size: int
) -> None:
for i in range(len(attentions)):
attentions[i] = SplitInferenceModule(
attentions[i], temporal_split_size, 0, ["hidden_states", "encoder_hidden_states"]
)
def _enable_split_inference_resnets_(self, resnets: List[ResnetBlock2D], temporal_split_size: int) -> None:
for i in range(len(resnets)):
resnets[i] = SplitInferenceModule(resnets[i], temporal_split_size, 0, ["input_tensor", "temb"])
def _enable_split_inference_samplers_(
self, samplers: Union[List[Downsample2D], List[Upsample2D]], temporal_split_size: int
) -> None:
for i in range(len(samplers)):
samplers[i] = SplitInferenceModule(samplers[i], temporal_split_size, 0, ["hidden_states"])
def enable_free_noise_split_inference(self, spatial_split_size: int = 256, temporal_split_size: int = 16) -> None:
r"""
Enable FreeNoise memory optimizations by utilizing
[`~diffusers.pipelines.free_noise_utils.SplitInferenceModule`] across different intermediate modeling blocks.
Args:
spatial_split_size (`int`, defaults to `256`):
The split size across spatial dimensions for internal blocks. This is used in facilitating split
inference across the effective batch dimension (`[B x H x W, F, C]`) of intermediate tensors in motion
modeling blocks.
temporal_split_size (`int`, defaults to `16`):
The split size across temporal dimensions for internal blocks. This is used in facilitating split
inference across the effective batch dimension (`[B x F, H x W, C]`) of intermediate tensors in spatial
attention, resnets, downsampling and upsampling blocks.
"""
# TODO(aryan): Discuss on what's the best way to provide more control to users
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
for block in blocks:
if getattr(block, "motion_modules", None) is not None:
self._enable_split_inference_motion_modules_(block.motion_modules, spatial_split_size)
if getattr(block, "attentions", None) is not None:
self._enable_split_inference_attentions_(block.attentions, temporal_split_size)
if getattr(block, "resnets", None) is not None:
self._enable_split_inference_resnets_(block.resnets, temporal_split_size)
if getattr(block, "downsamplers", None) is not None:
self._enable_split_inference_samplers_(block.downsamplers, temporal_split_size)
if getattr(block, "upsamplers", None) is not None:
self._enable_split_inference_samplers_(block.upsamplers, temporal_split_size)
@property
def free_noise_enabled(self):
return hasattr(self, "_free_noise_context_length") and self._free_noise_context_length is not None
@@ -547,7 +547,7 @@ class KandinskyV22Img2ImgCombinedPipeline(DiffusionPipeline):
negative_image_embeds = prior_outputs[1]
prompt = [prompt] if not isinstance(prompt, (list, tuple)) else prompt
image = [image] if isinstance(prompt, PIL.Image.Image) else image
image = [image] if isinstance(image, PIL.Image.Image) else image
if len(prompt) < image_embeds.shape[0] and image_embeds.shape[0] % len(prompt) == 0:
prompt = (image_embeds.shape[0] // len(prompt)) * prompt
@@ -813,7 +813,7 @@ class KandinskyV22InpaintCombinedPipeline(DiffusionPipeline):
negative_image_embeds = prior_outputs[1]
prompt = [prompt] if not isinstance(prompt, (list, tuple)) else prompt
image = [image] if isinstance(prompt, PIL.Image.Image) else image
image = [image] if isinstance(image, PIL.Image.Image) else image
mask_image = [mask_image] if isinstance(mask_image, PIL.Image.Image) else mask_image
if len(prompt) < image_embeds.shape[0] and image_embeds.shape[0] % len(prompt) == 0:
@@ -734,6 +734,8 @@ class AnimateDiffPAGPipeline(
elif self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image,
@@ -805,7 +807,9 @@ class AnimateDiffPAGPipeline(
with self.progress_bar(total=self._num_timesteps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0]))
latent_model_input = torch.cat(
[latents] * (prompt_embeds.shape[0] // num_frames // latents.shape[0])
)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
@@ -824,6 +824,8 @@ class PIAPipeline(
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
prompt_embeds = prompt_embeds.repeat_interleave(repeats=num_frames, dim=0)
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image,
@@ -1744,6 +1744,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
for name, component in pipeline.components.items():
if name in expected_modules and name not in passed_class_obj:
# for model components, we will not switch over if the class does not matches the type hint in the new pipeline's signature
component = component._orig_mod if is_compiled_module(component) else component
if (
not isinstance(component, ModelMixin)
or type(component) in component_types[name]
@@ -25,7 +25,7 @@ from transformers import (
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import SD3LoraLoaderMixin
from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import SD3Transformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -148,7 +148,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
class StableDiffusion3InpaintPipeline(DiffusionPipeline):
class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin):
r"""
Args:
transformer ([`SD3Transformer2DModel`]):
+20 -3
View File
@@ -1,5 +1,6 @@
import contextlib
import copy
import gc
import math
import random
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
@@ -259,6 +260,22 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
return weighting
def clear_objs_and_retain_memory(objs: List[Any]):
"""Deletes `objs` and runs garbage collection. Then clears the cache of the available accelerator."""
if len(objs) >= 1:
for obj in objs:
del obj
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif torch.backends.mps.is_available():
torch.mps.empty_cache()
elif is_torch_npu_available():
torch_npu.empty_cache()
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
class EMAModel:
"""
@@ -418,11 +435,11 @@ class EMAModel:
one_minus_decay = 1 - decay
context_manager = contextlib.nullcontext
if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
import deepspeed
if self.foreach:
if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
context_manager = deepspeed.zero.GatheredParameters(parameters, modifier_rank=None)
with context_manager():
@@ -444,7 +461,7 @@ class EMAModel:
else:
for s_param, param in zip(self.shadow_params, parameters):
if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)
with context_manager():
+15
View File
@@ -197,6 +197,21 @@ class FluxControlNetModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class FluxMultiControlNetModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class FluxTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -272,6 +272,21 @@ class CogVideoXPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class CogVideoXVideoToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class CycleDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -302,6 +317,36 @@ class FluxControlNetPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class FluxImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class FluxInpaintPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class FluxPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
+2 -1
View File
@@ -157,11 +157,12 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
if ("adapter-1" in n or "adapter-2" in n) and not isinstance(m, (nn.Dropout, nn.Identity)):
self.assertTrue(m.weight.device != torch.device("cpu"))
@slow
@require_torch_gpu
def test_integration_move_lora_dora_cpu(self):
from peft import LoraConfig
path = "runwayml/stable-diffusion-v1-5"
path = "Lykon/dreamshaper-8"
unet_lora_config = LoraConfig(
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
+8 -4
View File
@@ -528,6 +528,10 @@ class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCa
def test_forward_with_norm_groups(self):
pass
@unittest.skip("No attention module used in this model")
def test_set_attn_processor_for_determinism(self):
return
@slow
class AutoencoderTinyIntegrationTests(unittest.TestCase):
@@ -1032,9 +1036,9 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/img2img/sketch-mountains-input.jpg"
).resize((256, 256))
image = torch.from_numpy(np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1)[
None, :, :, :
].cuda()
image = torch.from_numpy(np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1)[None, :, :, :].to(
torch_device
)
latent = vae.encode(image).latent_dist.mean
@@ -1075,7 +1079,7 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
image = (
torch.from_numpy(np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1)[None, :, :, :]
.half()
.cuda()
.to(torch_device)
)
latent = vae.encode(image).latent_dist.mean
-11
View File
@@ -55,17 +55,6 @@ class EmbeddingsTests(unittest.TestCase):
assert grad > prev_grad
prev_grad = grad
def test_timestep_defaults(self):
embedding_dim = 16
timesteps = torch.arange(10)
t1 = get_timestep_embedding(timesteps, embedding_dim)
t2 = get_timestep_embedding(
timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, max_period=10_000
)
assert torch.allclose(t1.cpu(), t2.cpu(), 1e-3)
def test_timestep_flip_sin_cos(self):
embedding_dim = 16
timesteps = torch.arange(10)
+4 -11
View File
@@ -183,17 +183,6 @@ class ModelUtilsTest(unittest.TestCase):
class UNetTesterMixin:
def test_forward_signature(self):
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["sample", "timestep"]
self.assertListEqual(arg_names[:2], expected_arg_names)
def test_forward_with_norm_groups(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -220,6 +209,7 @@ class ModelTesterMixin:
base_precision = 1e-3
forward_requires_fresh_args = False
model_split_percents = [0.5, 0.7, 0.9]
uses_custom_attn_processor = False
def check_device_map_is_respected(self, model, device_map):
for param_name, param in model.named_parameters():
@@ -417,6 +407,9 @@ class ModelTesterMixin:
@require_torch_gpu
def test_set_attn_processor_for_determinism(self):
if self.uses_custom_attn_processor:
return
torch.use_deterministic_algorithms(False)
if self.forward_requires_fresh_args:
model = self.model_class(**self.init_dict)
@@ -32,6 +32,7 @@ enable_full_determinism()
class CogVideoXTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = CogVideoXTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def dummy_input(self):
@@ -32,6 +32,9 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
# We override the items here because the transformer under consideration is small.
model_split_percents = [0.7, 0.6, 0.6]
# Skip setting testing with default: AttnProcessor
uses_custom_attn_processor = True
@property
def dummy_input(self):
batch_size = 1
@@ -32,6 +32,7 @@ enable_full_determinism()
class LuminaNextDiT2DModelTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = LuminaNextDiT2DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def dummy_input(self):
@@ -51,7 +51,7 @@ class UNetMotionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase)
noise = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 16)).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size * num_frames, 4, 16)).to(torch_device)
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
@@ -175,7 +175,7 @@ class AnimateDiffPipelineFastTests(
def test_attention_slicing_forward_pass(self):
pass
def test_ip_adapter_single(self):
def test_ip_adapter(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array(
@@ -209,7 +209,7 @@ class AnimateDiffPipelineFastTests(
0.5620,
]
)
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
def test_dict_tuple_outputs_equivalent(self):
expected_slice = None
@@ -460,6 +460,53 @@ class AnimateDiffPipelineFastTests(
"Disabling of FreeNoise should lead to results similar to the default pipeline results",
)
def test_free_noise_split_inference(self):
components = self.get_dummy_components()
pipe: AnimateDiffPipeline = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)
pipe.enable_free_noise(8, 4)
inputs_normal = self.get_dummy_inputs(torch_device)
frames_normal = pipe(**inputs_normal).frames[0]
# Test FreeNoise with split inference memory-optimization
pipe.enable_free_noise_split_inference(spatial_split_size=16, temporal_split_size=4)
inputs_enable_split_inference = self.get_dummy_inputs(torch_device)
frames_enable_split_inference = pipe(**inputs_enable_split_inference).frames[0]
sum_split_inference = np.abs(to_np(frames_normal) - to_np(frames_enable_split_inference)).sum()
self.assertLess(
sum_split_inference,
1e-4,
"Enabling FreeNoise Split Inference memory-optimizations should lead to results similar to the default pipeline results",
)
def test_free_noise_multi_prompt(self):
components = self.get_dummy_components()
pipe: AnimateDiffPipeline = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)
context_length = 8
context_stride = 4
pipe.enable_free_noise(context_length, context_stride)
# Make sure that pipeline works when prompt indices are within num_frames bounds
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf"}
inputs["num_frames"] = 16
pipe(**inputs).frames[0]
with self.assertRaises(ValueError):
# Ensure that prompt indices are within bounds
inputs = self.get_dummy_inputs(torch_device)
inputs["num_frames"] = 16
inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf", 42: "Error on a leaf"}
pipe(**inputs).frames[0]
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
@@ -193,7 +193,7 @@ class AnimateDiffControlNetPipelineFastTests(
def test_attention_slicing_forward_pass(self):
pass
def test_ip_adapter_single(self):
def test_ip_adapter(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array(
@@ -218,7 +218,7 @@ class AnimateDiffControlNetPipelineFastTests(
0.5155,
]
)
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
def test_dict_tuple_outputs_equivalent(self):
expected_slice = None
@@ -476,6 +476,27 @@ class AnimateDiffControlNetPipelineFastTests(
"Disabling of FreeNoise should lead to results similar to the default pipeline results",
)
def test_free_noise_multi_prompt(self):
components = self.get_dummy_components()
pipe: AnimateDiffControlNetPipeline = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)
context_length = 8
context_stride = 4
pipe.enable_free_noise(context_length, context_stride)
# Make sure that pipeline works when prompt indices are within num_frames bounds
inputs = self.get_dummy_inputs(torch_device, num_frames=16)
inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf"}
pipe(**inputs).frames[0]
with self.assertRaises(ValueError):
# Ensure that prompt indices are within bounds
inputs = self.get_dummy_inputs(torch_device, num_frames=16)
inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf", 42: "Error on a leaf"}
pipe(**inputs).frames[0]
def test_vae_slicing(self, video_count=2):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
@@ -195,7 +195,7 @@ class AnimateDiffSparseControlNetPipelineFastTests(
def test_attention_slicing_forward_pass(self):
pass
def test_ip_adapter_single(self):
def test_ip_adapter(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array(
@@ -220,7 +220,7 @@ class AnimateDiffSparseControlNetPipelineFastTests(
0.5155,
]
)
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
def test_dict_tuple_outputs_equivalent(self):
expected_slice = None
@@ -175,7 +175,7 @@ class AnimateDiffVideoToVideoPipelineFastTests(
def test_attention_slicing_forward_pass(self):
pass
def test_ip_adapter_single(self):
def test_ip_adapter(self):
expected_pipe_slice = None
if torch_device == "cpu":
@@ -201,7 +201,7 @@ class AnimateDiffVideoToVideoPipelineFastTests(
0.5378,
]
)
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
def test_inference_batch_single_identical(
self,
@@ -491,3 +491,56 @@ class AnimateDiffVideoToVideoPipelineFastTests(
1e-4,
"Disabling of FreeNoise should lead to results similar to the default pipeline results",
)
def test_free_noise_split_inference(self):
components = self.get_dummy_components()
pipe: AnimateDiffVideoToVideoPipeline = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)
pipe.enable_free_noise(8, 4)
inputs_normal = self.get_dummy_inputs(torch_device, num_frames=16)
inputs_normal["num_inference_steps"] = 2
inputs_normal["strength"] = 0.5
frames_normal = pipe(**inputs_normal).frames[0]
# Test FreeNoise with split inference memory-optimization
pipe.enable_free_noise_split_inference(spatial_split_size=16, temporal_split_size=4)
inputs_enable_split_inference = self.get_dummy_inputs(torch_device, num_frames=16)
inputs_enable_split_inference["num_inference_steps"] = 2
inputs_enable_split_inference["strength"] = 0.5
frames_enable_split_inference = pipe(**inputs_enable_split_inference).frames[0]
sum_split_inference = np.abs(to_np(frames_normal) - to_np(frames_enable_split_inference)).sum()
self.assertLess(
sum_split_inference,
1e-4,
"Enabling FreeNoise Split Inference memory-optimizations should lead to results similar to the default pipeline results",
)
def test_free_noise_multi_prompt(self):
components = self.get_dummy_components()
pipe: AnimateDiffVideoToVideoPipeline = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)
context_length = 8
context_stride = 4
pipe.enable_free_noise(context_length, context_stride)
# Make sure that pipeline works when prompt indices are within num_frames bounds
inputs = self.get_dummy_inputs(torch_device, num_frames=16)
inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf"}
inputs["num_inference_steps"] = 2
inputs["strength"] = 0.5
pipe(**inputs).frames[0]
with self.assertRaises(ValueError):
# Ensure that prompt indices are within bounds
inputs = self.get_dummy_inputs(torch_device, num_frames=16)
inputs["num_inference_steps"] = 2
inputs["strength"] = 0.5
inputs["prompt"] = {0: "Caterpillar on a leaf", 10: "Butterfly on a leaf", 42: "Error on a leaf"}
pipe(**inputs).frames[0]
@@ -0,0 +1,328 @@
# Copyright 2024 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import unittest
import numpy as np
import torch
from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel, CogVideoXVideoToVideoPipeline, DDIMScheduler
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
PipelineTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist,
to_np,
)
enable_full_determinism()
class CogVideoXVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = CogVideoXVideoToVideoPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"video"})
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
def get_dummy_components(self):
torch.manual_seed(0)
transformer = CogVideoXTransformer3DModel(
# Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings
# But, since we are using tiny-random-t5 here, we need the internal dim of CogVideoXTransformer3DModel
# to be 32. The internal dim is product of num_attention_heads and attention_head_dim
num_attention_heads=4,
attention_head_dim=8,
in_channels=4,
out_channels=4,
time_embed_dim=2,
text_embed_dim=32, # Must match with tiny-random-t5
num_layers=1,
sample_width=16, # latent width: 2 -> final width: 16
sample_height=16, # latent height: 2 -> final height: 16
sample_frames=9, # latent frames: (9 - 1) / 4 + 1 = 3 -> final frames: 9
patch_size=2,
temporal_compression_ratio=4,
max_text_seq_length=16,
)
torch.manual_seed(0)
vae = AutoencoderKLCogVideoX(
in_channels=3,
out_channels=3,
down_block_types=(
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
"CogVideoXDownBlock3D",
),
up_block_types=(
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
"CogVideoXUpBlock3D",
),
block_out_channels=(8, 8, 8, 8),
latent_channels=4,
layers_per_block=1,
norm_num_groups=2,
temporal_compression_ratio=4,
)
torch.manual_seed(0)
scheduler = DDIMScheduler()
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
}
return components
def get_dummy_inputs(self, device, seed: int = 0, num_frames: int = 8):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
video_height = 16
video_width = 16
video = [Image.new("RGB", (video_width, video_height))] * num_frames
inputs = {
"video": video,
"prompt": "dance monkey",
"negative_prompt": "",
"generator": generator,
"num_inference_steps": 2,
"strength": 0.5,
"guidance_scale": 6.0,
# Cannot reduce because convolution kernel becomes bigger than sample
"height": video_height,
"width": video_width,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (8, 3, 16, 16))
expected_video = torch.randn(8, 3, 16, 16)
max_diff = np.abs(generated_video - expected_video).max()
self.assertLessEqual(max_diff, 1e10)
def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__)
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
has_callback_step_end = "callback_on_step_end" in sig.parameters
if not (has_callback_tensor_inputs and has_callback_step_end):
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
self.assertTrue(
hasattr(pipe, "_callback_tensor_inputs"),
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
)
def callback_inputs_subset(pipe, i, t, callback_kwargs):
# iterate over callback args
for tensor_name, tensor_value in callback_kwargs.items():
# check that we're only passing in allowed tensor inputs
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
def callback_inputs_all(pipe, i, t, callback_kwargs):
for tensor_name in pipe._callback_tensor_inputs:
assert tensor_name in callback_kwargs
# iterate over callback args
for tensor_name, tensor_value in callback_kwargs.items():
# check that we're only passing in allowed tensor inputs
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
inputs = self.get_dummy_inputs(torch_device)
# Test passing in a subset
inputs["callback_on_step_end"] = callback_inputs_subset
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
output = pipe(**inputs)[0]
# Test passing in a everything
inputs["callback_on_step_end"] = callback_inputs_all
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
output = pipe(**inputs)[0]
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
is_last = i == (pipe.num_timesteps - 1)
if is_last:
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
return callback_kwargs
inputs["callback_on_step_end"] = callback_inputs_change_tensor
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
output = pipe(**inputs)[0]
assert output.abs().sum() < 1e10
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3)
def test_attention_slicing_forward_pass(
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
):
if not self.test_attention_slicing:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
output_without_slicing = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=1)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing1 = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=2)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing2 = pipe(**inputs)[0]
if test_max_difference:
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
self.assertLess(
max(max_diff1, max_diff2),
expected_max_diff,
"Attention slicing should not affect the inference results",
)
def test_vae_tiling(self, expected_diff_max: float = 0.2):
# Since VideoToVideo uses both encoder and decoder tiling, there seems to be much more numerical
# difference. We seem to need a higher tolerance here...
# TODO(aryan): Look into this more deeply
expected_diff_max = 0.4
generator_device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to("cpu")
pipe.set_progress_bar_config(disable=None)
# Without tiling
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_without_tiling = pipe(**inputs)[0]
# With tiling
pipe.vae.enable_tiling(
tile_sample_min_height=96,
tile_sample_min_width=96,
tile_overlap_factor_height=1 / 12,
tile_overlap_factor_width=1 / 12,
)
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_with_tiling = pipe(**inputs)[0]
self.assertLess(
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
expected_diff_max,
"VAE tiling should not affect the inference results",
)
@unittest.skip("xformers attention processor does not exist for CogVideoX")
def test_xformers_attention_forwardGenerator_pass(self):
pass
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
frames = pipe(**inputs).frames # [B, F, C, H, W]
original_image_slice = frames[0, -2:, -1, -3:, -3:]
pipe.fuse_qkv_projections()
assert check_qkv_fusion_processors_exist(
pipe.transformer
), "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
assert check_qkv_fusion_matches_attn_procs_length(
pipe.transformer, pipe.transformer.original_attn_processors
), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device)
frames = pipe(**inputs).frames
image_slice_fused = frames[0, -2:, -1, -3:, -3:]
pipe.transformer.unfuse_qkv_projections()
inputs = self.get_dummy_inputs(device)
frames = pipe(**inputs).frames
image_slice_disabled = frames[0, -2:, -1, -3:, -3:]
assert np.allclose(
original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
), "Fusion of QKV projections shouldn't affect the outputs."
assert np.allclose(
image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
assert np.allclose(
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."
@@ -220,11 +220,11 @@ class ControlNetPipelineFastTests(
def test_attention_slicing_forward_pass(self):
return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
def test_ip_adapter_single(self):
def test_ip_adapter(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array([0.5234, 0.3333, 0.1745, 0.7605, 0.6224, 0.4637, 0.6989, 0.7526, 0.4665])
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
@@ -460,11 +460,11 @@ class StableDiffusionMultiControlNetPipelineFastTests(
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
def test_ip_adapter_single(self):
def test_ip_adapter(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array([0.2422, 0.3425, 0.4048, 0.5351, 0.3503, 0.2419, 0.4645, 0.4570, 0.3804])
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
def test_save_pretrained_raise_not_implemented_exception(self):
components = self.get_dummy_components()
@@ -679,11 +679,11 @@ class StableDiffusionMultiControlNetOneModelPipelineFastTests(
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
def test_ip_adapter_single(self):
def test_ip_adapter(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array([0.5264, 0.3203, 0.1602, 0.8235, 0.6332, 0.4593, 0.7226, 0.7777, 0.4780])
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
def test_save_pretrained_raise_not_implemented_exception(self):
components = self.get_dummy_components()
@@ -173,11 +173,11 @@ class ControlNetImg2ImgPipelineFastTests(
def test_attention_slicing_forward_pass(self):
return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
def test_ip_adapter_single(self):
def test_ip_adapter(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array([0.7096, 0.5149, 0.3571, 0.5897, 0.4715, 0.4052, 0.6098, 0.6886, 0.4213])
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
@@ -371,11 +371,11 @@ class StableDiffusionMultiControlNetPipelineFastTests(
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3)
def test_ip_adapter_single(self):
def test_ip_adapter(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array([0.5293, 0.7339, 0.6642, 0.3950, 0.5212, 0.5175, 0.7002, 0.5907, 0.5182])
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
def test_save_pretrained_raise_not_implemented_exception(self):
components = self.get_dummy_components()
@@ -190,14 +190,14 @@ class StableDiffusionXLControlNetPipelineFastTests(
def test_attention_slicing_forward_pass(self):
return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
def test_ip_adapter_single(self, from_ssd1b=False, expected_pipe_slice=None):
def test_ip_adapter(self, from_ssd1b=False, expected_pipe_slice=None):
if not from_ssd1b:
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array(
[0.7335, 0.5866, 0.5623, 0.6242, 0.5751, 0.5999, 0.4091, 0.4590, 0.5054]
)
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
@@ -970,12 +970,12 @@ class StableDiffusionSSD1BControlNetPipelineFastTests(StableDiffusionXLControlNe
# make sure that it's equal
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-4
def test_ip_adapter_single(self):
def test_ip_adapter(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array([0.7212, 0.5890, 0.5491, 0.6425, 0.5970, 0.6091, 0.4418, 0.4556, 0.5032])
return super().test_ip_adapter_single(from_ssd1b=True, expected_pipe_slice=expected_pipe_slice)
return super().test_ip_adapter(from_ssd1b=True, expected_pipe_slice=expected_pipe_slice)
def test_controlnet_sdxl_lcm(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -175,12 +175,12 @@ class ControlNetPipelineSDXLImg2ImgFastTests(
return inputs
def test_ip_adapter_single(self):
def test_ip_adapter(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array([0.6276, 0.5271, 0.5205, 0.5393, 0.5774, 0.5872, 0.5456, 0.5415, 0.5354])
# TODO: update after slices.p
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
def test_stable_diffusion_xl_controlnet_img2img(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -25,6 +25,9 @@ class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
batch_params = frozenset(["prompt"])
# there is no xformers processor for Flux
test_xformers_attention = False
def get_dummy_components(self):
torch.manual_seed(0)
transformer = FluxTransformer2DModel(
@@ -0,0 +1,149 @@
import random
import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxImg2ImgPipeline, FluxTransformer2DModel
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
@unittest.skipIf(torch_device == "mps", "Flux has a float64 operation which is not supported in MPS.")
class FluxImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class = FluxImg2ImgPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
batch_params = frozenset(["prompt"])
def get_dummy_components(self):
torch.manual_seed(0)
transformer = FluxTransformer2DModel(
patch_size=1,
in_channels=4,
num_layers=1,
num_single_layers=1,
attention_head_dim=16,
num_attention_heads=2,
joint_attention_dim=32,
pooled_projection_dim=32,
axes_dims_rope=[4, 4, 8],
)
clip_text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
)
torch.manual_seed(0)
text_encoder = CLIPTextModel(clip_text_encoder_config)
torch.manual_seed(0)
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
vae = AutoencoderKL(
sample_size=32,
in_channels=3,
out_channels=3,
block_out_channels=(4,),
layers_per_block=1,
latent_channels=1,
norm_num_groups=1,
use_quant_conv=False,
use_post_quant_conv=False,
shift_factor=0.0609,
scaling_factor=1.5035,
)
scheduler = FlowMatchEulerDiscreteScheduler()
return {
"scheduler": scheduler,
"text_encoder": text_encoder,
"text_encoder_2": text_encoder_2,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
"transformer": transformer,
"vae": vae,
}
def get_dummy_inputs(self, device, seed=0):
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"image": image,
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"height": 8,
"width": 8,
"max_sequence_length": 48,
"strength": 0.8,
"output_type": "np",
}
return inputs
def test_flux_different_prompts(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_same_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt_2"] = "a different prompt"
output_different_prompts = pipe(**inputs).images[0]
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
# Outputs should be different here
# For some reasons, they don't show large differences
assert max_diff > 1e-6
def test_flux_prompt_embeds(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_with_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
prompt = inputs.pop("prompt")
(prompt_embeds, pooled_prompt_embeds, text_ids) = pipe.encode_prompt(
prompt,
prompt_2=None,
device=torch_device,
max_sequence_length=inputs["max_sequence_length"],
)
output_with_embeds = pipe(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
**inputs,
).images[0]
max_diff = np.abs(output_with_prompt - output_with_embeds).max()
assert max_diff < 1e-4
@@ -0,0 +1,151 @@
import random
import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxInpaintPipeline, FluxTransformer2DModel
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
@unittest.skipIf(torch_device == "mps", "Flux has a float64 operation which is not supported in MPS.")
class FluxInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class = FluxInpaintPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
batch_params = frozenset(["prompt"])
def get_dummy_components(self):
torch.manual_seed(0)
transformer = FluxTransformer2DModel(
patch_size=1,
in_channels=8,
num_layers=1,
num_single_layers=1,
attention_head_dim=16,
num_attention_heads=2,
joint_attention_dim=32,
pooled_projection_dim=32,
axes_dims_rope=[4, 4, 8],
)
clip_text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
)
torch.manual_seed(0)
text_encoder = CLIPTextModel(clip_text_encoder_config)
torch.manual_seed(0)
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
vae = AutoencoderKL(
sample_size=32,
in_channels=3,
out_channels=3,
block_out_channels=(4,),
layers_per_block=1,
latent_channels=2,
norm_num_groups=1,
use_quant_conv=False,
use_post_quant_conv=False,
shift_factor=0.0609,
scaling_factor=1.5035,
)
scheduler = FlowMatchEulerDiscreteScheduler()
return {
"scheduler": scheduler,
"text_encoder": text_encoder,
"text_encoder_2": text_encoder_2,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
"transformer": transformer,
"vae": vae,
}
def get_dummy_inputs(self, device, seed=0):
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
mask_image = torch.ones((1, 1, 32, 32)).to(device)
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"image": image,
"mask_image": mask_image,
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"height": 32,
"width": 32,
"max_sequence_length": 48,
"strength": 0.8,
"output_type": "np",
}
return inputs
def test_flux_inpaint_different_prompts(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_same_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt_2"] = "a different prompt"
output_different_prompts = pipe(**inputs).images[0]
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
# Outputs should be different here
# For some reasons, they don't show large differences
assert max_diff > 1e-6
def test_flux_inpaint_prompt_embeds(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_with_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
prompt = inputs.pop("prompt")
(prompt_embeds, pooled_prompt_embeds, text_ids) = pipe.encode_prompt(
prompt,
prompt_2=None,
device=torch_device,
max_sequence_length=inputs["max_sequence_length"],
)
output_with_embeds = pipe(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
**inputs,
).images[0]
max_diff = np.abs(output_with_prompt - output_with_embeds).max()
assert max_diff < 1e-4
@@ -550,7 +550,7 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
assert max_diff < 5e-4
def test_ip_adapter_single_mask(self):
def test_ip_adapter_mask(self):
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
@@ -108,11 +108,11 @@ class LatentConsistencyModelPipelineFastTests(
}
return inputs
def test_ip_adapter_single(self):
def test_ip_adapter(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array([0.1403, 0.5072, 0.5316, 0.1202, 0.3865, 0.4211, 0.5363, 0.3557, 0.3645])
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
def test_lcm_onestep(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -119,11 +119,11 @@ class LatentConsistencyModelImg2ImgPipelineFastTests(
}
return inputs
def test_ip_adapter_single(self):
def test_ip_adapter(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array([0.4003, 0.3718, 0.2863, 0.5500, 0.5587, 0.3772, 0.4617, 0.4961, 0.4417])
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
def test_lcm_onestep(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
+2 -2
View File
@@ -175,7 +175,7 @@ class AnimateDiffPAGPipelineFastTests(
def test_attention_slicing_forward_pass(self):
pass
def test_ip_adapter_single(self):
def test_ip_adapter(self):
expected_pipe_slice = None
if torch_device == "cpu":
@@ -210,7 +210,7 @@ class AnimateDiffPAGPipelineFastTests(
0.5538,
]
)
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
def test_dict_tuple_outputs_equivalent(self):
expected_slice = None
+1
View File
@@ -37,6 +37,7 @@ class StableDiffusion3PAGPipelineFastTests(unittest.TestCase, PipelineTesterMixi
]
)
batch_params = frozenset(["prompt", "negative_prompt"])
test_xformers_attention = False
def get_dummy_components(self):
torch.manual_seed(0)
+2 -2
View File
@@ -176,7 +176,7 @@ class PIAPipelineFastTests(IPAdapterTesterMixin, PipelineTesterMixin, PipelineFr
assert isinstance(pipe.unet, UNetMotionModel)
def test_ip_adapter_single(self):
def test_ip_adapter(self):
expected_pipe_slice = None
if torch_device == "cpu":
@@ -211,7 +211,7 @@ class PIAPipelineFastTests(IPAdapterTesterMixin, PipelineTesterMixin, PipelineFr
0.5538,
]
)
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
def test_dict_tuple_outputs_equivalent(self):
expected_slice = None
@@ -68,6 +68,8 @@ class StableAudioPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"callback_steps",
]
)
# There is not xformers version of the StableAudioPipeline custom attention processor
test_xformers_attention = False
def get_dummy_components(self):
torch.manual_seed(0)
@@ -253,11 +253,11 @@ class StableDiffusionImg2ImgPipelineFastTests(
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
def test_ip_adapter_single(self):
def test_ip_adapter(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array([0.4932, 0.5092, 0.5135, 0.5517, 0.5626, 0.6621, 0.6490, 0.5021, 0.5441])
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
def test_stable_diffusion_img2img_multiple_init_images(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -385,14 +385,14 @@ class StableDiffusionInpaintPipelineFastTests(
# they should be the same
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
def test_ip_adapter_single(self, from_simple=False, expected_pipe_slice=None):
def test_ip_adapter(self, from_simple=False, expected_pipe_slice=None):
if not from_simple:
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array(
[0.4390, 0.5452, 0.3772, 0.5448, 0.6031, 0.4480, 0.5194, 0.4687, 0.4640]
)
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
class StableDiffusionSimpleInpaintPipelineFastTests(StableDiffusionInpaintPipelineFastTests):
@@ -481,11 +481,11 @@ class StableDiffusionSimpleInpaintPipelineFastTests(StableDiffusionInpaintPipeli
}
return inputs
def test_ip_adapter_single(self):
def test_ip_adapter(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array([0.6345, 0.5395, 0.5611, 0.5403, 0.5830, 0.5855, 0.5193, 0.5443, 0.5211])
return super().test_ip_adapter_single(from_simple=True, expected_pipe_slice=expected_pipe_slice)
return super().test_ip_adapter(from_simple=True, expected_pipe_slice=expected_pipe_slice)
def test_stable_diffusion_inpaint(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -281,9 +281,6 @@ class StableDiffusionDepth2ImgPipelineFastTests(
max_diff = np.abs(output - output_tuple).max()
self.assertLess(max_diff, 1e-4)
def test_progress_bar(self):
super().test_progress_bar()
def test_stable_diffusion_depth2img_default_case(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
@@ -330,12 +330,12 @@ class StableDiffusionXLPipelineFastTests(
# make sure that it's equal
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
def test_ip_adapter_single(self):
def test_ip_adapter(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array([0.5388, 0.5452, 0.4694, 0.4583, 0.5253, 0.4832, 0.5288, 0.5035, 0.4766])
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
def test_attention_slicing_forward_pass(self):
super().test_attention_slicing_forward_pass(expected_max_diff=3e-3)
@@ -290,7 +290,7 @@ class StableDiffusionXLAdapterPipelineFastTests(
}
return inputs
def test_ip_adapter_single(self, from_multi=False, expected_pipe_slice=None):
def test_ip_adapter(self, from_multi=False, expected_pipe_slice=None):
if not from_multi:
expected_pipe_slice = None
if torch_device == "cpu":
@@ -298,7 +298,7 @@ class StableDiffusionXLAdapterPipelineFastTests(
[0.5752, 0.6155, 0.4826, 0.5111, 0.5741, 0.4678, 0.5199, 0.5231, 0.4794]
)
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
def test_stable_diffusion_adapter_default_case(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -448,12 +448,12 @@ class StableDiffusionXLMultiAdapterPipelineFastTests(
expected_slice = np.array([0.5617, 0.6081, 0.4807, 0.5071, 0.5665, 0.4614, 0.5165, 0.5164, 0.4786])
assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-3
def test_ip_adapter_single(self):
def test_ip_adapter(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array([0.5617, 0.6081, 0.4807, 0.5071, 0.5665, 0.4614, 0.5165, 0.5164, 0.4786])
return super().test_ip_adapter_single(from_multi=True, expected_pipe_slice=expected_pipe_slice)
return super().test_ip_adapter(from_multi=True, expected_pipe_slice=expected_pipe_slice)
def test_inference_batch_consistent(
self, batch_sizes=[2, 4, 13], additional_params_copy_to_batched_inputs=["num_inference_steps"]
@@ -310,12 +310,12 @@ class StableDiffusionXLImg2ImgPipelineFastTests(
# make sure that it's equal
assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
def test_ip_adapter_single(self):
def test_ip_adapter(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array([0.5133, 0.4626, 0.4970, 0.6273, 0.5160, 0.6891, 0.6639, 0.5892, 0.5709])
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
def test_stable_diffusion_xl_img2img_tiny_autoencoder(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -223,12 +223,12 @@ class StableDiffusionXLInpaintPipelineFastTests(
}
return inputs
def test_ip_adapter_single(self):
def test_ip_adapter(self):
expected_pipe_slice = None
if torch_device == "cpu":
expected_pipe_slice = np.array([0.8274, 0.5538, 0.6141, 0.5843, 0.6865, 0.7082, 0.5861, 0.6123, 0.5344])
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
def test_components_function(self):
init_components = self.get_dummy_components()
+267
View File
@@ -1,6 +1,25 @@
import contextlib
import io
import re
import unittest
import torch
from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import (
AnimateDiffPipeline,
AnimateDiffVideoToVideoPipeline,
AutoencoderKL,
DDIMScheduler,
MotionAdapter,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.pipelines.pipeline_utils import is_safetensors_compatible
from diffusers.utils.testing_utils import torch_device
class IsSafetensorsCompatibleTests(unittest.TestCase):
@@ -177,3 +196,251 @@ class IsSafetensorsCompatibleTests(unittest.TestCase):
"unet/diffusion_pytorch_model.fp16.safetensors",
]
self.assertTrue(is_safetensors_compatible(filenames))
class ProgressBarTests(unittest.TestCase):
def get_dummy_components_image_generation(self):
cross_attention_dim = 8
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(4, 8),
layers_per_block=1,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=cross_attention_dim,
norm_num_groups=2,
)
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[4, 8],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
norm_num_groups=2,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=cross_attention_dim,
intermediate_size=16,
layer_norm_eps=1e-05,
num_attention_heads=2,
num_hidden_layers=2,
pad_token_id=1,
vocab_size=1000,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"safety_checker": None,
"feature_extractor": None,
"image_encoder": None,
}
return components
def get_dummy_components_video_generation(self):
cross_attention_dim = 8
block_out_channels = (8, 8)
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=block_out_channels,
layers_per_block=2,
sample_size=8,
in_channels=4,
out_channels=4,
down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=cross_attention_dim,
norm_num_groups=2,
)
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="linear",
clip_sample=False,
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=block_out_channels,
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
norm_num_groups=2,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=cross_attention_dim,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
torch.manual_seed(0)
motion_adapter = MotionAdapter(
block_out_channels=block_out_channels,
motion_layers_per_block=2,
motion_norm_num_groups=2,
motion_num_attention_heads=4,
)
components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"motion_adapter": motion_adapter,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"feature_extractor": None,
"image_encoder": None,
}
return components
def test_text_to_image(self):
components = self.get_dummy_components_image_generation()
pipe = StableDiffusionPipeline(**components)
pipe.to(torch_device)
inputs = {"prompt": "a cute cat", "num_inference_steps": 2}
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
_ = pipe(**inputs)
stderr = stderr.getvalue()
# we can't calculate the number of progress steps beforehand e.g. for strength-dependent img2img,
# so we just match "5" in "#####| 1/5 [00:01<00:00]"
max_steps = re.search("/(.*?) ", stderr).group(1)
self.assertTrue(max_steps is not None and len(max_steps) > 0)
self.assertTrue(
f"{max_steps}/{max_steps}" in stderr, "Progress bar should be enabled and stopped at the max step"
)
pipe.set_progress_bar_config(disable=True)
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
_ = pipe(**inputs)
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
def test_image_to_image(self):
components = self.get_dummy_components_image_generation()
pipe = StableDiffusionImg2ImgPipeline(**components)
pipe.to(torch_device)
image = Image.new("RGB", (32, 32))
inputs = {"prompt": "a cute cat", "num_inference_steps": 2, "strength": 0.5, "image": image}
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
_ = pipe(**inputs)
stderr = stderr.getvalue()
# we can't calculate the number of progress steps beforehand e.g. for strength-dependent img2img,
# so we just match "5" in "#####| 1/5 [00:01<00:00]"
max_steps = re.search("/(.*?) ", stderr).group(1)
self.assertTrue(max_steps is not None and len(max_steps) > 0)
self.assertTrue(
f"{max_steps}/{max_steps}" in stderr, "Progress bar should be enabled and stopped at the max step"
)
pipe.set_progress_bar_config(disable=True)
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
_ = pipe(**inputs)
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
def test_inpainting(self):
components = self.get_dummy_components_image_generation()
pipe = StableDiffusionInpaintPipeline(**components)
pipe.to(torch_device)
image = Image.new("RGB", (32, 32))
mask = Image.new("RGB", (32, 32))
inputs = {
"prompt": "a cute cat",
"num_inference_steps": 2,
"strength": 0.5,
"image": image,
"mask_image": mask,
}
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
_ = pipe(**inputs)
stderr = stderr.getvalue()
# we can't calculate the number of progress steps beforehand e.g. for strength-dependent img2img,
# so we just match "5" in "#####| 1/5 [00:01<00:00]"
max_steps = re.search("/(.*?) ", stderr).group(1)
self.assertTrue(max_steps is not None and len(max_steps) > 0)
self.assertTrue(
f"{max_steps}/{max_steps}" in stderr, "Progress bar should be enabled and stopped at the max step"
)
pipe.set_progress_bar_config(disable=True)
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
_ = pipe(**inputs)
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
def test_text_to_video(self):
components = self.get_dummy_components_video_generation()
pipe = AnimateDiffPipeline(**components)
pipe.to(torch_device)
inputs = {"prompt": "a cute cat", "num_inference_steps": 2, "num_frames": 2}
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
_ = pipe(**inputs)
stderr = stderr.getvalue()
# we can't calculate the number of progress steps beforehand e.g. for strength-dependent img2img,
# so we just match "5" in "#####| 1/5 [00:01<00:00]"
max_steps = re.search("/(.*?) ", stderr).group(1)
self.assertTrue(max_steps is not None and len(max_steps) > 0)
self.assertTrue(
f"{max_steps}/{max_steps}" in stderr, "Progress bar should be enabled and stopped at the max step"
)
pipe.set_progress_bar_config(disable=True)
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
_ = pipe(**inputs)
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
def test_video_to_video(self):
components = self.get_dummy_components_video_generation()
pipe = AnimateDiffVideoToVideoPipeline(**components)
pipe.to(torch_device)
num_frames = 2
video = [Image.new("RGB", (32, 32))] * num_frames
inputs = {"prompt": "a cute cat", "num_inference_steps": 2, "video": video}
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
_ = pipe(**inputs)
stderr = stderr.getvalue()
# we can't calculate the number of progress steps beforehand e.g. for strength-dependent img2img,
# so we just match "5" in "#####| 1/5 [00:01<00:00]"
max_steps = re.search("/(.*?) ", stderr).group(1)
self.assertTrue(max_steps is not None and len(max_steps) > 0)
self.assertTrue(
f"{max_steps}/{max_steps}" in stderr, "Progress bar should be enabled and stopped at the max step"
)
pipe.set_progress_bar_config(disable=True)
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
_ = pipe(**inputs)
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
+33
View File
@@ -235,9 +235,32 @@ class AutoPipelineFastTest(unittest.TestCase):
pipe = AutoPipelineForImage2Image.from_pretrained(repo)
assert pipe.__class__.__name__ == "StableDiffusionXLImg2ImgPipeline"
controlnet = ControlNetModel.from_pretrained("hf-internal-testing/tiny-controlnet")
pipe_control = AutoPipelineForImage2Image.from_pretrained(repo, controlnet=controlnet)
assert pipe_control.__class__.__name__ == "StableDiffusionXLControlNetImg2ImgPipeline"
pipe_pag = AutoPipelineForImage2Image.from_pretrained(repo, enable_pag=True)
assert pipe_pag.__class__.__name__ == "StableDiffusionXLPAGImg2ImgPipeline"
pipe_control_pag = AutoPipelineForImage2Image.from_pretrained(repo, controlnet=controlnet, enable_pag=True)
assert pipe_control_pag.__class__.__name__ == "StableDiffusionXLControlNetPAGImg2ImgPipeline"
def test_from_pretrained_img2img_refiner(self):
repo = "hf-internal-testing/tiny-stable-diffusion-xl-refiner-pipe"
pipe = AutoPipelineForImage2Image.from_pretrained(repo)
assert pipe.__class__.__name__ == "StableDiffusionXLImg2ImgPipeline"
controlnet = ControlNetModel.from_pretrained("hf-internal-testing/tiny-controlnet")
pipe_control = AutoPipelineForImage2Image.from_pretrained(repo, controlnet=controlnet)
assert pipe_control.__class__.__name__ == "StableDiffusionXLControlNetImg2ImgPipeline"
pipe_pag = AutoPipelineForImage2Image.from_pretrained(repo, enable_pag=True)
assert pipe_pag.__class__.__name__ == "StableDiffusionXLPAGImg2ImgPipeline"
pipe_control_pag = AutoPipelineForImage2Image.from_pretrained(repo, controlnet=controlnet, enable_pag=True)
assert pipe_control_pag.__class__.__name__ == "StableDiffusionXLControlNetPAGImg2ImgPipeline"
def test_from_pipe_pag_img2img(self):
# test from tableDiffusionXLPAGImg2ImgPipeline
pipe = AutoPipelineForImage2Image.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
@@ -265,6 +288,16 @@ class AutoPipelineFastTest(unittest.TestCase):
pipe_pag = AutoPipelineForInpainting.from_pretrained(repo, enable_pag=True)
assert pipe_pag.__class__.__name__ == "StableDiffusionXLPAGInpaintPipeline"
def test_from_pretrained_inpaint_from_inpaint(self):
repo = "hf-internal-testing/tiny-stable-diffusion-xl-inpaint-pipe"
pipe = AutoPipelineForInpainting.from_pretrained(repo)
assert pipe.__class__.__name__ == "StableDiffusionXLInpaintPipeline"
# make sure you can use pag with inpaint-specific pipeline
pipe = AutoPipelineForInpainting.from_pretrained(repo, enable_pag=True)
assert pipe.__class__.__name__ == "StableDiffusionXLPAGInpaintPipeline"
def test_from_pipe_pag_inpaint(self):
# test from tableDiffusionXLPAGInpaintPipeline
pipe = AutoPipelineForInpainting.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
+25 -60
View File
@@ -1,10 +1,7 @@
import contextlib
import gc
import inspect
import io
import json
import os
import re
import tempfile
import unittest
import uuid
@@ -141,52 +138,35 @@ class SDFunctionTesterMixin:
assert np.abs(to_np(output_2) - to_np(output_1)).max() < 5e-1
# test that tiled decode works with various shapes
shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)]
shapes = [(1, 4, 73, 97), (1, 4, 65, 49)]
with torch.no_grad():
for shape in shapes:
zeros = torch.zeros(shape).to(torch_device)
pipe.vae.decode(zeros)
# MPS currently doesn't support ComplexFloats, which are required for freeU - see https://github.com/huggingface/diffusers/issues/7569.
# MPS currently doesn't support ComplexFloats, which are required for FreeU - see https://github.com/huggingface/diffusers/issues/7569.
@skip_mps
def test_freeu_enabled(self):
def test_freeu(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
# Normal inference
inputs = self.get_dummy_inputs(torch_device)
inputs["return_dict"] = False
inputs["output_type"] = "np"
output = pipe(**inputs)[0]
# FreeU-enabled inference
pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
inputs = self.get_dummy_inputs(torch_device)
inputs["return_dict"] = False
inputs["output_type"] = "np"
output_freeu = pipe(**inputs)[0]
assert not np.allclose(
output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1]
), "Enabling of FreeU should lead to different results."
def test_freeu_disabled(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
inputs["return_dict"] = False
inputs["output_type"] = "np"
output = pipe(**inputs)[0]
pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
# FreeU-disabled inference
pipe.disable_freeu()
freeu_keys = {"s1", "s2", "b1", "b2"}
for upsample_block in pipe.unet.up_blocks:
for key in freeu_keys:
@@ -195,8 +175,11 @@ class SDFunctionTesterMixin:
inputs = self.get_dummy_inputs(torch_device)
inputs["return_dict"] = False
inputs["output_type"] = "np"
output_no_freeu = pipe(**inputs)[0]
assert not np.allclose(
output[0, -3:, -3:, -1], output_freeu[0, -3:, -3:, -1]
), "Enabling of FreeU should lead to different results."
assert np.allclose(
output, output_no_freeu, atol=1e-2
), f"Disabling of FreeU should lead to results similar to the default pipeline results but Max Abs Error={np.abs(output_no_freeu - output).max()}."
@@ -290,7 +273,15 @@ class IPAdapterTesterMixin:
inputs["return_dict"] = False
return inputs
def test_ip_adapter_single(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
r"""Tests for IP-Adapter.
The following scenarios are tested:
- Single IP-Adapter with scale=0 should produce same output as no IP-Adapter.
- Multi IP-Adapter with scale=0 should produce same output as no IP-Adapter.
- Single IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
- Multi IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
"""
# Raising the tolerance for this test when it's run on a CPU because we
# compare against static slices and that can be shaky (with a VVVV low probability).
expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
@@ -307,6 +298,7 @@ class IPAdapterTesterMixin:
else:
output_without_adapter = expected_pipe_slice
# 1. Single IP-Adapter test cases
adapter_state_dict = create_ip_adapter_state_dict(pipe.unet)
pipe.unet._load_ip_adapter_weights(adapter_state_dict)
@@ -338,16 +330,7 @@ class IPAdapterTesterMixin:
max_diff_with_adapter_scale, 1e-2, "Output with ip-adapter must be different from normal inference"
)
def test_ip_adapter_multi(self, expected_max_diff: float = 1e-4):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device)
pipe.set_progress_bar_config(disable=None)
cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32)
# forward pass without ip adapter
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
output_without_adapter = pipe(**inputs)[0]
# 2. Multi IP-Adapter test cases
adapter_state_dict_1 = create_ip_adapter_state_dict(pipe.unet)
adapter_state_dict_2 = create_ip_adapter_state_dict(pipe.unet)
pipe.unet._load_ip_adapter_weights([adapter_state_dict_1, adapter_state_dict_2])
@@ -357,12 +340,16 @@ class IPAdapterTesterMixin:
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
pipe.set_ip_adapter_scale([0.0, 0.0])
output_without_multi_adapter_scale = pipe(**inputs)[0]
if expected_pipe_slice is not None:
output_without_multi_adapter_scale = output_without_multi_adapter_scale[0, -3:, -3:, -1].flatten()
# forward pass with multi ip adapter, but with scale of adapter weights
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
pipe.set_ip_adapter_scale([42.0, 42.0])
output_with_multi_adapter_scale = pipe(**inputs)[0]
if expected_pipe_slice is not None:
output_with_multi_adapter_scale = output_with_multi_adapter_scale[0, -3:, -3:, -1].flatten()
max_diff_without_multi_adapter_scale = np.abs(
output_without_multi_adapter_scale - output_without_adapter
@@ -1689,28 +1676,6 @@ class PipelineTesterMixin:
if test_mean_pixel_difference:
assert_mean_pixel_difference(output_with_offload[0], output_without_offload[0])
def test_progress_bar(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
_ = pipe(**inputs)
stderr = stderr.getvalue()
# we can't calculate the number of progress steps beforehand e.g. for strength-dependent img2img,
# so we just match "5" in "#####| 1/5 [00:01<00:00]"
max_steps = re.search("/(.*?) ", stderr).group(1)
self.assertTrue(max_steps is not None and len(max_steps) > 0)
self.assertTrue(
f"{max_steps}/{max_steps}" in stderr, "Progress bar should be enabled and stopped at the max step"
)
pipe.set_progress_bar_config(disable=True)
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
_ = pipe(**inputs)
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
def test_num_images_per_prompt(self):
sig = inspect.signature(self.pipeline_class.__call__)
@@ -173,9 +173,6 @@ class TextToVideoSDPipelineFastTests(PipelineTesterMixin, SDFunctionTesterMixin,
def test_num_images_per_prompt(self):
pass
def test_progress_bar(self):
return super().test_progress_bar()
@slow
@skip_mps
@@ -13,11 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import gc
import inspect
import io
import re
import tempfile
import unittest
@@ -282,28 +279,6 @@ class TextToVideoZeroSDXLPipelineFastTests(PipelineTesterMixin, PipelineFromPipe
def test_pipeline_call_signature(self):
pass
def test_progress_bar(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(torch_device)
inputs = self.get_dummy_inputs(self.generator_device)
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
_ = pipe(**inputs)
stderr = stderr.getvalue()
# we can't calculate the number of progress steps beforehand e.g. for strength-dependent img2img,
# so we just match "5" in "#####| 1/5 [00:01<00:00]"
max_steps = re.search("/(.*?) ", stderr).group(1)
self.assertTrue(max_steps is not None and len(max_steps) > 0)
self.assertTrue(
f"{max_steps}/{max_steps}" in stderr, "Progress bar should be enabled and stopped at the max step"
)
pipe.set_progress_bar_config(disable=True)
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
_ = pipe(**inputs)
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
@unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA")
def test_save_load_float16(self, expected_max_diff=1e-2):
components = self.get_dummy_components()
@@ -197,9 +197,6 @@ class VideoToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def test_num_images_per_prompt(self):
pass
def test_progress_bar(self):
return super().test_progress_bar()
@nightly
@skip_mps
+13 -12
View File
@@ -5,6 +5,7 @@ import requests
import torch
from huggingface_hub import hf_hub_download, snapshot_download
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.models.attention_processor import AttnProcessor
from diffusers.utils.testing_utils import (
numpy_cosine_similarity_distance,
@@ -98,8 +99,8 @@ class SDSingleFileTesterMixin:
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_filename = self.ckpt_path.split("/")[-1]
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir)
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
local_ckpt_path, safety_checker=None, local_files_only=True
@@ -138,8 +139,8 @@ class SDSingleFileTesterMixin:
upcast_attention = pipe.unet.config.upcast_attention
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_filename = self.ckpt_path.split("/")[-1]
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir)
local_original_config = download_original_config(self.original_config, tmpdir)
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
@@ -191,8 +192,8 @@ class SDSingleFileTesterMixin:
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_filename = self.ckpt_path.split("/")[-1]
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir)
local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir)
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
@@ -286,8 +287,8 @@ class SDXLSingleFileTesterMixin:
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_filename = self.ckpt_path.split("/")[-1]
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir)
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
local_ckpt_path, safety_checker=None, local_files_only=True
@@ -327,8 +328,8 @@ class SDXLSingleFileTesterMixin:
upcast_attention = pipe.unet.config.upcast_attention
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_filename = self.ckpt_path.split("/")[-1]
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir)
local_original_config = download_original_config(self.original_config, tmpdir)
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
@@ -364,8 +365,8 @@ class SDXLSingleFileTesterMixin:
pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None)
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_filename = self.ckpt_path.split("/")[-1]
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir)
local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir)
single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file(
@@ -5,6 +5,7 @@ import unittest
import torch
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
enable_full_determinism,
@@ -29,11 +30,11 @@ enable_full_determinism()
@require_torch_gpu
class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
pipeline_class = StableDiffusionControlNetPipeline
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
ckpt_path = "https://huggingface.co/Lykon/DreamShaper/blob/main/DreamShaper_8_pruned.safetensors"
original_config = (
"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
)
repo_id = "runwayml/stable-diffusion-v1-5"
repo_id = "Lykon/dreamshaper-8"
def setUp(self):
super().setUp()
@@ -108,8 +109,8 @@ class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SD
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet)
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_filename = self.ckpt_path.split("/")[-1]
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
repo_id, weights_name = _extract_repo_id_and_weights_name(self.ckpt_path)
local_ckpt_path = download_single_file_checkpoint(repo_id, weights_name, tmpdir)
pipe_single_file = self.pipeline_class.from_single_file(
local_ckpt_path, controlnet=controlnet, safety_checker=None, local_files_only=True
@@ -136,8 +137,9 @@ class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SD
)
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_filename = self.ckpt_path.split("/")[-1]
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
repo_id, weights_name = _extract_repo_id_and_weights_name(self.ckpt_path)
local_ckpt_path = download_single_file_checkpoint(repo_id, weights_name, tmpdir)
local_original_config = download_original_config(self.original_config, tmpdir)
pipe_single_file = self.pipeline_class.from_single_file(
@@ -168,8 +170,9 @@ class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SD
)
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_filename = self.ckpt_path.split("/")[-1]
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
repo_id, weights_name = _extract_repo_id_and_weights_name(self.ckpt_path)
local_ckpt_path = download_single_file_checkpoint(repo_id, weights_name, tmpdir)
local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir)
pipe_single_file = self.pipeline_class.from_single_file(
@@ -5,6 +5,7 @@ import unittest
import torch
from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
enable_full_determinism,
@@ -28,9 +29,9 @@ enable_full_determinism()
@require_torch_gpu
class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
pipeline_class = StableDiffusionControlNetInpaintPipeline
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-inpainting/blob/main/sd-v1-5-inpainting.ckpt"
ckpt_path = "https://huggingface.co/Lykon/DreamShaper/blob/main/DreamShaper_8_INPAINTING.inpainting.safetensors"
original_config = "https://raw.githubusercontent.com/runwayml/stable-diffusion/main/configs/stable-diffusion/v1-inpainting-inference.yaml"
repo_id = "runwayml/stable-diffusion-inpainting"
repo_id = "Lykon/dreamshaper-8-inpainting"
def setUp(self):
super().setUp()
@@ -83,7 +84,7 @@ class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestC
output_sf = pipe_sf(**inputs).images[0]
max_diff = numpy_cosine_similarity_distance(output_sf.flatten(), output.flatten())
assert max_diff < 1e-3
assert max_diff < 2e-3
def test_single_file_components(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny")
@@ -103,8 +104,8 @@ class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestC
pipe = self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None, controlnet=controlnet)
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_filename = self.ckpt_path.split("/")[-1]
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir)
pipe_single_file = self.pipeline_class.from_single_file(
local_ckpt_path, controlnet=controlnet, safety_checker=None, local_files_only=True
@@ -112,6 +113,7 @@ class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestC
super()._compare_component_configs(pipe, pipe_single_file)
@unittest.skip("runwayml original config repo does not exist")
def test_single_file_components_with_original_config(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny", variant="fp16")
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet)
@@ -121,6 +123,7 @@ class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestC
super()._compare_component_configs(pipe, pipe_single_file)
@unittest.skip("runwayml original config repo does not exist")
def test_single_file_components_with_original_config_local_files_only(self):
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16, variant="fp16"
@@ -132,8 +135,8 @@ class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestC
)
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_filename = self.ckpt_path.split("/")[-1]
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir)
local_original_config = download_original_config(self.original_config, tmpdir)
pipe_single_file = self.pipeline_class.from_single_file(
@@ -169,8 +172,8 @@ class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestC
)
with tempfile.TemporaryDirectory() as tmpdir:
ckpt_filename = self.ckpt_path.split("/")[-1]
local_ckpt_path = download_single_file_checkpoint(self.repo_id, ckpt_filename, tmpdir)
repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir)
local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir)
pipe_single_file = self.pipeline_class.from_single_file(

Some files were not shown because too many files have changed in this diff Show More