Compare commits
123 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c73c00610e | |||
| 233dffdc3f | |||
| be2070991f | |||
| bf9a641f1a | |||
| a756694bf0 | |||
| d41388145e | |||
| a6288a5571 | |||
| 7d4db57037 | |||
| 902008608a | |||
| c8ee4af228 | |||
| b64ca6c11c | |||
| e12d610faa | |||
| bf6eaa8aec | |||
| 17128c42a4 | |||
| dbc1d505f0 | |||
| 151b74cd77 | |||
| 41ba8c0bf6 | |||
| 3191248472 | |||
| 648d968cfc | |||
| b756ec6e80 | |||
| d8825e7697 | |||
| 074798b299 | |||
| 3ee966950b | |||
| 9764f229d4 | |||
| 1826a1e7d3 | |||
| 0ed09a17bb | |||
| 2f7a417d1f | |||
| 4450d26b63 | |||
| f781b8c30c | |||
| 9c0e20de61 | |||
| f35a38725b | |||
| f66bd3261c | |||
| c4c99c3907 | |||
| 862a7d5038 | |||
| 8304adce2a | |||
| b389f339ec | |||
| e222246b4e | |||
| 83709d5a06 | |||
| 8eb73c872a | |||
| 88b015dc9f | |||
| 63cdf9c0ba | |||
| 0ac52d6f09 | |||
| ba6fd6eb30 | |||
| 9408aa2dfc | |||
| ec1c7a793f | |||
| 9c68c945e9 | |||
| 2739241ad1 | |||
| 1524781b88 | |||
| 128b96f369 | |||
| e24941b2a7 | |||
| f9d5a9324d | |||
| ac86393487 | |||
| 0d96a894a7 | |||
| 6fb94d51cb | |||
| 7667cfcb41 | |||
| 9f00c617a0 | |||
| aafed3f8dd | |||
| 5ed761a6f2 | |||
| 2f023d7b84 | |||
| e9a3911b67 | |||
| 7186bb45f0 | |||
| 438bd60549 | |||
| 87e8157437 | |||
| 3f421fe09f | |||
| a7d50524dd | |||
| 672bd49573 | |||
| ea893a9ae7 | |||
| 5fb3a98517 | |||
| aace1f412b | |||
| 8957324363 | |||
| e68092a471 | |||
| 3bf5400a64 | |||
| 02cbe972c3 | |||
| 5a196e3d46 | |||
| 22c4f079b1 | |||
| 96a9097445 | |||
| a5f35ee473 | |||
| 63243406ba | |||
| 6bd30ba748 | |||
| cef0e3677e | |||
| ec9bfa9e14 | |||
| bdbaea8f64 | |||
| e8b65bffa2 | |||
| f2d348d904 | |||
| c002724dd5 | |||
| 96c376a5ff | |||
| 8170dc368d | |||
| 25f3e91c81 | |||
| a6a18cff5e | |||
| 7db9463e52 | |||
| 26e80e0143 | |||
| 914a585be8 | |||
| ad40e26515 | |||
| d041dd5040 | |||
| 0967593400 | |||
| 43534a8d1f | |||
| 65b98b5da4 | |||
| 49a9143479 | |||
| 4c4b323c1f | |||
| 22d3a82651 | |||
| c9e4fab42c | |||
| 0e50401e34 | |||
| 6131a93b96 | |||
| 3cb7b8628c | |||
| fa3a9100be | |||
| 188bca3084 | |||
| cd892041e2 | |||
| 6394d905da | |||
| 18f9b99088 | |||
| bf64b32652 | |||
| 3335e2262d | |||
| 65ab1052b8 | |||
| 40fc389c44 | |||
| 98d0cd5778 | |||
| 0d11ab26c4 | |||
| 243d9a4986 | |||
| 96220390a2 | |||
| 73dac0c49e | |||
| 04bba38725 | |||
| a2d424eb2e | |||
| 25ddc7945b | |||
| e8da75dff5 | |||
| 8a450c3da0 |
@@ -238,12 +238,13 @@ jobs:
|
||||
|
||||
run_flax_tpu_tests:
|
||||
name: Nightly Flax TPU Tests
|
||||
runs-on: docker-tpu
|
||||
runs-on:
|
||||
group: gcp-ct5lp-hightpu-8t
|
||||
if: github.event_name == 'schedule'
|
||||
|
||||
container:
|
||||
image: diffusers/diffusers-flax-tpu
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ --privileged
|
||||
options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@@ -356,6 +357,8 @@ jobs:
|
||||
config:
|
||||
- backend: "bitsandbytes"
|
||||
test_location: "bnb"
|
||||
- backend: "gguf"
|
||||
test_location: "gguf"
|
||||
runs-on:
|
||||
group: aws-g6e-xlarge-plus
|
||||
container:
|
||||
@@ -519,4 +522,4 @@ jobs:
|
||||
# if: always()
|
||||
# run: |
|
||||
# pip install slack_sdk tabulate
|
||||
# python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
# python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
@@ -1,134 +0,0 @@
|
||||
name: Fast tests for PRs - PEFT backend
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "src/diffusers/**.py"
|
||||
- "tests/**.py"
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
OMP_NUM_THREADS: 4
|
||||
MKL_NUM_THREADS: 4
|
||||
PYTEST_TIMEOUT: 60
|
||||
|
||||
jobs:
|
||||
check_code_quality:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[quality]
|
||||
- name: Check quality
|
||||
run: make quality
|
||||
- name: Check if failure
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
echo "Quality check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make style && make quality'" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
check_repository_consistency:
|
||||
needs: check_code_quality
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[quality]
|
||||
- name: Check repo consistency
|
||||
run: |
|
||||
python utils/check_copies.py
|
||||
python utils/check_dummies.py
|
||||
make deps_table_check_updated
|
||||
- name: Check if failure
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
run_fast_tests:
|
||||
needs: [check_code_quality, check_repository_consistency]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
lib-versions: ["main", "latest"]
|
||||
|
||||
|
||||
name: LoRA - ${{ matrix.lib-versions }}
|
||||
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
# TODO (sayakpaul, DN6): revisit `--no-deps`
|
||||
if [ "${{ matrix.lib-versions }}" == "main" ]; then
|
||||
python -m pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps
|
||||
python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
else
|
||||
python -m uv pip install -U peft --no-deps
|
||||
python -m uv pip install -U transformers accelerate --no-deps
|
||||
fi
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run fast PyTorch LoRA CPU tests with PEFT backend
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v \
|
||||
--make-reports=tests_${{ matrix.lib-versions }} \
|
||||
tests/lora/
|
||||
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v \
|
||||
--make-reports=tests_models_lora_${{ matrix.lib-versions }} \
|
||||
tests/models/ -k "lora"
|
||||
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_${{ matrix.lib-versions }}_failures_short.txt
|
||||
cat reports/tests_models_lora_${{ matrix.lib-versions }}_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: pr_${{ matrix.lib-versions }}_test_reports
|
||||
path: reports
|
||||
@@ -234,3 +234,67 @@ jobs:
|
||||
with:
|
||||
name: pr_${{ matrix.config.report }}_test_reports
|
||||
path: reports
|
||||
|
||||
run_lora_tests:
|
||||
needs: [check_code_quality, check_repository_consistency]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
||||
name: LoRA tests with PEFT main
|
||||
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
# TODO (sayakpaul, DN6): revisit `--no-deps`
|
||||
python -m pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps
|
||||
python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run fast PyTorch LoRA tests with PEFT
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v \
|
||||
--make-reports=tests_peft_main \
|
||||
tests/lora/
|
||||
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v \
|
||||
--make-reports=tests_models_lora_peft_main \
|
||||
tests/models/ -k "lora"
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_lora_failures_short.txt
|
||||
cat reports/tests_models_lora_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: pr_main_test_reports
|
||||
path: reports
|
||||
|
||||
|
||||
@@ -161,10 +161,11 @@ jobs:
|
||||
|
||||
flax_tpu_tests:
|
||||
name: Flax TPU Tests
|
||||
runs-on: docker-tpu
|
||||
runs-on:
|
||||
group: gcp-ct5lp-hightpu-8t
|
||||
container:
|
||||
image: diffusers/diffusers-flax-tpu
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ --privileged
|
||||
options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
|
||||
@@ -46,7 +46,7 @@ jobs:
|
||||
shell: arch -arch arm64 bash {0}
|
||||
run: |
|
||||
${CONDA_RUN} python -m pip install --upgrade pip uv
|
||||
${CONDA_RUN} python -m uv pip install -e [quality,test]
|
||||
${CONDA_RUN} python -m uv pip install -e ".[quality,test]"
|
||||
${CONDA_RUN} python -m uv pip install torch torchvision torchaudio
|
||||
${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
${CONDA_RUN} python -m uv pip install transformers --upgrade
|
||||
|
||||
@@ -112,8 +112,8 @@ Check out the [Quickstart](https://huggingface.co/docs/diffusers/quicktour) to l
|
||||
| **Documentation** | **What can I learn?** |
|
||||
|---------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| [Tutorial](https://huggingface.co/docs/diffusers/tutorials/tutorial_overview) | A basic crash course for learning how to use the library's most important features like using models and schedulers to build your own diffusion system, and training your own diffusion model. |
|
||||
| [Loading](https://huggingface.co/docs/diffusers/using-diffusers/loading_overview) | Guides for how to load and configure all the components (pipelines, models, and schedulers) of the library, as well as how to use different schedulers. |
|
||||
| [Pipelines for inference](https://huggingface.co/docs/diffusers/using-diffusers/pipeline_overview) | Guides for how to use pipelines for different inference tasks, batched generation, controlling generated outputs and randomness, and how to contribute a pipeline to the library. |
|
||||
| [Loading](https://huggingface.co/docs/diffusers/using-diffusers/loading) | Guides for how to load and configure all the components (pipelines, models, and schedulers) of the library, as well as how to use different schedulers. |
|
||||
| [Pipelines for inference](https://huggingface.co/docs/diffusers/using-diffusers/overview_techniques) | Guides for how to use pipelines for different inference tasks, batched generation, controlling generated outputs and randomness, and how to contribute a pipeline to the library. |
|
||||
| [Optimization](https://huggingface.co/docs/diffusers/optimization/fp16) | Guides for how to optimize your diffusion model to run faster and consume less memory. |
|
||||
| [Training](https://huggingface.co/docs/diffusers/training/overview) | Guides for how to train a diffusion model for different tasks with different training techniques. |
|
||||
## Contribution
|
||||
|
||||
@@ -157,6 +157,10 @@
|
||||
title: Getting Started
|
||||
- local: quantization/bitsandbytes
|
||||
title: bitsandbytes
|
||||
- local: quantization/gguf
|
||||
title: gguf
|
||||
- local: quantization/torchao
|
||||
title: torchao
|
||||
title: Quantization Methods
|
||||
- sections:
|
||||
- local: optimization/fp16
|
||||
@@ -234,6 +238,8 @@
|
||||
title: Textual Inversion
|
||||
- local: api/loaders/unet
|
||||
title: UNet
|
||||
- local: api/loaders/transformer_sd3
|
||||
title: SD3Transformer2D
|
||||
- local: api/loaders/peft
|
||||
title: PEFT
|
||||
title: Loaders
|
||||
@@ -252,6 +258,8 @@
|
||||
title: SD3ControlNetModel
|
||||
- local: api/models/controlnet_sparsectrl
|
||||
title: SparseControlNetModel
|
||||
- local: api/models/controlnet_union
|
||||
title: ControlNetUnionModel
|
||||
title: ControlNets
|
||||
- sections:
|
||||
- local: api/models/allegro_transformer3d
|
||||
@@ -268,10 +276,14 @@
|
||||
title: FluxTransformer2DModel
|
||||
- local: api/models/hunyuan_transformer2d
|
||||
title: HunyuanDiT2DModel
|
||||
- local: api/models/hunyuan_video_transformer_3d
|
||||
title: HunyuanVideoTransformer3DModel
|
||||
- local: api/models/latte_transformer3d
|
||||
title: LatteTransformer3DModel
|
||||
- local: api/models/lumina_nextdit2d
|
||||
title: LuminaNextDiT2DModel
|
||||
- local: api/models/ltx_video_transformer3d
|
||||
title: LTXVideoTransformer3DModel
|
||||
- local: api/models/mochi_transformer3d
|
||||
title: MochiTransformer3DModel
|
||||
- local: api/models/pixart_transformer2d
|
||||
@@ -280,6 +292,8 @@
|
||||
title: PriorTransformer
|
||||
- local: api/models/sd3_transformer2d
|
||||
title: SD3Transformer2DModel
|
||||
- local: api/models/sana_transformer2d
|
||||
title: SanaTransformer2DModel
|
||||
- local: api/models/stable_audio_transformer
|
||||
title: StableAudioDiTModel
|
||||
- local: api/models/transformer2d
|
||||
@@ -310,10 +324,16 @@
|
||||
title: AutoencoderKLAllegro
|
||||
- local: api/models/autoencoderkl_cogvideox
|
||||
title: AutoencoderKLCogVideoX
|
||||
- local: api/models/autoencoder_kl_hunyuan_video
|
||||
title: AutoencoderKLHunyuanVideo
|
||||
- local: api/models/autoencoderkl_ltx_video
|
||||
title: AutoencoderKLLTXVideo
|
||||
- local: api/models/autoencoderkl_mochi
|
||||
title: AutoencoderKLMochi
|
||||
- local: api/models/asymmetricautoencoderkl
|
||||
title: AsymmetricAutoencoderKL
|
||||
- local: api/models/autoencoder_dc
|
||||
title: AutoencoderDC
|
||||
- local: api/models/consistency_decoder_vae
|
||||
title: ConsistencyDecoderVAE
|
||||
- local: api/models/autoencoder_oobleck
|
||||
@@ -366,6 +386,8 @@
|
||||
title: ControlNet-XS
|
||||
- local: api/pipelines/controlnetxs_sdxl
|
||||
title: ControlNet-XS with Stable Diffusion XL
|
||||
- local: api/pipelines/controlnet_union
|
||||
title: ControlNetUnion
|
||||
- local: api/pipelines/dance_diffusion
|
||||
title: Dance Diffusion
|
||||
- local: api/pipelines/ddim
|
||||
@@ -380,8 +402,12 @@
|
||||
title: DiT
|
||||
- local: api/pipelines/flux
|
||||
title: Flux
|
||||
- local: api/pipelines/control_flux_inpaint
|
||||
title: FluxControlInpaint
|
||||
- local: api/pipelines/hunyuandit
|
||||
title: Hunyuan-DiT
|
||||
- local: api/pipelines/hunyuan_video
|
||||
title: HunyuanVideo
|
||||
- local: api/pipelines/i2vgenxl
|
||||
title: I2VGen-XL
|
||||
- local: api/pipelines/pix2pix
|
||||
@@ -402,6 +428,8 @@
|
||||
title: Latte
|
||||
- local: api/pipelines/ledits_pp
|
||||
title: LEDITS++
|
||||
- local: api/pipelines/ltx_video
|
||||
title: LTX
|
||||
- local: api/pipelines/lumina
|
||||
title: Lumina-T2X
|
||||
- local: api/pipelines/marigold
|
||||
@@ -422,6 +450,8 @@
|
||||
title: PixArt-α
|
||||
- local: api/pipelines/pixart_sigma
|
||||
title: PixArt-Σ
|
||||
- local: api/pipelines/sana
|
||||
title: Sana
|
||||
- local: api/pipelines/self_attention_guidance
|
||||
title: Self-Attention Guidance
|
||||
- local: api/pipelines/semantic_stable_diffusion
|
||||
|
||||
@@ -15,40 +15,135 @@ specific language governing permissions and limitations under the License.
|
||||
An attention processor is a class for applying different types of attention mechanisms.
|
||||
|
||||
## AttnProcessor
|
||||
|
||||
[[autodoc]] models.attention_processor.AttnProcessor
|
||||
|
||||
## AttnProcessor2_0
|
||||
[[autodoc]] models.attention_processor.AttnProcessor2_0
|
||||
|
||||
## AttnAddedKVProcessor
|
||||
[[autodoc]] models.attention_processor.AttnAddedKVProcessor
|
||||
|
||||
## AttnAddedKVProcessor2_0
|
||||
[[autodoc]] models.attention_processor.AttnAddedKVProcessor2_0
|
||||
|
||||
## CrossFrameAttnProcessor
|
||||
[[autodoc]] pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.CrossFrameAttnProcessor
|
||||
[[autodoc]] models.attention_processor.AttnProcessorNPU
|
||||
|
||||
## CustomDiffusionAttnProcessor
|
||||
[[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor
|
||||
|
||||
## CustomDiffusionAttnProcessor2_0
|
||||
[[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor2_0
|
||||
|
||||
## CustomDiffusionXFormersAttnProcessor
|
||||
[[autodoc]] models.attention_processor.CustomDiffusionXFormersAttnProcessor
|
||||
|
||||
## FusedAttnProcessor2_0
|
||||
[[autodoc]] models.attention_processor.FusedAttnProcessor2_0
|
||||
|
||||
## Allegro
|
||||
|
||||
[[autodoc]] models.attention_processor.AllegroAttnProcessor2_0
|
||||
|
||||
## AuraFlow
|
||||
|
||||
[[autodoc]] models.attention_processor.AuraFlowAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.FusedAuraFlowAttnProcessor2_0
|
||||
|
||||
## CogVideoX
|
||||
|
||||
[[autodoc]] models.attention_processor.CogVideoXAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.FusedCogVideoXAttnProcessor2_0
|
||||
|
||||
## CrossFrameAttnProcessor
|
||||
|
||||
[[autodoc]] pipelines.text_to_video_synthesis.pipeline_text_to_video_zero.CrossFrameAttnProcessor
|
||||
|
||||
## Custom Diffusion
|
||||
|
||||
[[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor
|
||||
|
||||
[[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.CustomDiffusionXFormersAttnProcessor
|
||||
|
||||
## Flux
|
||||
|
||||
[[autodoc]] models.attention_processor.FluxAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.FusedFluxAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.FluxSingleAttnProcessor2_0
|
||||
|
||||
## Hunyuan
|
||||
|
||||
[[autodoc]] models.attention_processor.HunyuanAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.FusedHunyuanAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.PAGHunyuanAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.PAGCFGHunyuanAttnProcessor2_0
|
||||
|
||||
## IdentitySelfAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.PAGIdentitySelfAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.PAGCFGIdentitySelfAttnProcessor2_0
|
||||
|
||||
## IP-Adapter
|
||||
|
||||
[[autodoc]] models.attention_processor.IPAdapterAttnProcessor
|
||||
|
||||
[[autodoc]] models.attention_processor.IPAdapterAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.SD3IPAdapterJointAttnProcessor2_0
|
||||
|
||||
## JointAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.JointAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.PAGJointAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.PAGCFGJointAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.FusedJointAttnProcessor2_0
|
||||
|
||||
## LoRA
|
||||
|
||||
[[autodoc]] models.attention_processor.LoRAAttnProcessor
|
||||
|
||||
[[autodoc]] models.attention_processor.LoRAAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.LoRAAttnAddedKVProcessor
|
||||
|
||||
[[autodoc]] models.attention_processor.LoRAXFormersAttnProcessor
|
||||
|
||||
## Lumina-T2X
|
||||
|
||||
[[autodoc]] models.attention_processor.LuminaAttnProcessor2_0
|
||||
|
||||
## Mochi
|
||||
|
||||
[[autodoc]] models.attention_processor.MochiAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.MochiVaeAttnProcessor2_0
|
||||
|
||||
## Sana
|
||||
|
||||
[[autodoc]] models.attention_processor.SanaLinearAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.SanaMultiscaleAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.PAGCFGSanaLinearAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.PAGIdentitySanaLinearAttnProcessor2_0
|
||||
|
||||
## Stable Audio
|
||||
|
||||
[[autodoc]] models.attention_processor.StableAudioAttnProcessor2_0
|
||||
|
||||
## SlicedAttnProcessor
|
||||
|
||||
[[autodoc]] models.attention_processor.SlicedAttnProcessor
|
||||
|
||||
## SlicedAttnAddedKVProcessor
|
||||
[[autodoc]] models.attention_processor.SlicedAttnAddedKVProcessor
|
||||
|
||||
## XFormersAttnProcessor
|
||||
|
||||
[[autodoc]] models.attention_processor.XFormersAttnProcessor
|
||||
|
||||
## AttnProcessorNPU
|
||||
[[autodoc]] models.attention_processor.AttnProcessorNPU
|
||||
[[autodoc]] models.attention_processor.XFormersAttnAddedKVProcessor
|
||||
|
||||
## XLAFlashAttnProcessor2_0
|
||||
|
||||
[[autodoc]] models.attention_processor.XLAFlashAttnProcessor2_0
|
||||
|
||||
@@ -24,6 +24,12 @@ Learn how to load an IP-Adapter checkpoint and image in the IP-Adapter [loading]
|
||||
|
||||
[[autodoc]] loaders.ip_adapter.IPAdapterMixin
|
||||
|
||||
## SD3IPAdapterMixin
|
||||
|
||||
[[autodoc]] loaders.ip_adapter.SD3IPAdapterMixin
|
||||
- all
|
||||
- is_ip_adapter_active
|
||||
|
||||
## IPAdapterMaskProcessor
|
||||
|
||||
[[autodoc]] image_processor.IPAdapterMaskProcessor
|
||||
@@ -17,6 +17,9 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
|
||||
- [`StableDiffusionLoraLoaderMixin`] provides functions for loading and unloading, fusing and unfusing, enabling and disabling, and more functions for managing LoRA weights. This class can be used with any model.
|
||||
- [`StableDiffusionXLLoraLoaderMixin`] is a [Stable Diffusion (SDXL)](../../api/pipelines/stable_diffusion/stable_diffusion_xl) version of the [`StableDiffusionLoraLoaderMixin`] class for loading and saving LoRA weights. It can only be used with the SDXL model.
|
||||
- [`SD3LoraLoaderMixin`] provides similar functions for [Stable Diffusion 3](https://huggingface.co/blog/sd3).
|
||||
- [`FluxLoraLoaderMixin`] provides similar functions for [Flux](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux).
|
||||
- [`CogVideoXLoraLoaderMixin`] provides similar functions for [CogVideoX](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox).
|
||||
- [`Mochi1LoraLoaderMixin`] provides similar functions for [Mochi](https://huggingface.co/docs/diffusers/main/en/api/pipelines/mochi).
|
||||
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
|
||||
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.
|
||||
|
||||
@@ -38,6 +41,18 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.SD3LoraLoaderMixin
|
||||
|
||||
## FluxLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.FluxLoraLoaderMixin
|
||||
|
||||
## CogVideoXLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.CogVideoXLoraLoaderMixin
|
||||
|
||||
## Mochi1LoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.Mochi1LoraLoaderMixin
|
||||
|
||||
## AmusedLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.AmusedLoraLoaderMixin
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# SD3Transformer2D
|
||||
|
||||
This class is useful when *only* loading weights into a [`SD3Transformer2DModel`]. If you need to load weights into the text encoder or a text encoder and SD3Transformer2DModel, check [`SD3LoraLoaderMixin`](lora#diffusers.loaders.SD3LoraLoaderMixin) class instead.
|
||||
|
||||
The [`SD3Transformer2DLoadersMixin`] class currently only loads IP-Adapter weights, but will be used in the future to save weights and load LoRAs.
|
||||
|
||||
<Tip>
|
||||
|
||||
To learn more about how to load LoRA weights, see the [LoRA](../../using-diffusers/loading_adapters#lora) loading guide.
|
||||
|
||||
</Tip>
|
||||
|
||||
## SD3Transformer2DLoadersMixin
|
||||
|
||||
[[autodoc]] loaders.transformer_sd3.SD3Transformer2DLoadersMixin
|
||||
- all
|
||||
- _load_ip_adapter_weights
|
||||
@@ -0,0 +1,72 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# AutoencoderDC
|
||||
|
||||
The 2D Autoencoder model used in [SANA](https://huggingface.co/papers/2410.10629) and introduced in [DCAE](https://huggingface.co/papers/2410.10733) by authors Junyu Chen\*, Han Cai\*, Junsong Chen, Enze Xie, Shang Yang, Haotian Tang, Muyang Li, Yao Lu, Song Han from MIT HAN Lab.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*We present Deep Compression Autoencoder (DC-AE), a new family of autoencoder models for accelerating high-resolution diffusion models. Existing autoencoder models have demonstrated impressive results at a moderate spatial compression ratio (e.g., 8x), but fail to maintain satisfactory reconstruction accuracy for high spatial compression ratios (e.g., 64x). We address this challenge by introducing two key techniques: (1) Residual Autoencoding, where we design our models to learn residuals based on the space-to-channel transformed features to alleviate the optimization difficulty of high spatial-compression autoencoders; (2) Decoupled High-Resolution Adaptation, an efficient decoupled three-phases training strategy for mitigating the generalization penalty of high spatial-compression autoencoders. With these designs, we improve the autoencoder's spatial compression ratio up to 128 while maintaining the reconstruction quality. Applying our DC-AE to latent diffusion models, we achieve significant speedup without accuracy drop. For example, on ImageNet 512x512, our DC-AE provides 19.1x inference speedup and 17.9x training speedup on H100 GPU for UViT-H while achieving a better FID, compared with the widely used SD-VAE-f8 autoencoder. Our code is available at [this https URL](https://github.com/mit-han-lab/efficientvit).*
|
||||
|
||||
The following DCAE models are released and supported in Diffusers.
|
||||
|
||||
| Diffusers format | Original format |
|
||||
|:----------------:|:---------------:|
|
||||
| [`mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers) | [`mit-han-lab/dc-ae-f32c32-sana-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0)
|
||||
| [`mit-han-lab/dc-ae-f32c32-in-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-in-1.0-diffusers) | [`mit-han-lab/dc-ae-f32c32-in-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-in-1.0)
|
||||
| [`mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers) | [`mit-han-lab/dc-ae-f32c32-mix-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f32c32-mix-1.0)
|
||||
| [`mit-han-lab/dc-ae-f64c128-in-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-in-1.0-diffusers) | [`mit-han-lab/dc-ae-f64c128-in-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-in-1.0)
|
||||
| [`mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers) | [`mit-han-lab/dc-ae-f64c128-mix-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f64c128-mix-1.0)
|
||||
| [`mit-han-lab/dc-ae-f128c512-in-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-in-1.0-diffusers) | [`mit-han-lab/dc-ae-f128c512-in-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-in-1.0)
|
||||
| [`mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers) | [`mit-han-lab/dc-ae-f128c512-mix-1.0`](https://huggingface.co/mit-han-lab/dc-ae-f128c512-mix-1.0)
|
||||
|
||||
This model was contributed by [lawrence-cj](https://github.com/lawrence-cj).
|
||||
|
||||
Load a model in Diffusers format with [`~ModelMixin.from_pretrained`].
|
||||
|
||||
```python
|
||||
from diffusers import AutoencoderDC
|
||||
|
||||
ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers", torch_dtype=torch.float32).to("cuda")
|
||||
```
|
||||
|
||||
## Load a model in Diffusers via `from_single_file`
|
||||
|
||||
```python
|
||||
from difusers import AutoencoderDC
|
||||
|
||||
ckpt_path = "https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0/blob/main/model.safetensors"
|
||||
model = AutoencoderDC.from_single_file(ckpt_path)
|
||||
|
||||
```
|
||||
|
||||
The `AutoencoderDC` model has `in` and `mix` single file checkpoint variants that have matching checkpoint keys, but use different scaling factors. It is not possible for Diffusers to automatically infer the correct config file to use with the model based on just the checkpoint and will default to configuring the model using the `mix` variant config file. To override the automatically determined config, please use the `config` argument when using single file loading with `in` variant checkpoints.
|
||||
|
||||
```python
|
||||
from diffusers import AutoencoderDC
|
||||
|
||||
ckpt_path = "https://huggingface.co/mit-han-lab/dc-ae-f128c512-in-1.0/blob/main/model.safetensors"
|
||||
model = AutoencoderDC.from_single_file(ckpt_path, config="mit-han-lab/dc-ae-f128c512-in-1.0-diffusers")
|
||||
```
|
||||
|
||||
|
||||
## AutoencoderDC
|
||||
|
||||
[[autodoc]] AutoencoderDC
|
||||
- encode
|
||||
- decode
|
||||
- all
|
||||
|
||||
## DecoderOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.vae.DecoderOutput
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# AutoencoderKLHunyuanVideo
|
||||
|
||||
The 3D variational autoencoder (VAE) model with KL loss used in [HunyuanVideo](https://github.com/Tencent/HunyuanVideo/), which was introduced in [HunyuanVideo: A Systematic Framework For Large Video Generative Models](https://huggingface.co/papers/2412.03603) by Tencent.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import AutoencoderKLHunyuanVideo
|
||||
|
||||
vae = AutoencoderKLHunyuanVideo.from_pretrained("tencent/HunyuanVideo", torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
## AutoencoderKLHunyuanVideo
|
||||
|
||||
[[autodoc]] AutoencoderKLHunyuanVideo
|
||||
- decode
|
||||
- all
|
||||
|
||||
## DecoderOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.vae.DecoderOutput
|
||||
@@ -0,0 +1,37 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# AutoencoderKLLTXVideo
|
||||
|
||||
The 3D variational autoencoder (VAE) model with KL loss used in [LTX](https://huggingface.co/Lightricks/LTX-Video) was introduced by Lightricks.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import AutoencoderKLLTXVideo
|
||||
|
||||
vae = AutoencoderKLLTXVideo.from_pretrained("TODO/TODO", subfolder="vae", torch_dtype=torch.float32).to("cuda")
|
||||
```
|
||||
|
||||
## AutoencoderKLLTXVideo
|
||||
|
||||
[[autodoc]] AutoencoderKLLTXVideo
|
||||
- decode
|
||||
- encode
|
||||
- all
|
||||
|
||||
## AutoencoderKLOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
|
||||
|
||||
## DecoderOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.vae.DecoderOutput
|
||||
@@ -0,0 +1,35 @@
|
||||
<!--Copyright 2024 The HuggingFace Team and The InstantX Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# ControlNetUnionModel
|
||||
|
||||
ControlNetUnionModel is an implementation of ControlNet for Stable Diffusion XL.
|
||||
|
||||
The ControlNet model was introduced in [ControlNetPlus](https://github.com/xinsir6/ControlNetPlus) by xinsir6. It supports multiple conditioning inputs without increasing computation.
|
||||
|
||||
*We design a new architecture that can support 10+ control types in condition text-to-image generation and can generate high resolution images visually comparable with midjourney. The network is based on the original ControlNet architecture, we propose two new modules to: 1 Extend the original ControlNet to support different image conditions using the same network parameter. 2 Support multiple conditions input without increasing computation offload, which is especially important for designers who want to edit image in detail, different conditions use the same condition encoder, without adding extra computations or parameters.*
|
||||
|
||||
## Loading
|
||||
|
||||
By default the [`ControlNetUnionModel`] should be loaded with [`~ModelMixin.from_pretrained`].
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionXLControlNetUnionPipeline, ControlNetUnionModel
|
||||
|
||||
controlnet = ControlNetUnionModel.from_pretrained("xinsir/controlnet-union-sdxl-1.0")
|
||||
pipe = StableDiffusionXLControlNetUnionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet)
|
||||
```
|
||||
|
||||
## ControlNetUnionModel
|
||||
|
||||
[[autodoc]] ControlNetUnionModel
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# HunyuanVideoTransformer3DModel
|
||||
|
||||
A Diffusion Transformer model for 3D video-like data was introduced in [HunyuanVideo: A Systematic Framework For Large Video Generative Models](https://huggingface.co/papers/2412.03603) by Tencent.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import HunyuanVideoTransformer3DModel
|
||||
|
||||
transformer = HunyuanVideoTransformer3DModel.from_pretrained("tencent/HunyuanVideo", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## HunyuanVideoTransformer3DModel
|
||||
|
||||
[[autodoc]] HunyuanVideoTransformer3DModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
|
||||
@@ -0,0 +1,30 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# LTXVideoTransformer3DModel
|
||||
|
||||
A Diffusion Transformer model for 3D data from [LTX](https://huggingface.co/Lightricks/LTX-Video) was introduced by Lightricks.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import LTXVideoTransformer3DModel
|
||||
|
||||
transformer = LTXVideoTransformer3DModel.from_pretrained("TODO/TODO", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
|
||||
```
|
||||
|
||||
## LTXVideoTransformer3DModel
|
||||
|
||||
[[autodoc]] LTXVideoTransformer3DModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
|
||||
@@ -0,0 +1,34 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# SanaTransformer2DModel
|
||||
|
||||
A Diffusion Transformer model for 2D data from [SANA: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformers](https://huggingface.co/papers/2410.10629) was introduced from NVIDIA and MIT HAN Lab, by Enze Xie, Junsong Chen, Junyu Chen, Han Cai, Haotian Tang, Yujun Lin, Zhekai Zhang, Muyang Li, Ligeng Zhu, Yao Lu, Song Han.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*We introduce Sana, a text-to-image framework that can efficiently generate images up to 4096×4096 resolution. Sana can synthesize high-resolution, high-quality images with strong text-image alignment at a remarkably fast speed, deployable on laptop GPU. Core designs include: (1) Deep compression autoencoder: unlike traditional AEs, which compress images only 8×, we trained an AE that can compress images 32×, effectively reducing the number of latent tokens. (2) Linear DiT: we replace all vanilla attention in DiT with linear attention, which is more efficient at high resolutions without sacrificing quality. (3) Decoder-only text encoder: we replaced T5 with modern decoder-only small LLM as the text encoder and designed complex human instruction with in-context learning to enhance the image-text alignment. (4) Efficient training and sampling: we propose Flow-DPM-Solver to reduce sampling steps, with efficient caption labeling and selection to accelerate convergence. As a result, Sana-0.6B is very competitive with modern giant diffusion model (e.g. Flux-12B), being 20 times smaller and 100+ times faster in measured throughput. Moreover, Sana-0.6B can be deployed on a 16GB laptop GPU, taking less than 1 second to generate a 1024×1024 resolution image. Sana enables content creation at low cost. Code and model will be publicly released.*
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import SanaTransformer2DModel
|
||||
|
||||
transformer = SanaTransformer2DModel.from_pretrained("Efficient-Large-Model/Sana_1600M_1024px_diffusers", subfolder="transformer", torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
## SanaTransformer2DModel
|
||||
|
||||
[[autodoc]] SanaTransformer2DModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
|
||||
@@ -0,0 +1,89 @@
|
||||
<!--Copyright 2024 The HuggingFace Team, The Black Forest Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# FluxControlInpaint
|
||||
|
||||
FluxControlInpaintPipeline is an implementation of Inpainting for Flux.1 Depth/Canny models. It is a pipeline that allows you to inpaint images using the Flux.1 Depth/Canny models. The pipeline takes an image and a mask as input and returns the inpainted image.
|
||||
|
||||
FLUX.1 Depth and Canny [dev] is a 12 billion parameter rectified flow transformer capable of generating an image based on a text description while following the structure of a given input image. **This is not a ControlNet model**.
|
||||
|
||||
| Control type | Developer | Link |
|
||||
| -------- | ---------- | ---- |
|
||||
| Depth | [Black Forest Labs](https://huggingface.co/black-forest-labs) | [Link](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev) |
|
||||
| Canny | [Black Forest Labs](https://huggingface.co/black-forest-labs) | [Link](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev) |
|
||||
|
||||
|
||||
<Tip>
|
||||
|
||||
Flux can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details. Additionally, Flux can benefit from quantization for memory efficiency with a trade-off in inference latency. Refer to [this blog post](https://huggingface.co/blog/quanto-diffusers) to learn more. For an exhaustive list of resources, check out [this gist](https://gist.github.com/sayakpaul/b664605caf0aa3bf8585ab109dd5ac9c).
|
||||
|
||||
</Tip>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxControlInpaintPipeline
|
||||
from diffusers.models.transformers import FluxTransformer2DModel
|
||||
from transformers import T5EncoderModel
|
||||
from diffusers.utils import load_image, make_image_grid
|
||||
from image_gen_aux import DepthPreprocessor # https://github.com/huggingface/image_gen_aux
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
pipe = FluxControlInpaintPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-Depth-dev",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
# use following lines if you have GPU constraints
|
||||
# ---------------------------------------------------------------
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
"sayakpaul/FLUX.1-Depth-dev-nf4", subfolder="transformer", torch_dtype=torch.bfloat16
|
||||
)
|
||||
text_encoder_2 = T5EncoderModel.from_pretrained(
|
||||
"sayakpaul/FLUX.1-Depth-dev-nf4", subfolder="text_encoder_2", torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.transformer = transformer
|
||||
pipe.text_encoder_2 = text_encoder_2
|
||||
pipe.enable_model_cpu_offload()
|
||||
# ---------------------------------------------------------------
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "a blue robot singing opera with human-like expressions"
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
|
||||
|
||||
head_mask = np.zeros_like(image)
|
||||
head_mask[65:580,300:642] = 255
|
||||
mask_image = Image.fromarray(head_mask)
|
||||
|
||||
processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
|
||||
control_image = processor(image)[0].convert("RGB")
|
||||
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
control_image=control_image,
|
||||
mask_image=mask_image,
|
||||
num_inference_steps=30,
|
||||
strength=0.9,
|
||||
guidance_scale=10.0,
|
||||
generator=torch.Generator().manual_seed(42),
|
||||
).images[0]
|
||||
make_image_grid([image, control_image, mask_image, output.resize(image.size)], rows=1, cols=4).save("output.png")
|
||||
```
|
||||
|
||||
## FluxControlInpaintPipeline
|
||||
[[autodoc]] FluxControlInpaintPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
## FluxPipelineOutput
|
||||
[[autodoc]] pipelines.flux.pipeline_output.FluxPipelineOutput
|
||||
@@ -0,0 +1,35 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# ControlNetUnion
|
||||
|
||||
ControlNetUnionModel is an implementation of ControlNet for Stable Diffusion XL.
|
||||
|
||||
The ControlNet model was introduced in [ControlNetPlus](https://github.com/xinsir6/ControlNetPlus) by xinsir6. It supports multiple conditioning inputs without increasing computation.
|
||||
|
||||
*We design a new architecture that can support 10+ control types in condition text-to-image generation and can generate high resolution images visually comparable with midjourney. The network is based on the original ControlNet architecture, we propose two new modules to: 1 Extend the original ControlNet to support different image conditions using the same network parameter. 2 Support multiple conditions input without increasing computation offload, which is especially important for designers who want to edit image in detail, different conditions use the same condition encoder, without adding extra computations or parameters.*
|
||||
|
||||
|
||||
## StableDiffusionXLControlNetUnionPipeline
|
||||
[[autodoc]] StableDiffusionXLControlNetUnionPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusionXLControlNetUnionImg2ImgPipeline
|
||||
[[autodoc]] StableDiffusionXLControlNetUnionImg2ImgPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusionXLControlNetUnionInpaintPipeline
|
||||
[[autodoc]] StableDiffusionXLControlNetUnionInpaintPipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -143,6 +143,35 @@ image = pipe(
|
||||
image.save("output.png")
|
||||
```
|
||||
|
||||
Canny Control is also possible with a LoRA variant of this condition. The usage is as follows:
|
||||
|
||||
```python
|
||||
# !pip install -U controlnet-aux
|
||||
import torch
|
||||
from controlnet_aux import CannyDetector
|
||||
from diffusers import FluxControlPipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda")
|
||||
pipe.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora")
|
||||
|
||||
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
|
||||
control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
|
||||
|
||||
processor = CannyDetector()
|
||||
control_image = processor(control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024)
|
||||
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
control_image=control_image,
|
||||
height=1024,
|
||||
width=1024,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=30.0,
|
||||
).images[0]
|
||||
image.save("output.png")
|
||||
```
|
||||
|
||||
### Depth Control
|
||||
|
||||
**Note:** `black-forest-labs/Flux.1-Depth-dev` is _not_ a ControlNet model. [`ControlNetModel`] models are a separate component from the UNet/Transformer whose residuals are added to the actual underlying model. Depth Control is an alternate architecture that achieves effectively the same results as a ControlNet model would, by using channel-wise concatenation with input control condition and ensuring the transformer learns structure control by following the condition as closely as possible.
|
||||
@@ -174,6 +203,36 @@ image = pipe(
|
||||
image.save("output.png")
|
||||
```
|
||||
|
||||
Depth Control is also possible with a LoRA variant of this condition. The usage is as follows:
|
||||
|
||||
```python
|
||||
# !pip install git+https://github.com/huggingface/image_gen_aux
|
||||
import torch
|
||||
from diffusers import FluxControlPipeline, FluxTransformer2DModel
|
||||
from diffusers.utils import load_image
|
||||
from image_gen_aux import DepthPreprocessor
|
||||
|
||||
pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda")
|
||||
pipe.load_lora_weights("black-forest-labs/FLUX.1-Depth-dev-lora")
|
||||
|
||||
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
|
||||
control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
|
||||
|
||||
processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
|
||||
control_image = processor(control_image)[0].convert("RGB")
|
||||
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
control_image=control_image,
|
||||
height=1024,
|
||||
width=1024,
|
||||
num_inference_steps=30,
|
||||
guidance_scale=10.0,
|
||||
generator=torch.Generator().manual_seed(42),
|
||||
).images[0]
|
||||
image.save("output.png")
|
||||
```
|
||||
|
||||
### Redux
|
||||
|
||||
* Flux Redux pipeline is an adapter for FLUX.1 base models. It can be used with both flux-dev and flux-schnell, for image-to-image generation.
|
||||
@@ -209,6 +268,43 @@ images = pipe(
|
||||
images[0].save("flux-redux.png")
|
||||
```
|
||||
|
||||
## Combining Flux Turbo LoRAs with Flux Control, Fill, and Redux
|
||||
|
||||
We can combine Flux Turbo LoRAs with Flux Control and other pipelines like Fill and Redux to enable few-steps' inference. The example below shows how to do that for Flux Control LoRA for depth and turbo LoRA from [`ByteDance/Hyper-SD`](https://hf.co/ByteDance/Hyper-SD).
|
||||
|
||||
```py
|
||||
from diffusers import FluxControlPipeline
|
||||
from image_gen_aux import DepthPreprocessor
|
||||
from diffusers.utils import load_image
|
||||
from huggingface_hub import hf_hub_download
|
||||
import torch
|
||||
|
||||
control_pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
|
||||
control_pipe.load_lora_weights("black-forest-labs/FLUX.1-Depth-dev-lora", adapter_name="depth")
|
||||
control_pipe.load_lora_weights(
|
||||
hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
|
||||
)
|
||||
control_pipe.set_adapters(["depth", "hyper-sd"], adapter_weights=[0.85, 0.125])
|
||||
control_pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
|
||||
control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
|
||||
|
||||
processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
|
||||
control_image = processor(control_image)[0].convert("RGB")
|
||||
|
||||
image = control_pipe(
|
||||
prompt=prompt,
|
||||
control_image=control_image,
|
||||
height=1024,
|
||||
width=1024,
|
||||
num_inference_steps=8,
|
||||
guidance_scale=10.0,
|
||||
generator=torch.Generator().manual_seed(42),
|
||||
).images[0]
|
||||
image.save("output.png")
|
||||
```
|
||||
|
||||
## Running FP16 inference
|
||||
|
||||
Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details.
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# HunyuanVideo
|
||||
|
||||
[HunyuanVideo](https://www.arxiv.org/abs/2412.03603) by Tencent.
|
||||
|
||||
*Recent advancements in video generation have significantly impacted daily life for both individuals and industries. However, the leading video generation models remain closed-source, resulting in a notable performance gap between industry capabilities and those available to the public. In this report, we introduce HunyuanVideo, an innovative open-source video foundation model that demonstrates performance in video generation comparable to, or even surpassing, that of leading closed-source models. HunyuanVideo encompasses a comprehensive framework that integrates several key elements, including data curation, advanced architectural design, progressive model scaling and training, and an efficient infrastructure tailored for large-scale model training and inference. As a result, we successfully trained a video generative model with over 13 billion parameters, making it the largest among all open-source models. We conducted extensive experiments and implemented a series of targeted designs to ensure high visual quality, motion dynamics, text-video alignment, and advanced filming techniques. According to evaluations by professionals, HunyuanVideo outperforms previous state-of-the-art models, including Runway Gen-3, Luma 1.6, and three top-performing Chinese video generative models. By releasing the code for the foundation model and its applications, we aim to bridge the gap between closed-source and open-source communities. This initiative will empower individuals within the community to experiment with their ideas, fostering a more dynamic and vibrant video generation ecosystem. The code is publicly available at [this https URL](https://github.com/Tencent/HunyuanVideo).*
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
Recommendations for inference:
|
||||
- Both text encoders should be in `torch.float16`.
|
||||
- Transformer should be in `torch.bfloat16`.
|
||||
- VAE should be in `torch.float16`.
|
||||
- `num_frames` should be of the form `4 * k + 1`, for example `49` or `129`.
|
||||
- For smaller resolution images, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution images, try higher values (between `7.0` and `12.0`). The default value is `7.0` for HunyuanVideo.
|
||||
- For more information about supported resolutions and other details, please refer to the original repository [here](https://github.com/Tencent/HunyuanVideo/).
|
||||
|
||||
## HunyuanVideoPipeline
|
||||
|
||||
[[autodoc]] HunyuanVideoPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## HunyuanVideoPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.hunyuan_video.pipeline_output.HunyuanVideoPipelineOutput
|
||||
@@ -0,0 +1,118 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# LTX
|
||||
|
||||
[LTX Video](https://huggingface.co/Lightricks/LTX-Video) is the first DiT-based video generation model capable of generating high-quality videos in real-time. It produces 24 FPS videos at a 768x512 resolution faster than they can be watched. Trained on a large-scale dataset of diverse videos, the model generates high-resolution videos with realistic and varied content. We provide a model for both text-to-video as well as image + text-to-video usecases.
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Loading Single Files
|
||||
|
||||
Loading the original LTX Video checkpoints is also possible with [`~ModelMixin.from_single_file`].
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import AutoencoderKLLTXVideo, LTXImageToVideoPipeline, LTXVideoTransformer3DModel
|
||||
|
||||
single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors"
|
||||
transformer = LTXVideoTransformer3DModel.from_single_file(
|
||||
single_file_url, torch_dtype=torch.bfloat16
|
||||
)
|
||||
vae = AutoencoderKLLTXVideo.from_single_file(single_file_url, torch_dtype=torch.bfloat16)
|
||||
pipe = LTXImageToVideoPipeline.from_pretrained(
|
||||
"Lightricks/LTX-Video", transformer=transformer, vae=vae, torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
# ... inference code ...
|
||||
```
|
||||
|
||||
Alternatively, the pipeline can be used to load the weights with [`~FromSingleFileMixin.from_single_file`].
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import LTXImageToVideoPipeline
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
|
||||
single_file_url = "https://huggingface.co/Lightricks/LTX-Video/ltx-video-2b-v0.9.safetensors"
|
||||
text_encoder = T5EncoderModel.from_pretrained(
|
||||
"Lightricks/LTX-Video", subfolder="text_encoder", torch_dtype=torch.bfloat16
|
||||
)
|
||||
tokenizer = T5Tokenizer.from_pretrained(
|
||||
"Lightricks/LTX-Video", subfolder="tokenizer", torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe = LTXImageToVideoPipeline.from_single_file(
|
||||
single_file_url, text_encoder=text_encoder, tokenizer=tokenizer, torch_dtype=torch.bfloat16
|
||||
)
|
||||
```
|
||||
|
||||
Loading [LTX GGUF checkpoints](https://huggingface.co/city96/LTX-Video-gguf) are also supported:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.utils import export_to_video
|
||||
from diffusers import LTXPipeline, LTXVideoTransformer3DModel, GGUFQuantizationConfig
|
||||
|
||||
ckpt_path = (
|
||||
"https://huggingface.co/city96/LTX-Video-gguf/blob/main/ltx-video-2b-v0.9-Q3_K_S.gguf"
|
||||
)
|
||||
transformer = LTXVideoTransformer3DModel.from_single_file(
|
||||
ckpt_path,
|
||||
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe = LTXPipeline.from_pretrained(
|
||||
"Lightricks/LTX-Video",
|
||||
transformer=transformer,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = "A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage"
|
||||
negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
|
||||
|
||||
video = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
width=704,
|
||||
height=480,
|
||||
num_frames=161,
|
||||
num_inference_steps=50,
|
||||
).frames[0]
|
||||
export_to_video(video, "output_gguf_ltx.mp4", fps=24)
|
||||
```
|
||||
|
||||
Make sure to read the [documentation on GGUF](../../quantization/gguf) to learn more about our GGUF support.
|
||||
|
||||
Refer to [this section](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox#memory-optimization) to learn more about optimizing memory consumption.
|
||||
|
||||
## LTXPipeline
|
||||
|
||||
[[autodoc]] LTXPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LTXImageToVideoPipeline
|
||||
|
||||
[[autodoc]] LTXImageToVideoPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LTXPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.ltx.pipeline_output.LTXPipelineOutput
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
-->
|
||||
|
||||
# Mochi
|
||||
# Mochi 1 Preview
|
||||
|
||||
[Mochi 1 Preview](https://huggingface.co/genmo/mochi-1-preview) from Genmo.
|
||||
|
||||
@@ -25,6 +25,201 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
|
||||
|
||||
</Tip>
|
||||
|
||||
## Generating videos with Mochi-1 Preview
|
||||
|
||||
The following example will download the full precision `mochi-1-preview` weights and produce the highest quality results but will require at least 42GB VRAM to run.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import MochiPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview")
|
||||
|
||||
# Enable memory savings
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.enable_vae_tiling()
|
||||
|
||||
prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
|
||||
|
||||
with torch.autocast("cuda", torch.bfloat16, cache_enabled=False):
|
||||
frames = pipe(prompt, num_frames=85).frames[0]
|
||||
|
||||
export_to_video(frames, "mochi.mp4", fps=30)
|
||||
```
|
||||
|
||||
## Using a lower precision variant to save memory
|
||||
|
||||
The following example will use the `bfloat16` variant of the model and requires 22GB VRAM to run. There is a slight drop in the quality of the generated video as a result.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import MochiPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", variant="bf16", torch_dtype=torch.bfloat16)
|
||||
|
||||
# Enable memory savings
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.enable_vae_tiling()
|
||||
|
||||
prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
|
||||
frames = pipe(prompt, num_frames=85).frames[0]
|
||||
|
||||
export_to_video(frames, "mochi.mp4", fps=30)
|
||||
```
|
||||
|
||||
## Reproducing the results from the Genmo Mochi repo
|
||||
|
||||
The [Genmo Mochi implementation](https://github.com/genmoai/mochi/tree/main) uses different precision values for each stage in the inference process. The text encoder and VAE use `torch.float32`, while the DiT uses `torch.bfloat16` with the [attention kernel](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html#torch.nn.attention.sdpa_kernel) set to `EFFICIENT_ATTENTION`. Diffusers pipelines currently do not support setting different `dtypes` for different stages of the pipeline. In order to run inference in the same way as the the original implementation, please refer to the following example.
|
||||
|
||||
<Tip>
|
||||
The original Mochi implementation zeros out empty prompts. However, enabling this option and placing the entire pipeline under autocast can lead to numerical overflows with the T5 text encoder.
|
||||
|
||||
When enabling `force_zeros_for_empty_prompt`, it is recommended to run the text encoding step outside the autocast context in full precision.
|
||||
</Tip>
|
||||
|
||||
<Tip>
|
||||
Decoding the latents in full precision is very memory intensive. You will need at least 70GB VRAM to generate the 163 frames in this example. To reduce memory, either reduce the number of frames or run the decoding step in `torch.bfloat16`.
|
||||
</Tip>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
|
||||
from diffusers import MochiPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
from diffusers.video_processor import VideoProcessor
|
||||
|
||||
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", force_zeros_for_empty_prompt=True)
|
||||
pipe.enable_vae_tiling()
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = "An aerial shot of a parade of elephants walking across the African savannah. The camera showcases the herd and the surrounding landscape."
|
||||
|
||||
with torch.no_grad():
|
||||
prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = (
|
||||
pipe.encode_prompt(prompt=prompt)
|
||||
)
|
||||
|
||||
with torch.autocast("cuda", torch.bfloat16):
|
||||
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
|
||||
frames = pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
guidance_scale=4.5,
|
||||
num_inference_steps=64,
|
||||
height=480,
|
||||
width=848,
|
||||
num_frames=163,
|
||||
generator=torch.Generator("cuda").manual_seed(0),
|
||||
output_type="latent",
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
video_processor = VideoProcessor(vae_scale_factor=8)
|
||||
has_latents_mean = hasattr(pipe.vae.config, "latents_mean") and pipe.vae.config.latents_mean is not None
|
||||
has_latents_std = hasattr(pipe.vae.config, "latents_std") and pipe.vae.config.latents_std is not None
|
||||
if has_latents_mean and has_latents_std:
|
||||
latents_mean = (
|
||||
torch.tensor(pipe.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(frames.device, frames.dtype)
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(pipe.vae.config.latents_std).view(1, 12, 1, 1, 1).to(frames.device, frames.dtype)
|
||||
)
|
||||
frames = frames * latents_std / pipe.vae.config.scaling_factor + latents_mean
|
||||
else:
|
||||
frames = frames / pipe.vae.config.scaling_factor
|
||||
|
||||
with torch.no_grad():
|
||||
video = pipe.vae.decode(frames.to(pipe.vae.dtype), return_dict=False)[0]
|
||||
|
||||
video = video_processor.postprocess_video(video)[0]
|
||||
export_to_video(video, "mochi.mp4", fps=30)
|
||||
```
|
||||
|
||||
## Running inference with multiple GPUs
|
||||
|
||||
It is possible to split the large Mochi transformer across multiple GPUs using the `device_map` and `max_memory` options in `from_pretrained`. In the following example we split the model across two GPUs, each with 24GB of VRAM.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import MochiPipeline, MochiTransformer3DModel
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
model_id = "genmo/mochi-1-preview"
|
||||
transformer = MochiTransformer3DModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="transformer",
|
||||
device_map="auto",
|
||||
max_memory={0: "24GB", 1: "24GB"}
|
||||
)
|
||||
|
||||
pipe = MochiPipeline.from_pretrained(model_id, transformer=transformer)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.enable_vae_tiling()
|
||||
|
||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, cache_enabled=False):
|
||||
frames = pipe(
|
||||
prompt="Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k.",
|
||||
negative_prompt="",
|
||||
height=480,
|
||||
width=848,
|
||||
num_frames=85,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=4.5,
|
||||
num_videos_per_prompt=1,
|
||||
generator=torch.Generator(device="cuda").manual_seed(0),
|
||||
max_sequence_length=256,
|
||||
output_type="pil",
|
||||
).frames[0]
|
||||
|
||||
export_to_video(frames, "output.mp4", fps=30)
|
||||
```
|
||||
|
||||
## Using single file loading with the Mochi Transformer
|
||||
|
||||
You can use `from_single_file` to load the Mochi transformer in its original format.
|
||||
|
||||
<Tip>
|
||||
Diffusers currently doesn't support using the FP8 scaled versions of the Mochi single file checkpoints.
|
||||
</Tip>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import MochiPipeline, MochiTransformer3DModel
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
model_id = "genmo/mochi-1-preview"
|
||||
|
||||
ckpt_path = "https://huggingface.co/Comfy-Org/mochi_preview_repackaged/blob/main/split_files/diffusion_models/mochi_preview_bf16.safetensors"
|
||||
|
||||
transformer = MochiTransformer3DModel.from_pretrained(ckpt_path, torch_dtype=torch.bfloat16)
|
||||
|
||||
pipe = MochiPipeline.from_pretrained(model_id, transformer=transformer)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.enable_vae_tiling()
|
||||
|
||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, cache_enabled=False):
|
||||
frames = pipe(
|
||||
prompt="Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k.",
|
||||
negative_prompt="",
|
||||
height=480,
|
||||
width=848,
|
||||
num_frames=85,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=4.5,
|
||||
num_videos_per_prompt=1,
|
||||
generator=torch.Generator(device="cuda").manual_seed(0),
|
||||
max_sequence_length=256,
|
||||
output_type="pil",
|
||||
).frames[0]
|
||||
|
||||
export_to_video(frames, "output.mp4", fps=30)
|
||||
```
|
||||
|
||||
## MochiPipeline
|
||||
|
||||
[[autodoc]] MochiPipeline
|
||||
|
||||
@@ -48,6 +48,11 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusionPAGInpaintPipeline
|
||||
[[autodoc]] StableDiffusionPAGInpaintPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusionPAGPipeline
|
||||
[[autodoc]] StableDiffusionPAGPipeline
|
||||
- all
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# SanaPipeline
|
||||
|
||||
[SANA: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformers](https://huggingface.co/papers/2410.10629) from NVIDIA and MIT HAN Lab, by Enze Xie, Junsong Chen, Junyu Chen, Han Cai, Haotian Tang, Yujun Lin, Zhekai Zhang, Muyang Li, Ligeng Zhu, Yao Lu, Song Han.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*We introduce Sana, a text-to-image framework that can efficiently generate images up to 4096×4096 resolution. Sana can synthesize high-resolution, high-quality images with strong text-image alignment at a remarkably fast speed, deployable on laptop GPU. Core designs include: (1) Deep compression autoencoder: unlike traditional AEs, which compress images only 8×, we trained an AE that can compress images 32×, effectively reducing the number of latent tokens. (2) Linear DiT: we replace all vanilla attention in DiT with linear attention, which is more efficient at high resolutions without sacrificing quality. (3) Decoder-only text encoder: we replaced T5 with modern decoder-only small LLM as the text encoder and designed complex human instruction with in-context learning to enhance the image-text alignment. (4) Efficient training and sampling: we propose Flow-DPM-Solver to reduce sampling steps, with efficient caption labeling and selection to accelerate convergence. As a result, Sana-0.6B is very competitive with modern giant diffusion model (e.g. Flux-12B), being 20 times smaller and 100+ times faster in measured throughput. Moreover, Sana-0.6B can be deployed on a 16GB laptop GPU, taking less than 1 second to generate a 1024×1024 resolution image. Sana enables content creation at low cost. Code and model will be publicly released.*
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
This pipeline was contributed by [lawrence-cj](https://github.com/lawrence-cj) and [chenjy2003](https://github.com/chenjy2003). The original codebase can be found [here](https://github.com/NVlabs/Sana). The original weights can be found under [hf.co/Efficient-Large-Model](https://huggingface.co/Efficient-Large-Model).
|
||||
|
||||
Available models:
|
||||
|
||||
| Model | Recommended dtype |
|
||||
|:-----:|:-----------------:|
|
||||
| [`Efficient-Large-Model/Sana_1600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_diffusers) | `torch.float16` |
|
||||
| [`Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers) | `torch.float16` |
|
||||
| [`Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers) | `torch.bfloat16` |
|
||||
| [`Efficient-Large-Model/Sana_1600M_512px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_diffusers) | `torch.float16` |
|
||||
| [`Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_MultiLing_diffusers) | `torch.float16` |
|
||||
| [`Efficient-Large-Model/Sana_600M_1024px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_600M_1024px_diffusers) | `torch.float16` |
|
||||
| [`Efficient-Large-Model/Sana_600M_512px_diffusers`](https://huggingface.co/Efficient-Large-Model/Sana_600M_512px_diffusers) | `torch.float16` |
|
||||
|
||||
Refer to [this](https://huggingface.co/collections/Efficient-Large-Model/sana-673efba2a57ed99843f11f9e) collection for more information.
|
||||
|
||||
Note: The recommended dtype mentioned is for the transformer weights. The text encoder and VAE weights must stay in `torch.bfloat16` or `torch.float32` for the model to work correctly. Please refer to the inference example below to see how to load the model with the recommended dtype.
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to pass the `variant` argument for downloaded checkpoints to use lower disk space. Set it to `"fp16"` for models with recommended dtype as `torch.float16`, and `"bf16"` for models with recommended dtype as `torch.bfloat16`. By default, `torch.float32` weights are downloaded, which use twice the amount of disk storage. Additionally, `torch.float32` weights can be downcasted on-the-fly by specifying the `torch_dtype` argument. Read about it in the [docs](https://huggingface.co/docs/diffusers/v0.31.0/en/api/pipelines/overview#diffusers.DiffusionPipeline.from_pretrained).
|
||||
|
||||
</Tip>
|
||||
|
||||
## SanaPipeline
|
||||
|
||||
[[autodoc]] SanaPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## SanaPAGPipeline
|
||||
|
||||
[[autodoc]] SanaPAGPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## SanaPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.sana.pipeline_output.SanaPipelineOutput
|
||||
@@ -59,9 +59,76 @@ image.save("sd3_hello_world.png")
|
||||
- [`stabilityai/stable-diffusion-3.5-large`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large)
|
||||
- [`stabilityai/stable-diffusion-3.5-large-turbo`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large-turbo)
|
||||
|
||||
## Image Prompting with IP-Adapters
|
||||
|
||||
An IP-Adapter lets you prompt SD3 with images, in addition to the text prompt. This is especially useful when describing complex concepts that are difficult to articulate through text alone and you have reference images. To load and use an IP-Adapter, you need:
|
||||
|
||||
- `image_encoder`: Pre-trained vision model used to obtain image features, usually a CLIP image encoder.
|
||||
- `feature_extractor`: Image processor that prepares the input image for the chosen `image_encoder`.
|
||||
- `ip_adapter_id`: Checkpoint containing parameters of image cross attention layers and image projection.
|
||||
|
||||
IP-Adapters are trained for a specific model architecture, so they also work in finetuned variations of the base model. You can use the [`~SD3IPAdapterMixin.set_ip_adapter_scale`] function to adjust how strongly the output aligns with the image prompt. The higher the value, the more closely the model follows the image prompt. A default value of 0.5 is typically a good balance, ensuring the model considers both the text and image prompts equally.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from diffusers import StableDiffusion3Pipeline
|
||||
from transformers import SiglipVisionModel, SiglipImageProcessor
|
||||
|
||||
image_encoder_id = "google/siglip-so400m-patch14-384"
|
||||
ip_adapter_id = "InstantX/SD3.5-Large-IP-Adapter"
|
||||
|
||||
feature_extractor = SiglipImageProcessor.from_pretrained(
|
||||
image_encoder_id,
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
image_encoder = SiglipVisionModel.from_pretrained(
|
||||
image_encoder_id,
|
||||
torch_dtype=torch.float16
|
||||
).to( "cuda")
|
||||
|
||||
pipe = StableDiffusion3Pipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-3.5-large",
|
||||
torch_dtype=torch.float16,
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
).to("cuda")
|
||||
|
||||
pipe.load_ip_adapter(ip_adapter_id)
|
||||
pipe.set_ip_adapter_scale(0.6)
|
||||
|
||||
ref_img = Image.open("image.jpg").convert('RGB')
|
||||
|
||||
image = pipe(
|
||||
width=1024,
|
||||
height=1024,
|
||||
prompt="a cat",
|
||||
negative_prompt="lowres, low quality, worst quality",
|
||||
num_inference_steps=24,
|
||||
guidance_scale=5.0,
|
||||
ip_adapter_image=ref_img
|
||||
).images[0]
|
||||
|
||||
image.save("result.jpg")
|
||||
```
|
||||
|
||||
<div class="justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sd3_ip_adapter_example.png"/>
|
||||
<figcaption class="mt-2 text-sm text-center text-gray-500">IP-Adapter examples with prompt "a cat"</figcaption>
|
||||
</div>
|
||||
|
||||
|
||||
<Tip>
|
||||
|
||||
Check out [IP-Adapter](../../../using-diffusers/ip_adapter) to learn more about how IP-Adapters work.
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
## Memory Optimisations for SD3
|
||||
|
||||
SD3 uses three text encoders, one if which is the very large T5-XXL model. This makes it challenging to run the model on GPUs with less than 24GB of VRAM, even when using `fp16` precision. The following section outlines a few memory optimizations in Diffusers that make it easier to run SD3 on low resource hardware.
|
||||
SD3 uses three text encoders, one of which is the very large T5-XXL model. This makes it challenging to run the model on GPUs with less than 24GB of VRAM, even when using `fp16` precision. The following section outlines a few memory optimizations in Diffusers that make it easier to run SD3 on low resource hardware.
|
||||
|
||||
### Running Inference with Model Offloading
|
||||
|
||||
|
||||
@@ -28,6 +28,13 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui
|
||||
|
||||
[[autodoc]] BitsAndBytesConfig
|
||||
|
||||
## GGUFQuantizationConfig
|
||||
|
||||
[[autodoc]] GGUFQuantizationConfig
|
||||
## TorchAoConfig
|
||||
|
||||
[[autodoc]] TorchAoConfig
|
||||
|
||||
## DiffusersQuantizer
|
||||
|
||||
[[autodoc]] quantizers.base.DiffusersQuantizer
|
||||
|
||||
@@ -17,6 +17,12 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
4-bit quantization compresses a model even further, and it is commonly used with [QLoRA](https://hf.co/papers/2305.14314) to finetune quantized LLMs.
|
||||
|
||||
This guide demonstrates how quantization can enable running
|
||||
[FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)
|
||||
on less than 16GB of VRAM and even on a free Google
|
||||
Colab instance.
|
||||
|
||||

|
||||
|
||||
To use bitsandbytes, make sure you have the following libraries installed:
|
||||
|
||||
@@ -31,70 +37,167 @@ Now you can quantize a model by passing a [`BitsAndBytesConfig`] to [`~ModelMixi
|
||||
|
||||
Quantizing a model in 8-bit halves the memory-usage:
|
||||
|
||||
bitsandbytes is supported in both Transformers and Diffusers, so you can quantize both the
|
||||
[`FluxTransformer2DModel`] and [`~transformers.T5EncoderModel`].
|
||||
|
||||
For Ada and higher-series GPUs. we recommend changing `torch_dtype` to `torch.bfloat16`.
|
||||
|
||||
> [!TIP]
|
||||
> The [`CLIPTextModel`] and [`AutoencoderKL`] aren't quantized because they're already small in size and because [`AutoencoderKL`] only has a few `torch.nn.Linear` layers.
|
||||
|
||||
```py
|
||||
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
|
||||
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
|
||||
|
||||
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
model_8bit = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
quant_config = TransformersBitsAndBytesConfig(load_in_8bit=True,)
|
||||
|
||||
text_encoder_2_8bit = T5EncoderModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="text_encoder_2",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True,)
|
||||
|
||||
transformer_8bit = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
```
|
||||
|
||||
By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter if you want:
|
||||
By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter.
|
||||
|
||||
```py
|
||||
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
|
||||
|
||||
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
|
||||
model_8bit = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
```diff
|
||||
transformer_8bit = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.float32
|
||||
quantization_config=quant_config,
|
||||
+ torch_dtype=torch.float32,
|
||||
)
|
||||
model_8bit.transformer_blocks.layers[-1].norm2.weight.dtype
|
||||
```
|
||||
|
||||
Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization `config.json` file is pushed first, followed by the quantized model weights. You can also save the serialized 4-bit models locally with [`~ModelMixin.save_pretrained`].
|
||||
Let's generate an image using our quantized models.
|
||||
|
||||
Setting `device_map="auto"` automatically fills all available space on the GPU(s) first, then the
|
||||
CPU, and finally, the hard drive (the absolute slowest option) if there is still not enough memory.
|
||||
|
||||
```py
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
transformer=transformer_8bit,
|
||||
text_encoder_2=text_encoder_2_8bit,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
pipe_kwargs = {
|
||||
"prompt": "A cat holding a sign that says hello world",
|
||||
"height": 1024,
|
||||
"width": 1024,
|
||||
"guidance_scale": 3.5,
|
||||
"num_inference_steps": 50,
|
||||
"max_sequence_length": 512,
|
||||
}
|
||||
|
||||
image = pipe(**pipe_kwargs, generator=torch.manual_seed(0),).images[0]
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/quant-bnb/8bit.png"/>
|
||||
</div>
|
||||
|
||||
When there is enough memory, you can also directly move the pipeline to the GPU with `.to("cuda")` and apply [`~DiffusionPipeline.enable_model_cpu_offload`] to optimize GPU memory usage.
|
||||
|
||||
Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization `config.json` file is pushed first, followed by the quantized model weights. You can also save the serialized 8-bit models locally with [`~ModelMixin.save_pretrained`].
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="4-bit">
|
||||
|
||||
Quantizing a model in 4-bit reduces your memory-usage by 4x:
|
||||
|
||||
bitsandbytes is supported in both Transformers and Diffusers, so you can can quantize both the
|
||||
[`FluxTransformer2DModel`] and [`~transformers.T5EncoderModel`].
|
||||
|
||||
For Ada and higher-series GPUs. we recommend changing `torch_dtype` to `torch.bfloat16`.
|
||||
|
||||
> [!TIP]
|
||||
> The [`CLIPTextModel`] and [`AutoencoderKL`] aren't quantized because they're already small in size and because [`AutoencoderKL`] only has a few `torch.nn.Linear` layers.
|
||||
|
||||
```py
|
||||
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
|
||||
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
|
||||
|
||||
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
model_4bit = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
quant_config = TransformersBitsAndBytesConfig(load_in_4bit=True,)
|
||||
|
||||
text_encoder_2_4bit = T5EncoderModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="text_encoder_2",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_4bit=True,)
|
||||
|
||||
transformer_4bit = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
```
|
||||
|
||||
By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter if you want:
|
||||
By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter.
|
||||
|
||||
```py
|
||||
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
|
||||
|
||||
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
|
||||
model_4bit = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
```diff
|
||||
transformer_4bit = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.float32
|
||||
quantization_config=quant_config,
|
||||
+ torch_dtype=torch.float32,
|
||||
)
|
||||
model_4bit.transformer_blocks.layers[-1].norm2.weight.dtype
|
||||
```
|
||||
|
||||
Call [`~ModelMixin.push_to_hub`] after loading it in 4-bit precision. You can also save the serialized 4-bit models locally with [`~ModelMixin.save_pretrained`].
|
||||
Let's generate an image using our quantized models.
|
||||
|
||||
Setting `device_map="auto"` automatically fills all available space on the GPU(s) first, then the CPU, and finally, the hard drive (the absolute slowest option) if there is still not enough memory.
|
||||
|
||||
```py
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
transformer=transformer_4bit,
|
||||
text_encoder_2=text_encoder_2_4bit,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
)
|
||||
|
||||
pipe_kwargs = {
|
||||
"prompt": "A cat holding a sign that says hello world",
|
||||
"height": 1024,
|
||||
"width": 1024,
|
||||
"guidance_scale": 3.5,
|
||||
"num_inference_steps": 50,
|
||||
"max_sequence_length": 512,
|
||||
}
|
||||
|
||||
image = pipe(**pipe_kwargs, generator=torch.manual_seed(0),).images[0]
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/quant-bnb/4bit.png"/>
|
||||
</div>
|
||||
|
||||
When there is enough memory, you can also directly move the pipeline to the GPU with `.to("cuda")` and apply [`~DiffusionPipeline.enable_model_cpu_offload`] to optimize GPU memory usage.
|
||||
|
||||
Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization `config.json` file is pushed first, followed by the quantized model weights. You can also save the serialized 4-bit models locally with [`~ModelMixin.save_pretrained`].
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
@@ -199,17 +302,34 @@ quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dty
|
||||
NF4 is a 4-bit data type from the [QLoRA](https://hf.co/papers/2305.14314) paper, adapted for weights initialized from a normal distribution. You should use NF4 for training 4-bit base models. This can be configured with the `bnb_4bit_quant_type` parameter in the [`BitsAndBytesConfig`]:
|
||||
|
||||
```py
|
||||
from diffusers import BitsAndBytesConfig
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
|
||||
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
|
||||
|
||||
nf4_config = BitsAndBytesConfig(
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
quant_config = TransformersBitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
)
|
||||
|
||||
model_nf4 = SD3Transformer2DModel.from_pretrained(
|
||||
"stabilityai/stable-diffusion-3-medium-diffusers",
|
||||
text_encoder_2_4bit = T5EncoderModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="text_encoder_2",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
)
|
||||
|
||||
transformer_4bit = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=nf4_config,
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
```
|
||||
|
||||
@@ -220,38 +340,74 @@ For inference, the `bnb_4bit_quant_type` does not have a huge impact on performa
|
||||
Nested quantization is a technique that can save additional memory at no additional performance cost. This feature performs a second quantization of the already quantized weights to save an additional 0.4 bits/parameter.
|
||||
|
||||
```py
|
||||
from diffusers import BitsAndBytesConfig
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
|
||||
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
|
||||
|
||||
double_quant_config = BitsAndBytesConfig(
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
quant_config = TransformersBitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
)
|
||||
|
||||
double_quant_model = SD3Transformer2DModel.from_pretrained(
|
||||
"stabilityai/stable-diffusion-3-medium-diffusers",
|
||||
text_encoder_2_4bit = T5EncoderModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="text_encoder_2",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
)
|
||||
|
||||
transformer_4bit = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=double_quant_config,
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
```
|
||||
|
||||
## Dequantizing `bitsandbytes` models
|
||||
|
||||
Once quantized, you can dequantize the model to the original precision but this might result in a small quality loss of the model. Make sure you have enough GPU RAM to fit the dequantized model.
|
||||
Once quantized, you can dequantize a model to its original precision, but this might result in a small loss of quality. Make sure you have enough GPU RAM to fit the dequantized model.
|
||||
|
||||
```python
|
||||
from diffusers import BitsAndBytesConfig
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
|
||||
from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
|
||||
|
||||
double_quant_config = BitsAndBytesConfig(
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
quant_config = TransformersBitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
)
|
||||
|
||||
double_quant_model = SD3Transformer2DModel.from_pretrained(
|
||||
"stabilityai/stable-diffusion-3-medium-diffusers",
|
||||
subfolder="transformer",
|
||||
quantization_config=double_quant_config,
|
||||
text_encoder_2_4bit = T5EncoderModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="text_encoder_2",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
model.dequantize()
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
)
|
||||
|
||||
transformer_4bit = FluxTransformer2DModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
text_encoder_2_4bit.dequantize()
|
||||
transformer_4bit.dequantize()
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
-->
|
||||
|
||||
# GGUF
|
||||
|
||||
The GGUF file format is typically used to store models for inference with [GGML](https://github.com/ggerganov/ggml) and supports a variety of block wise quantization options. Diffusers supports loading checkpoints prequantized and saved in the GGUF format via `from_single_file` loading with Model classes. Loading GGUF checkpoints via Pipelines is currently not supported.
|
||||
|
||||
The following example will load the [FLUX.1 DEV](https://huggingface.co/black-forest-labs/FLUX.1-dev) transformer model using the GGUF Q2_K quantization variant.
|
||||
|
||||
Before starting please install gguf in your environment
|
||||
|
||||
```shell
|
||||
pip install -U gguf
|
||||
```
|
||||
|
||||
Since GGUF is a single file format, use [`~FromSingleFileMixin.from_single_file`] to load the model and pass in the [`GGUFQuantizationConfig`].
|
||||
|
||||
When using GGUF checkpoints, the quantized weights remain in a low memory `dtype`(typically `torch.uint8`) and are dynamically dequantized and cast to the configured `compute_dtype` during each module's forward pass through the model. The `GGUFQuantizationConfig` allows you to set the `compute_dtype`.
|
||||
|
||||
The functions used for dynamic dequantizatation are based on the great work done by [city96](https://github.com/city96/ComfyUI-GGUF), who created the Pytorch ports of the original [`numpy`](https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/gguf/quants.py) implementation by [compilade](https://github.com/compilade).
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
from diffusers import FluxPipeline, FluxTransformer2DModel, GGUFQuantizationConfig
|
||||
|
||||
ckpt_path = (
|
||||
"https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
|
||||
)
|
||||
transformer = FluxTransformer2DModel.from_single_file(
|
||||
ckpt_path,
|
||||
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
transformer=transformer,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
prompt = "A cat holding a sign that says hello world"
|
||||
image = pipe(prompt, generator=torch.manual_seed(0)).images[0]
|
||||
image.save("flux-gguf.png")
|
||||
```
|
||||
|
||||
## Supported Quantization Types
|
||||
|
||||
- BF16
|
||||
- Q4_0
|
||||
- Q4_1
|
||||
- Q5_0
|
||||
- Q5_1
|
||||
- Q8_0
|
||||
- Q2_K
|
||||
- Q3_K
|
||||
- Q4_K
|
||||
- Q5_K
|
||||
- Q6_K
|
||||
|
||||
@@ -17,7 +17,7 @@ Quantization techniques focus on representing data with less information while a
|
||||
|
||||
<Tip>
|
||||
|
||||
Interested in adding a new quantization method to Transformers? Refer to the [Contribute new quantization method guide](https://huggingface.co/docs/transformers/main/en/quantization/contribute) to learn more about adding a new quantization method.
|
||||
Interested in adding a new quantization method to Diffusers? Refer to the [Contribute new quantization method guide](https://huggingface.co/docs/transformers/main/en/quantization/contribute) to learn more about adding a new quantization method.
|
||||
|
||||
</Tip>
|
||||
|
||||
@@ -32,4 +32,9 @@ If you are new to the quantization field, we recommend you to check out these be
|
||||
|
||||
## When to use what?
|
||||
|
||||
This section will be expanded once Diffusers has multiple quantization backends. Currently, we only support `bitsandbytes`. [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
|
||||
Diffusers currently supports the following quantization methods.
|
||||
- [BitsandBytes](./bitsandbytes)
|
||||
- [TorchAO](./torchao)
|
||||
- [GGUF](./gguf)
|
||||
|
||||
[This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# torchao
|
||||
|
||||
[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch. It provides high-performance dtypes, optimization techniques, and kernels for inference and training, featuring composability with native PyTorch features like [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html), FullyShardedDataParallel (FSDP), and more.
|
||||
|
||||
Before you begin, make sure you have Pytorch 2.5+ and TorchAO installed.
|
||||
|
||||
```bash
|
||||
pip install -U torch torchao
|
||||
```
|
||||
|
||||
|
||||
Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`] (you can also load pre-quantized models). This works for any model in any modality, as long as it supports loading with [Accelerate](https://hf.co/docs/accelerate/index) and contains `torch.nn.Linear` layers.
|
||||
|
||||
The example below only quantizes the weights to int8.
|
||||
|
||||
```python
|
||||
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
|
||||
|
||||
model_id = "black-forest-labs/Flux.1-Dev"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
quantization_config = TorchAoConfig("int8wo")
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="transformer",
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
model_id,
|
||||
transformer=transformer,
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "A cat holding a sign that says hello world"
|
||||
image = pipe(prompt, num_inference_steps=28, guidance_scale=0.0).images[0]
|
||||
image.save("output.png")
|
||||
```
|
||||
|
||||
TorchAO is fully compatible with [torch.compile](./optimization/torch2.0#torchcompile), setting it apart from other quantization methods. This makes it easy to speed up inference with just one line of code.
|
||||
|
||||
```python
|
||||
# In the above code, add the following after initializing the transformer
|
||||
transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
|
||||
```
|
||||
|
||||
For speed and memory benchmarks on Flux and CogVideoX, please refer to the table [here](https://github.com/huggingface/diffusers/pull/10009#issue-2688781450). You can also find some torchao [benchmarks](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks) numbers for various hardware.
|
||||
|
||||
torchao also supports an automatic quantization API through [autoquant](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md#autoquantization). Autoquantization determines the best quantization strategy applicable to a model by comparing the performance of each technique on chosen input types and shapes. Currently, this can be used directly on the underlying modeling components. Diffusers will also expose an autoquant configuration option in the future.
|
||||
|
||||
The `TorchAoConfig` class accepts three parameters:
|
||||
- `quant_type`: A string value mentioning one of the quantization types below.
|
||||
- `modules_to_not_convert`: A list of module full/partial module names for which quantization should not be performed. For example, to not perform any quantization of the [`FluxTransformer2DModel`]'s first block, one would specify: `modules_to_not_convert=["single_transformer_blocks.0"]`.
|
||||
- `kwargs`: A dict of keyword arguments to pass to the underlying quantization method which will be invoked based on `quant_type`.
|
||||
|
||||
## Supported quantization types
|
||||
|
||||
torchao supports weight-only quantization and weight and dynamic-activation quantization for int8, float3-float8, and uint1-uint7.
|
||||
|
||||
Weight-only quantization stores the model weights in a specific low-bit data type but performs computation with a higher-precision data type, like `bfloat16`. This lowers the memory requirements from model weights but retains the memory peaks for activation computation.
|
||||
|
||||
Dynamic activation quantization stores the model weights in a low-bit dtype, while also quantizing the activations on-the-fly to save additional memory. This lowers the memory requirements from model weights, while also lowering the memory overhead from activation computations. However, this may come at a quality tradeoff at times, so it is recommended to test different models thoroughly.
|
||||
|
||||
The quantization methods supported are as follows:
|
||||
|
||||
| **Category** | **Full Function Names** | **Shorthands** |
|
||||
|--------------|-------------------------|----------------|
|
||||
| **Integer quantization** | `int4_weight_only`, `int8_dynamic_activation_int4_weight`, `int8_weight_only`, `int8_dynamic_activation_int8_weight` | `int4wo`, `int4dq`, `int8wo`, `int8dq` |
|
||||
| **Floating point 8-bit quantization** | `float8_weight_only`, `float8_dynamic_activation_float8_weight`, `float8_static_activation_float8_weight` | `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`, `float8_e4m3_tensor`, `float8_e4m3_row` |
|
||||
| **Floating point X-bit quantization** | `fpx_weight_only` | `fpX_eAwB` where `X` is the number of bits (1-7), `A` is exponent bits, and `B` is mantissa bits. Constraint: `X == A + B + 1` |
|
||||
| **Unsigned Integer quantization** | `uintx_weight_only` | `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo` |
|
||||
|
||||
Some quantization methods are aliases (for example, `int8wo` is the commonly used shorthand for `int8_weight_only`). This allows using the quantization methods described in the torchao docs as-is, while also making it convenient to remember their shorthand notations.
|
||||
|
||||
Refer to the official torchao documentation for a better understanding of the available quantization methods and the exhaustive list of configuration options available.
|
||||
|
||||
## Resources
|
||||
|
||||
- [TorchAO Quantization API](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md)
|
||||
- [Diffusers-TorchAO examples](https://github.com/sayakpaul/diffusers-torchao)
|
||||
@@ -56,7 +56,7 @@ image
|
||||
|
||||
With the `adapter_name` parameter, it is really easy to use another adapter for inference! Load the [nerijs/pixel-art-xl](https://huggingface.co/nerijs/pixel-art-xl) adapter that has been fine-tuned to generate pixel art images and call it `"pixel"`.
|
||||
|
||||
The pipeline automatically sets the first loaded adapter (`"toy"`) as the active adapter, but you can activate the `"pixel"` adapter with the [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`] method:
|
||||
The pipeline automatically sets the first loaded adapter (`"toy"`) as the active adapter, but you can activate the `"pixel"` adapter with the [`~PeftAdapterMixin.set_adapters`] method:
|
||||
|
||||
```python
|
||||
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
@@ -85,7 +85,7 @@ By default, if the most up-to-date versions of PEFT and Transformers are detecte
|
||||
|
||||
You can also merge different adapter checkpoints for inference to blend their styles together.
|
||||
|
||||
Once again, use the [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`] method to activate the `pixel` and `toy` adapters and specify the weights for how they should be merged.
|
||||
Once again, use the [`~PeftAdapterMixin.set_adapters`] method to activate the `pixel` and `toy` adapters and specify the weights for how they should be merged.
|
||||
|
||||
```python
|
||||
pipe.set_adapters(["pixel", "toy"], adapter_weights=[0.5, 1.0])
|
||||
@@ -114,7 +114,7 @@ Impressive! As you can see, the model generated an image that mixed the characte
|
||||
> [!TIP]
|
||||
> Through its PEFT integration, Diffusers also offers more efficient merging methods which you can learn about in the [Merge LoRAs](../using-diffusers/merge_loras) guide!
|
||||
|
||||
To return to only using one adapter, use the [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`] method to activate the `"toy"` adapter:
|
||||
To return to only using one adapter, use the [`~PeftAdapterMixin.set_adapters`] method to activate the `"toy"` adapter:
|
||||
|
||||
```python
|
||||
pipe.set_adapters("toy")
|
||||
@@ -127,7 +127,7 @@ image = pipe(
|
||||
image
|
||||
```
|
||||
|
||||
Or to disable all adapters entirely, use the [`~diffusers.loaders.UNet2DConditionLoadersMixin.disable_lora`] method to return the base model.
|
||||
Or to disable all adapters entirely, use the [`~PeftAdapterMixin.disable_lora`] method to return the base model.
|
||||
|
||||
```python
|
||||
pipe.disable_lora()
|
||||
@@ -140,7 +140,8 @@ image
|
||||

|
||||
|
||||
### Customize adapters strength
|
||||
For even more customization, you can control how strongly the adapter affects each part of the pipeline. For this, pass a dictionary with the control strengths (called "scales") to [`~diffusers.loaders.UNet2DConditionLoadersMixin.set_adapters`].
|
||||
|
||||
For even more customization, you can control how strongly the adapter affects each part of the pipeline. For this, pass a dictionary with the control strengths (called "scales") to [`~PeftAdapterMixin.set_adapters`].
|
||||
|
||||
For example, here's how you can turn on the adapter for the `down` parts, but turn it off for the `mid` and `up` parts:
|
||||
```python
|
||||
@@ -195,7 +196,7 @@ image
|
||||
|
||||

|
||||
|
||||
## Manage active adapters
|
||||
## Manage adapters
|
||||
|
||||
You have attached multiple adapters in this tutorial, and if you're feeling a bit lost on what adapters have been attached to the pipeline's components, use the [`~diffusers.loaders.StableDiffusionLoraLoaderMixin.get_active_adapters`] method to check the list of active adapters:
|
||||
|
||||
@@ -212,3 +213,11 @@ list_adapters_component_wise = pipe.get_list_adapters()
|
||||
list_adapters_component_wise
|
||||
{"text_encoder": ["toy", "pixel"], "unet": ["toy", "pixel"], "text_encoder_2": ["toy", "pixel"]}
|
||||
```
|
||||
|
||||
The [`~PeftAdapterMixin.delete_adapters`] function completely removes an adapter and their LoRA layers from a model.
|
||||
|
||||
```py
|
||||
pipe.delete_adapters("toy")
|
||||
pipe.get_active_adapters()
|
||||
["pixel"]
|
||||
```
|
||||
|
||||
@@ -134,14 +134,16 @@ The [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] method loads L
|
||||
- the LoRA weights don't have separate identifiers for the UNet and text encoder
|
||||
- the LoRA weights have separate identifiers for the UNet and text encoder
|
||||
|
||||
But if you only need to load LoRA weights into the UNet, then you can use the [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`] method. Let's load the [jbilcke-hf/sdxl-cinematic-1](https://huggingface.co/jbilcke-hf/sdxl-cinematic-1) LoRA:
|
||||
To directly load (and save) a LoRA adapter at the *model-level*, use [`~PeftAdapterMixin.load_lora_adapter`], which builds and prepares the necessary model configuration for the adapter. Like [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`], [`PeftAdapterMixin.load_lora_adapter`] can load LoRAs for both the UNet and text encoder. For example, if you're loading a LoRA for the UNet, [`PeftAdapterMixin.load_lora_adapter`] ignores the keys for the text encoder.
|
||||
|
||||
Use the `weight_name` parameter to specify the specific weight file and the `prefix` parameter to filter for the appropriate state dicts (`"unet"` in this case) to load.
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
|
||||
pipeline.unet.load_attn_procs("jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors")
|
||||
pipeline.unet.load_lora_adapter("jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", prefix="unet")
|
||||
|
||||
# use cnmt in the prompt to trigger the LoRA
|
||||
prompt = "A cute cnmt eating a slice of pizza, stunning color scheme, masterpiece, illustration"
|
||||
@@ -153,6 +155,8 @@ image
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_attn_proc.png" />
|
||||
</div>
|
||||
|
||||
Save an adapter with [`~PeftAdapterMixin.save_lora_adapter`].
|
||||
|
||||
To unload the LoRA weights, use the [`~loaders.StableDiffusionLoraLoaderMixin.unload_lora_weights`] method to discard the LoRA weights and restore the model to its original weights:
|
||||
|
||||
```py
|
||||
|
||||
@@ -872,10 +872,9 @@ def prepare_rotary_positional_embeddings(
|
||||
crops_coords=grid_crops_coords,
|
||||
grid_size=(grid_height, grid_width),
|
||||
temporal_size=num_frames,
|
||||
device=device,
|
||||
)
|
||||
|
||||
freqs_cos = freqs_cos.to(device=device)
|
||||
freqs_sin = freqs_sin.to(device=device)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
|
||||
@@ -894,10 +894,9 @@ def prepare_rotary_positional_embeddings(
|
||||
crops_coords=grid_crops_coords,
|
||||
grid_size=(grid_height, grid_width),
|
||||
temporal_size=num_frames,
|
||||
device=device,
|
||||
)
|
||||
|
||||
freqs_cos = freqs_cos.to(device=device)
|
||||
freqs_sin = freqs_sin.to(device=device)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
|
||||
@@ -241,27 +241,15 @@ from diffusers import StableDiffusionPipeline
|
||||
from diffusers.callbacks import PipelineCallback, MultiPipelineCallbacks
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
import torch
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
|
||||
pipeline: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
use_safetensors=True,
|
||||
).to("cuda")
|
||||
pipeline.safety_checker = None
|
||||
pipeline.requires_safety_checker = False
|
||||
|
||||
|
||||
class SDPromptScheduleCallback(PipelineCallback):
|
||||
class SDPromptSchedulingCallback(PipelineCallback):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: Optional[str] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
cutoff_step_ratio=1.0,
|
||||
encoded_prompt: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
||||
cutoff_step_ratio=None,
|
||||
cutoff_step_index=None,
|
||||
):
|
||||
super().__init__(
|
||||
@@ -275,6 +263,10 @@ class SDPromptScheduleCallback(PipelineCallback):
|
||||
) -> Dict[str, Any]:
|
||||
cutoff_step_ratio = self.config.cutoff_step_ratio
|
||||
cutoff_step_index = self.config.cutoff_step_index
|
||||
if isinstance(self.config.encoded_prompt, tuple):
|
||||
prompt_embeds, negative_prompt_embeds = self.config.encoded_prompt
|
||||
else:
|
||||
prompt_embeds = self.config.encoded_prompt
|
||||
|
||||
# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
|
||||
cutoff_step = (
|
||||
@@ -284,34 +276,164 @@ class SDPromptScheduleCallback(PipelineCallback):
|
||||
)
|
||||
|
||||
if step_index == cutoff_step:
|
||||
prompt_embeds, negative_prompt_embeds = pipeline.encode_prompt(
|
||||
prompt=self.config.prompt,
|
||||
negative_prompt=self.config.negative_prompt,
|
||||
device=pipeline._execution_device,
|
||||
num_images_per_prompt=self.config.num_images_per_prompt,
|
||||
do_classifier_free_guidance=pipeline.do_classifier_free_guidance,
|
||||
)
|
||||
if pipeline.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
|
||||
return callback_kwargs
|
||||
|
||||
|
||||
pipeline: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
use_safetensors=True,
|
||||
).to("cuda")
|
||||
pipeline.safety_checker = None
|
||||
pipeline.requires_safety_checker = False
|
||||
|
||||
callback = MultiPipelineCallbacks(
|
||||
[
|
||||
SDPromptScheduleCallback(
|
||||
prompt="Official portrait of a smiling world war ii general, female, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski",
|
||||
negative_prompt="Deformed, ugly, bad anatomy",
|
||||
cutoff_step_ratio=0.25,
|
||||
)
|
||||
SDPromptSchedulingCallback(
|
||||
encoded_prompt=pipeline.encode_prompt(
|
||||
prompt=f"prompt {index}",
|
||||
negative_prompt=f"negative prompt {index}",
|
||||
device=pipeline._execution_device,
|
||||
num_images_per_prompt=1,
|
||||
# pipeline.do_classifier_free_guidance can't be accessed until after pipeline is ran
|
||||
do_classifier_free_guidance=True,
|
||||
),
|
||||
cutoff_step_index=index,
|
||||
) for index in range(1, 20)
|
||||
]
|
||||
)
|
||||
|
||||
image = pipeline(
|
||||
prompt="Official portrait of a smiling world war ii general, male, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski",
|
||||
negative_prompt="Deformed, ugly, bad anatomy",
|
||||
prompt="prompt"
|
||||
negative_prompt="negative prompt",
|
||||
callback_on_step_end=callback,
|
||||
callback_on_step_end_tensor_inputs=["prompt_embeds"],
|
||||
).images[0]
|
||||
torch.cuda.empty_cache()
|
||||
image.save('image.png')
|
||||
```
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
from diffusers.callbacks import PipelineCallback, MultiPipelineCallbacks
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
import torch
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
|
||||
class SDXLPromptSchedulingCallback(PipelineCallback):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
encoded_prompt: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
||||
add_text_embeds: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
||||
add_time_ids: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
||||
cutoff_step_ratio=None,
|
||||
cutoff_step_index=None,
|
||||
):
|
||||
super().__init__(
|
||||
cutoff_step_ratio=cutoff_step_ratio, cutoff_step_index=cutoff_step_index
|
||||
)
|
||||
|
||||
tensor_inputs = ["prompt_embeds", "add_text_embeds", "add_time_ids"]
|
||||
|
||||
def callback_fn(
|
||||
self, pipeline, step_index, timestep, callback_kwargs
|
||||
) -> Dict[str, Any]:
|
||||
cutoff_step_ratio = self.config.cutoff_step_ratio
|
||||
cutoff_step_index = self.config.cutoff_step_index
|
||||
if isinstance(self.config.encoded_prompt, tuple):
|
||||
prompt_embeds, negative_prompt_embeds = self.config.encoded_prompt
|
||||
else:
|
||||
prompt_embeds = self.config.encoded_prompt
|
||||
if isinstance(self.config.add_text_embeds, tuple):
|
||||
add_text_embeds, negative_add_text_embeds = self.config.add_text_embeds
|
||||
else:
|
||||
add_text_embeds = self.config.add_text_embeds
|
||||
if isinstance(self.config.add_time_ids, tuple):
|
||||
add_time_ids, negative_add_time_ids = self.config.add_time_ids
|
||||
else:
|
||||
add_time_ids = self.config.add_time_ids
|
||||
|
||||
# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
|
||||
cutoff_step = (
|
||||
cutoff_step_index
|
||||
if cutoff_step_index is not None
|
||||
else int(pipeline.num_timesteps * cutoff_step_ratio)
|
||||
)
|
||||
|
||||
if step_index == cutoff_step:
|
||||
if pipeline.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
add_text_embeds = torch.cat([negative_add_text_embeds, add_text_embeds])
|
||||
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids])
|
||||
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
|
||||
callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
|
||||
callback_kwargs[self.tensor_inputs[2]] = add_time_ids
|
||||
return callback_kwargs
|
||||
|
||||
|
||||
pipeline: StableDiffusionXLPipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
use_safetensors=True,
|
||||
).to("cuda")
|
||||
|
||||
callbacks = []
|
||||
for index in range(1, 20):
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = pipeline.encode_prompt(
|
||||
prompt=f"prompt {index}",
|
||||
negative_prompt=f"prompt {index}",
|
||||
device=pipeline._execution_device,
|
||||
num_images_per_prompt=1,
|
||||
# pipeline.do_classifier_free_guidance can't be accessed until after pipeline is ran
|
||||
do_classifier_free_guidance=True,
|
||||
)
|
||||
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
||||
add_time_ids = pipeline._get_add_time_ids(
|
||||
(1024, 1024),
|
||||
(0, 0),
|
||||
(1024, 1024),
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
negative_add_time_ids = pipeline._get_add_time_ids(
|
||||
(1024, 1024),
|
||||
(0, 0),
|
||||
(1024, 1024),
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
callbacks.append(
|
||||
SDXLPromptSchedulingCallback(
|
||||
encoded_prompt=(prompt_embeds, negative_prompt_embeds),
|
||||
add_text_embeds=(pooled_prompt_embeds, negative_pooled_prompt_embeds),
|
||||
add_time_ids=(add_time_ids, negative_add_time_ids),
|
||||
cutoff_step_index=index,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
callback = MultiPipelineCallbacks(callbacks)
|
||||
|
||||
image = pipeline(
|
||||
prompt="prompt",
|
||||
negative_prompt="negative prompt",
|
||||
callback_on_step_end=callback,
|
||||
callback_on_step_end_tensor_inputs=[
|
||||
"prompt_embeds",
|
||||
"add_text_embeds",
|
||||
"add_time_ids",
|
||||
],
|
||||
).images[0]
|
||||
```
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1008,6 +1008,8 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
|
||||
self.transformer.inner_dim // self.transformer.num_heads,
|
||||
grid_crops_coords,
|
||||
(grid_height, grid_width),
|
||||
device=device,
|
||||
output_type="pt",
|
||||
)
|
||||
|
||||
style = torch.tensor([0], device=device)
|
||||
|
||||
@@ -129,7 +129,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
self.power = int(rp_args["power"]) if "power" in rp_args else 1
|
||||
|
||||
prompts = prompt if isinstance(prompt, list) else [prompt]
|
||||
n_prompts = negative_prompt if isinstance(prompt, list) else [negative_prompt]
|
||||
n_prompts = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt]
|
||||
self.batch = batch = num_images_per_prompt * len(prompts)
|
||||
|
||||
if use_base:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# ControlNet training example for Stable Diffusion 3 (SD3)
|
||||
# ControlNet training example for Stable Diffusion 3/3.5 (SD3/3.5)
|
||||
|
||||
The `train_controlnet_sd3.py` script shows how to implement the ControlNet training procedure and adapt it for [Stable Diffusion 3](https://arxiv.org/abs/2403.03206).
|
||||
The `train_controlnet_sd3.py` script shows how to implement the ControlNet training procedure and adapt it for [Stable Diffusion 3](https://arxiv.org/abs/2403.03206) and [Stable Diffusion 3.5](https://stability.ai/news/introducing-stable-diffusion-3-5).
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
@@ -51,9 +51,9 @@ Please download the dataset and unzip it in the directory `fill50k` in the `exam
|
||||
|
||||
## Training
|
||||
|
||||
First download the SD3 model from [Hugging Face Hub](https://huggingface.co/stabilityai/stable-diffusion-3-medium). We will use it as a base model for the ControlNet training.
|
||||
First download the SD3 model from [Hugging Face Hub](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers) or the SD3.5 model from [Hugging Face Hub](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium). We will use it as a base model for the ControlNet training.
|
||||
> [!NOTE]
|
||||
> As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
|
||||
> As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers) or [Stable Diffusion 3.5 Large Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
@@ -90,6 +90,8 @@ accelerate launch train_controlnet_sd3.py \
|
||||
--gradient_accumulation_steps=4
|
||||
```
|
||||
|
||||
To train a ControlNet model for Stable Diffusion 3.5, replace the `MODEL_DIR` with `stabilityai/stable-diffusion-3.5-medium`.
|
||||
|
||||
To better track our training experiments, we're using flags `validation_image`, `validation_prompt`, and `validation_steps` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
|
||||
|
||||
Our experiments were conducted on a single 40GB A100 GPU.
|
||||
@@ -124,6 +126,8 @@ image = pipe(
|
||||
image.save("./output.png")
|
||||
```
|
||||
|
||||
Similarly, for SD3.5, replace the `base_model_path` with `stabilityai/stable-diffusion-3.5-medium` and controlnet_path `DavyMorgan/sd35-controlnet-out'.
|
||||
|
||||
## Notes
|
||||
|
||||
### GPU usage
|
||||
@@ -135,6 +139,8 @@ Make sure to use the right GPU when configuring the [accelerator](https://huggin
|
||||
|
||||
## Example results
|
||||
|
||||
### SD3
|
||||
|
||||
#### After 500 steps with batch size 8
|
||||
|
||||
| | |
|
||||
@@ -150,3 +156,20 @@ Make sure to use the right GPU when configuring the [accelerator](https://huggin
|
||||
|| pale golden rod circle with old lace background |
|
||||
 |  |
|
||||
|
||||
### SD3.5
|
||||
|
||||
#### After 500 steps with batch size 8
|
||||
|
||||
| | |
|
||||
|-------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------:|
|
||||
|| pale golden rod circle with old lace background |
|
||||
 |  |
|
||||
|
||||
|
||||
#### After 3000 steps with batch size 8:
|
||||
|
||||
| | |
|
||||
|-------------------|:----------------------------------------------------------------------------------------------------------------------------------------------------:|
|
||||
|| pale golden rod circle with old lace background |
|
||||
 |  |
|
||||
|
||||
|
||||
@@ -138,6 +138,27 @@ class ControlNetSD3(ExamplesTestsAccelerate):
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
|
||||
|
||||
|
||||
class ControlNetSD35(ExamplesTestsAccelerate):
|
||||
def test_controlnet_sd3(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/controlnet/train_controlnet_sd3.py
|
||||
--pretrained_model_name_or_path=hf-internal-testing/tiny-sd35-pipe
|
||||
--dataset_name=hf-internal-testing/fill10
|
||||
--output_dir={tmpdir}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--controlnet_model_name_or_path=DavyMorgan/tiny-controlnet-sd35
|
||||
--max_train_steps=4
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
|
||||
|
||||
|
||||
class ControlNetflux(ExamplesTestsAccelerate):
|
||||
def test_controlnet_flux(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
|
||||
@@ -263,6 +263,12 @@ def parse_args(input_args=None):
|
||||
help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
|
||||
" If not specified controlnet weights are initialized from unet.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_extra_conditioning_channels",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of extra conditioning channels for controlnet.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
@@ -539,6 +545,9 @@ def parse_args(input_args=None):
|
||||
default=77,
|
||||
help="Maximum sequence length to use with with the T5 text encoder",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_preprocess_batch_size", type=int, default=1000, help="Batch size for preprocessing dataset."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_prompt",
|
||||
type=str,
|
||||
@@ -986,7 +995,9 @@ def main(args):
|
||||
controlnet = SD3ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)
|
||||
else:
|
||||
logger.info("Initializing controlnet weights from transformer")
|
||||
controlnet = SD3ControlNetModel.from_transformer(transformer)
|
||||
controlnet = SD3ControlNetModel.from_transformer(
|
||||
transformer, num_extra_conditioning_channels=args.num_extra_conditioning_channels
|
||||
)
|
||||
|
||||
transformer.requires_grad_(False)
|
||||
vae.requires_grad_(False)
|
||||
@@ -1123,7 +1134,12 @@ def main(args):
|
||||
# fingerprint used by the cache for the other processes to load the result
|
||||
# details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
|
||||
new_fingerprint = Hasher.hash(args)
|
||||
train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)
|
||||
train_dataset = train_dataset.map(
|
||||
compute_embeddings_fn,
|
||||
batched=True,
|
||||
batch_size=args.dataset_preprocess_batch_size,
|
||||
new_fingerprint=new_fingerprint,
|
||||
)
|
||||
|
||||
del text_encoder_one, text_encoder_two, text_encoder_three
|
||||
del tokenizer_one, tokenizer_two, tokenizer_three
|
||||
|
||||
@@ -0,0 +1,127 @@
|
||||
# DreamBooth training example for SANA
|
||||
|
||||
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject.
|
||||
|
||||
The `train_dreambooth_lora_sana.py` script shows how to implement the training procedure with [LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) and adapt it for [SANA](https://arxiv.org/abs/2410.10629).
|
||||
|
||||
|
||||
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then cd in the `examples/dreambooth` folder and run
|
||||
```bash
|
||||
pip install -r requirements_sana.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Or for a default accelerate configuration without answering questions about your environment
|
||||
|
||||
```bash
|
||||
accelerate config default
|
||||
```
|
||||
|
||||
Or if your environment doesn't support an interactive shell (e.g., a notebook)
|
||||
|
||||
```python
|
||||
from accelerate.utils import write_basic_config
|
||||
write_basic_config()
|
||||
```
|
||||
|
||||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
|
||||
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.14.0` installed in your environment.
|
||||
|
||||
|
||||
### Dog toy example
|
||||
|
||||
Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
|
||||
|
||||
Let's first download it locally:
|
||||
|
||||
```python
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
local_dir = "./dog"
|
||||
snapshot_download(
|
||||
"diffusers/dog-example",
|
||||
local_dir=local_dir, repo_type="dataset",
|
||||
ignore_patterns=".gitattributes",
|
||||
)
|
||||
```
|
||||
|
||||
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
|
||||
|
||||
Now, we can launch training using:
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers"
|
||||
export INSTANCE_DIR="dog"
|
||||
export OUTPUT_DIR="trained-sana-lora"
|
||||
|
||||
accelerate launch train_dreambooth_lora_sana.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--mixed_precision="bf16" \
|
||||
--instance_prompt="a photo of sks dog" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--use_8bit_adam \
|
||||
--learning_rate=1e-4 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=500 \
|
||||
--validation_prompt="A photo of sks dog in a bucket" \
|
||||
--validation_epochs=25 \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
For using `push_to_hub`, make you're logged into your Hugging Face account:
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
To better track our training experiments, we're using the following flags in the command above:
|
||||
|
||||
* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.
|
||||
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
|
||||
|
||||
## Notes
|
||||
|
||||
Additionally, we welcome you to explore the following CLI arguments:
|
||||
|
||||
* `--lora_layers`: The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "to_k,to_q,to_v" will result in lora training of attention layers only.
|
||||
* `--complex_human_instruction`: Instructions for complex human attention as shown in [here](https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55).
|
||||
* `--max_sequence_length`: Maximum sequence length to use for text embeddings.
|
||||
|
||||
|
||||
We provide several options for optimizing memory optimization:
|
||||
|
||||
* `--offload`: When enabled, we will offload the text encoder and VAE to CPU, when they are not used.
|
||||
* `cache_latents`: When enabled, we will pre-compute the latents from the input images with the VAE and remove the VAE from memory once done.
|
||||
* `--use_8bit_adam`: When enabled, we will use the 8bit version of AdamW provided by the `bitsandbytes` library.
|
||||
|
||||
Refer to the [official documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana) of the `SanaPipeline` to know more about the models available under the SANA family and their preferred dtypes during inference.
|
||||
@@ -0,0 +1,8 @@
|
||||
accelerate>=1.0.0
|
||||
torchvision
|
||||
transformers>=4.47.0
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
peft>=0.14.0
|
||||
sentencepiece
|
||||
@@ -1300,16 +1300,17 @@ def main(args):
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
base_weight = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective needs to be floored to an SNR weight of one.
|
||||
mse_loss_weights = base_weight + 1
|
||||
divisor = snr + 1
|
||||
else:
|
||||
# Epsilon and sample both use the same loss weights.
|
||||
mse_loss_weights = base_weight
|
||||
divisor = snr
|
||||
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / divisor
|
||||
)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
loss = loss.mean()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,204 @@
|
||||
# Training Flux Control
|
||||
|
||||
This (experimental) example shows how to train Control LoRAs with [Flux](https://huggingface.co/black-forest-labs/FLUX.1-dev) by conditioning it with additional structural controls (like depth maps, poses, etc.). We provide a script for full fine-tuning, too, refer to [this section](#full-fine-tuning). To know more about Flux Control family, refer to the following resources:
|
||||
|
||||
* [Docs](https://github.com/black-forest-labs/flux/blob/main/docs/structural-conditioning.md) by Black Forest Labs
|
||||
* Diffusers docs ([1](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux#canny-control), [2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux#depth-control))
|
||||
|
||||
To incorporate additional condition latents, we expand the input features of Flux.1-Dev from 64 to 128. The first 64 channels correspond to the original input latents to be denoised, while the latter 64 channels correspond to control latents. This expansion happens on the `x_embedder` layer, where the combined latents are projected to the expected feature dimension of rest of the network. Inference is performed using the `FluxControlPipeline`.
|
||||
|
||||
> [!NOTE]
|
||||
> **Gated model**
|
||||
>
|
||||
> As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
The example command below shows how to launch fine-tuning for pose conditions. The dataset ([`raulc0399/open_pose_controlnet`](https://huggingface.co/datasets/raulc0399/open_pose_controlnet)) being used here already has the pose conditions of the original images, so we don't have to compute them.
|
||||
|
||||
```bash
|
||||
accelerate launch train_control_lora_flux.py \
|
||||
--pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
|
||||
--dataset_name="raulc0399/open_pose_controlnet" \
|
||||
--output_dir="pose-control-lora" \
|
||||
--mixed_precision="bf16" \
|
||||
--train_batch_size=1 \
|
||||
--rank=64 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--gradient_checkpointing \
|
||||
--use_8bit_adam \
|
||||
--learning_rate=1e-4 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=5000 \
|
||||
--validation_image="openpose.png" \
|
||||
--validation_prompt="A couple, 4k photo, highly detailed" \
|
||||
--offload \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
`openpose.png` comes from [here](https://huggingface.co/Adapter/t2iadapter/resolve/main/openpose.png).
|
||||
|
||||
You need to install `diffusers` from the branch of [this PR](https://github.com/huggingface/diffusers/pull/9999). When it's merged, you should install `diffusers` from the `main`.
|
||||
|
||||
The training script exposes additional CLI args that might be useful to experiment with:
|
||||
|
||||
* `use_lora_bias`: When set, additionally trains the biases of the `lora_B` layer.
|
||||
* `train_norm_layers`: When set, additionally trains the normalization scales. Takes care of saving and loading.
|
||||
* `lora_layers`: Specify the layers you want to apply LoRA to. If you specify "all-linear", all the linear layers will be LoRA-attached.
|
||||
|
||||
### Training with DeepSpeed
|
||||
|
||||
It's possible to train with [DeepSpeed](https://github.com/microsoft/DeepSpeed), specifically leveraging the Zero2 system optimization. To use it, save the following config to an YAML file (feel free to modify as needed):
|
||||
|
||||
```yaml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
gradient_accumulation_steps: 1
|
||||
gradient_clipping: 1.0
|
||||
offload_optimizer_device: cpu
|
||||
offload_param_device: cpu
|
||||
zero3_init_flag: false
|
||||
zero_stage: 2
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 1
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
```
|
||||
|
||||
And then while launching training, pass the config file:
|
||||
|
||||
```bash
|
||||
accelerate launch --config_file=CONFIG_FILE.yaml ...
|
||||
```
|
||||
|
||||
### Inference
|
||||
|
||||
The pose images in our dataset were computed using the [`controlnet_aux`](https://github.com/huggingface/controlnet_aux) library. Let's install it first:
|
||||
|
||||
```bash
|
||||
pip install controlnet_aux
|
||||
```
|
||||
|
||||
And then we are ready:
|
||||
|
||||
```py
|
||||
from controlnet_aux import OpenposeDetector
|
||||
from diffusers import FluxControlPipeline
|
||||
from diffusers.utils import load_image
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda")
|
||||
pipe.load_lora_weights("...") # change this.
|
||||
|
||||
open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
|
||||
# prepare pose condition.
|
||||
url = "https://huggingface.co/Adapter/t2iadapter/resolve/main/people.jpg"
|
||||
image = load_image(url)
|
||||
image = open_pose(image, detect_resolution=512, image_resolution=1024)
|
||||
image = np.array(image)[:, :, ::-1]
|
||||
image = Image.fromarray(np.uint8(image))
|
||||
|
||||
prompt = "A couple, 4k photo, highly detailed"
|
||||
|
||||
gen_images = pipe(
|
||||
prompt=prompt,
|
||||
condition_image=image,
|
||||
num_inference_steps=50,
|
||||
joint_attention_kwargs={"scale": 0.9},
|
||||
guidance_scale=25.,
|
||||
).images[0]
|
||||
gen_images.save("output.png")
|
||||
```
|
||||
|
||||
## Full fine-tuning
|
||||
|
||||
We provide a non-LoRA version of the training script `train_control_flux.py`. Here is an example command:
|
||||
|
||||
```bash
|
||||
accelerate launch --config_file=accelerate_ds2.yaml train_control_flux.py \
|
||||
--pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
|
||||
--dataset_name="raulc0399/open_pose_controlnet" \
|
||||
--output_dir="pose-control" \
|
||||
--mixed_precision="bf16" \
|
||||
--train_batch_size=2 \
|
||||
--dataloader_num_workers=4 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--gradient_checkpointing \
|
||||
--use_8bit_adam \
|
||||
--proportion_empty_prompts=0.2 \
|
||||
--learning_rate=5e-5 \
|
||||
--adam_weight_decay=1e-4 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="cosine" \
|
||||
--lr_warmup_steps=1000 \
|
||||
--checkpointing_steps=1000 \
|
||||
--max_train_steps=10000 \
|
||||
--validation_steps=200 \
|
||||
--validation_image "2_pose_1024.jpg" "3_pose_1024.jpg" \
|
||||
--validation_prompt "two friends sitting by each other enjoying a day at the park, full hd, cinematic" "person enjoying a day at the park, full hd, cinematic" \
|
||||
--offload \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
Change the `validation_image` and `validation_prompt` as needed.
|
||||
|
||||
For inference, this time, we will run:
|
||||
|
||||
```py
|
||||
from controlnet_aux import OpenposeDetector
|
||||
from diffusers import FluxControlPipeline, FluxTransformer2DModel
|
||||
from diffusers.utils import load_image
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
transformer = FluxTransformer2DModel.from_pretrained("...") # change this.
|
||||
pipe = FluxControlPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
|
||||
open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
|
||||
# prepare pose condition.
|
||||
url = "https://huggingface.co/Adapter/t2iadapter/resolve/main/people.jpg"
|
||||
image = load_image(url)
|
||||
image = open_pose(image, detect_resolution=512, image_resolution=1024)
|
||||
image = np.array(image)[:, :, ::-1]
|
||||
image = Image.fromarray(np.uint8(image))
|
||||
|
||||
prompt = "A couple, 4k photo, highly detailed"
|
||||
|
||||
gen_images = pipe(
|
||||
prompt=prompt,
|
||||
condition_image=image,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=25.,
|
||||
).images[0]
|
||||
gen_images.save("output.png")
|
||||
```
|
||||
|
||||
## Things to note
|
||||
|
||||
* The scripts provided in this directory are experimental and educational. This means we may have to tweak things around to get good results on a given condition. We believe this is best done with the community 🤗
|
||||
* The scripts are not memory-optimized but we offload the VAE and the text encoders to CPU when they are not used.
|
||||
* We can extract LoRAs from the fully fine-tuned model. While we currently don't provide any utilities for that, users are welcome to refer to [this script](https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/master/control_lora_create.py) that provides a similar functionality.
|
||||
@@ -0,0 +1,6 @@
|
||||
transformers==4.47.0
|
||||
wandb
|
||||
torch
|
||||
torchvision
|
||||
accelerate==1.2.0
|
||||
peft>=0.14.0
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,175 @@
|
||||
# Search models on Civitai and Hugging Face
|
||||
|
||||
The [auto_diffusers](https://github.com/suzukimain/auto_diffusers) library provides additional functionalities to Diffusers such as searching for models on Civitai and the Hugging Face Hub.
|
||||
Please refer to the original library [here](https://pypi.org/project/auto-diffusers/)
|
||||
|
||||
## Installation
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
> [!IMPORTANT]
|
||||
> To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the installation up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install .
|
||||
```
|
||||
Set up the pipeline. You can also cd to this folder and run it.
|
||||
```bash
|
||||
!wget https://raw.githubusercontent.com/suzukimain/auto_diffusers/refs/heads/master/src/auto_diffusers/pipeline_easy.py
|
||||
```
|
||||
|
||||
## Load from Civitai
|
||||
```python
|
||||
from pipeline_easy import (
|
||||
EasyPipelineForText2Image,
|
||||
EasyPipelineForImage2Image,
|
||||
EasyPipelineForInpainting,
|
||||
)
|
||||
|
||||
# Text-to-Image
|
||||
pipeline = EasyPipelineForText2Image.from_civitai(
|
||||
"search_word",
|
||||
base_model="SD 1.5",
|
||||
).to("cuda")
|
||||
|
||||
|
||||
# Image-to-Image
|
||||
pipeline = EasyPipelineForImage2Image.from_civitai(
|
||||
"search_word",
|
||||
base_model="SD 1.5",
|
||||
).to("cuda")
|
||||
|
||||
|
||||
# Inpainting
|
||||
pipeline = EasyPipelineForInpainting.from_civitai(
|
||||
"search_word",
|
||||
base_model="SD 1.5",
|
||||
).to("cuda")
|
||||
```
|
||||
|
||||
## Load from Hugging Face
|
||||
```python
|
||||
from pipeline_easy import (
|
||||
EasyPipelineForText2Image,
|
||||
EasyPipelineForImage2Image,
|
||||
EasyPipelineForInpainting,
|
||||
)
|
||||
|
||||
# Text-to-Image
|
||||
pipeline = EasyPipelineForText2Image.from_huggingface(
|
||||
"search_word",
|
||||
checkpoint_format="diffusers",
|
||||
).to("cuda")
|
||||
|
||||
|
||||
# Image-to-Image
|
||||
pipeline = EasyPipelineForImage2Image.from_huggingface(
|
||||
"search_word",
|
||||
checkpoint_format="diffusers",
|
||||
).to("cuda")
|
||||
|
||||
|
||||
# Inpainting
|
||||
pipeline = EasyPipelineForInpainting.from_huggingface(
|
||||
"search_word",
|
||||
checkpoint_format="diffusers",
|
||||
).to("cuda")
|
||||
```
|
||||
|
||||
|
||||
## Search Civitai and Huggingface
|
||||
|
||||
```python
|
||||
from pipeline_easy import (
|
||||
search_huggingface,
|
||||
search_civitai,
|
||||
)
|
||||
|
||||
# Search Lora
|
||||
Lora = search_civitai(
|
||||
"Keyword_to_search_Lora",
|
||||
model_type="LORA",
|
||||
base_model = "SD 1.5",
|
||||
download=True,
|
||||
)
|
||||
# Load Lora into the pipeline.
|
||||
pipeline.load_lora_weights(Lora)
|
||||
|
||||
|
||||
# Search TextualInversion
|
||||
TextualInversion = search_civitai(
|
||||
"EasyNegative",
|
||||
model_type="TextualInversion",
|
||||
base_model = "SD 1.5",
|
||||
download=True
|
||||
)
|
||||
# Load TextualInversion into the pipeline.
|
||||
pipeline.load_textual_inversion(TextualInversion, token="EasyNegative")
|
||||
```
|
||||
|
||||
### Search Civitai
|
||||
|
||||
> [!TIP]
|
||||
> **If an error occurs, insert the `token` and run again.**
|
||||
|
||||
#### `EasyPipeline.from_civitai` parameters
|
||||
|
||||
| Name | Type | Default | Description |
|
||||
|:---------------:|:----------------------:|:-------------:|:-----------------------------------------------------------------------------------:|
|
||||
| search_word | string, Path | ー | The search query string. Can be a keyword, Civitai URL, local directory or file path. |
|
||||
| model_type | string | `Checkpoint` | The type of model to search for. <br>(for example `Checkpoint`, `TextualInversion`, `Controlnet`, `LORA`, `Hypernetwork`, `AestheticGradient`, `Poses`) |
|
||||
| base_model | string | None | Trained model tag (for example `SD 1.5`, `SD 3.5`, `SDXL 1.0`) |
|
||||
| torch_dtype | string, torch.dtype | None | Override the default `torch.dtype` and load the model with another dtype. |
|
||||
| force_download | bool | False | Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. |
|
||||
| cache_dir | string, Path | None | Path to the folder where cached files are stored. |
|
||||
| resume | bool | False | Whether to resume an incomplete download. |
|
||||
| token | string | None | API token for Civitai authentication. |
|
||||
|
||||
|
||||
#### `search_civitai` parameters
|
||||
|
||||
| Name | Type | Default | Description |
|
||||
|:---------------:|:--------------:|:-------------:|:-----------------------------------------------------------------------------------:|
|
||||
| search_word | string, Path | ー | The search query string. Can be a keyword, Civitai URL, local directory or file path. |
|
||||
| model_type | string | `Checkpoint` | The type of model to search for. <br>(for example `Checkpoint`, `TextualInversion`, `Controlnet`, `LORA`, `Hypernetwork`, `AestheticGradient`, `Poses`) |
|
||||
| base_model | string | None | Trained model tag (for example `SD 1.5`, `SD 3.5`, `SDXL 1.0`) |
|
||||
| download | bool | False | Whether to download the model. |
|
||||
| force_download | bool | False | Whether to force the download if the model already exists. |
|
||||
| cache_dir | string, Path | None | Path to the folder where cached files are stored. |
|
||||
| resume | bool | False | Whether to resume an incomplete download. |
|
||||
| token | string | None | API token for Civitai authentication. |
|
||||
| include_params | bool | False | Whether to include parameters in the returned data. |
|
||||
| skip_error | bool | False | Whether to skip errors and return None. |
|
||||
|
||||
### Search Huggingface
|
||||
|
||||
> [!TIP]
|
||||
> **If an error occurs, insert the `token` and run again.**
|
||||
|
||||
#### `EasyPipeline.from_huggingface` parameters
|
||||
|
||||
| Name | Type | Default | Description |
|
||||
|:---------------------:|:-------------------:|:--------------:|:----------------------------------------------------------------:|
|
||||
| search_word | string, Path | ー | The search query string. Can be a keyword, Hugging Face URL, local directory or file path, or a Hugging Face path (`<creator>/<repo>`). |
|
||||
| checkpoint_format | string | `single_file` | The format of the model checkpoint.<br>● `single_file` to search for `single file checkpoint` <br>●`diffusers` to search for `multifolder diffusers format checkpoint` |
|
||||
| torch_dtype | string, torch.dtype | None | Override the default `torch.dtype` and load the model with another dtype. |
|
||||
| force_download | bool | False | Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. |
|
||||
| cache_dir | string, Path | None | Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. |
|
||||
| token | string, bool | None | The token to use as HTTP bearer authorization for remote files. |
|
||||
|
||||
|
||||
#### `search_huggingface` parameters
|
||||
|
||||
| Name | Type | Default | Description |
|
||||
|:---------------------:|:-------------------:|:--------------:|:----------------------------------------------------------------:|
|
||||
| search_word | string, Path | ー | The search query string. Can be a keyword, Hugging Face URL, local directory or file path, or a Hugging Face path (`<creator>/<repo>`). |
|
||||
| checkpoint_format | string | `single_file` | The format of the model checkpoint. <br>● `single_file` to search for `single file checkpoint` <br>●`diffusers` to search for `multifolder diffusers format checkpoint` |
|
||||
| pipeline_tag | string | None | Tag to filter models by pipeline. |
|
||||
| download | bool | False | Whether to download the model. |
|
||||
| force_download | bool | False | Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. |
|
||||
| cache_dir | string, Path | None | Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. |
|
||||
| token | string, bool | None | The token to use as HTTP bearer authorization for remote files. |
|
||||
| include_params | bool | False | Whether to include parameters in the returned data. |
|
||||
| skip_error | bool | False | Whether to skip errors and return None. |
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1 @@
|
||||
huggingface-hub>=0.26.2
|
||||
@@ -7,13 +7,14 @@ It has been tested on v4 and v5p TPU versions. Training code has been tested on
|
||||
This script implements Distributed Data Parallel using GSPMD feature in XLA compiler
|
||||
where we shard the input batches over the TPU devices.
|
||||
|
||||
As of 9-11-2024, these are some expected step times.
|
||||
As of 10-31-2024, these are some expected step times.
|
||||
|
||||
| accelerator | global batch size | step time (seconds) |
|
||||
| ----------- | ----------------- | --------- |
|
||||
| v5p-128 | 1024 | 0.245 |
|
||||
| v5p-256 | 2048 | 0.234 |
|
||||
| v5p-512 | 4096 | 0.2498 |
|
||||
| v5p-512 | 16384 | 1.01 |
|
||||
| v5p-256 | 8192 | 1.01 |
|
||||
| v5p-128 | 4096 | 1.0 |
|
||||
| v5p-64 | 2048 | 1.01 |
|
||||
|
||||
## Create TPU
|
||||
|
||||
@@ -43,8 +44,9 @@ Install PyTorch and PyTorch/XLA nightly versions:
|
||||
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
|
||||
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
|
||||
--command='
|
||||
pip3 install --pre torch==2.5.0.dev20240905+cpu torchvision==0.20.0.dev20240905+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
pip3 install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.5.0.dev20240905-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
|
||||
pip3 install --pre torch==2.6.0.dev20241031+cpu torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
pip3 install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241031.cxx11-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
|
||||
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||
'
|
||||
```
|
||||
|
||||
@@ -88,17 +90,18 @@ are fixed.
|
||||
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
|
||||
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
|
||||
--command='
|
||||
export XLA_DISABLE_FUNCTIONALIZATION=1
|
||||
export XLA_DISABLE_FUNCTIONALIZATION=0
|
||||
export PROFILE_DIR=/tmp/
|
||||
export CACHE_DIR=/tmp/
|
||||
export DATASET_NAME=lambdalabs/naruto-blip-captions
|
||||
export PER_HOST_BATCH_SIZE=32 # This is known to work on TPU v4. Can set this to 64 for TPU v5p
|
||||
export TRAIN_STEPS=50
|
||||
export OUTPUT_DIR=/tmp/trained-model/
|
||||
python diffusers/examples/research_projects/pytorch_xla/train_text_to_image_xla.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-2-base --dataset_name=$DATASET_NAME --resolution=512 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=80000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=4 --loader_prefetch_size=4 --device_prefetch_size=4'
|
||||
|
||||
python diffusers/examples/research_projects/pytorch_xla/train_text_to_image_xla.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-2-base --dataset_name=$DATASET_NAME --resolution=512 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=80000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=8 --loader_prefetch_size=4 --device_prefetch_size=4'
|
||||
```
|
||||
|
||||
Pass `--print_loss` if you would like to see the loss printed at every step. Be aware that printing the loss at every step disrupts the optimized flow execution, thus the step time will be longer.
|
||||
|
||||
### Environment Envs Explained
|
||||
|
||||
* `XLA_DISABLE_FUNCTIONALIZATION`: To optimize the performance for AdamW optimizer.
|
||||
|
||||
@@ -140,33 +140,43 @@ class TrainSD:
|
||||
self.optimizer.step()
|
||||
|
||||
def start_training(self):
|
||||
times = []
|
||||
last_time = time.time()
|
||||
step = 0
|
||||
while True:
|
||||
if self.global_step >= self.args.max_train_steps:
|
||||
xm.mark_step()
|
||||
break
|
||||
if step == 4 and PROFILE_DIR is not None:
|
||||
xm.wait_device_ops()
|
||||
xp.trace_detached(f"localhost:{PORT}", PROFILE_DIR, duration_ms=args.profile_duration)
|
||||
dataloader_exception = False
|
||||
measure_start_step = args.measure_start_step
|
||||
assert measure_start_step < self.args.max_train_steps
|
||||
total_time = 0
|
||||
for step in range(0, self.args.max_train_steps):
|
||||
try:
|
||||
batch = next(self.dataloader)
|
||||
except Exception as e:
|
||||
dataloader_exception = True
|
||||
print(e)
|
||||
break
|
||||
if step == measure_start_step and PROFILE_DIR is not None:
|
||||
xm.wait_device_ops()
|
||||
xp.trace_detached(f"localhost:{PORT}", PROFILE_DIR, duration_ms=args.profile_duration)
|
||||
last_time = time.time()
|
||||
loss = self.step_fn(batch["pixel_values"], batch["input_ids"])
|
||||
step_time = time.time() - last_time
|
||||
if step >= 10:
|
||||
times.append(step_time)
|
||||
print(f"step: {step}, step_time: {step_time}")
|
||||
if step % 5 == 0:
|
||||
print(f"step: {step}, loss: {loss}")
|
||||
last_time = time.time()
|
||||
self.global_step += 1
|
||||
step += 1
|
||||
# print(f"Average step time: {sum(times)/len(times)}")
|
||||
xm.wait_device_ops()
|
||||
|
||||
def print_loss_closure(step, loss):
|
||||
print(f"Step: {step}, Loss: {loss}")
|
||||
|
||||
if args.print_loss:
|
||||
xm.add_step_closure(
|
||||
print_loss_closure,
|
||||
args=(
|
||||
self.global_step,
|
||||
loss,
|
||||
),
|
||||
)
|
||||
xm.mark_step()
|
||||
if not dataloader_exception:
|
||||
xm.wait_device_ops()
|
||||
total_time = time.time() - last_time
|
||||
print(f"Average step time: {total_time/(self.args.max_train_steps-measure_start_step)}")
|
||||
else:
|
||||
print("dataloader exception happen, skip result")
|
||||
return
|
||||
|
||||
def step_fn(
|
||||
self,
|
||||
@@ -180,7 +190,10 @@ class TrainSD:
|
||||
noise = torch.randn_like(latents).to(self.device, dtype=self.weight_dtype)
|
||||
bsz = latents.shape[0]
|
||||
timesteps = torch.randint(
|
||||
0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
|
||||
0,
|
||||
self.noise_scheduler.config.num_train_timesteps,
|
||||
(bsz,),
|
||||
device=latents.device,
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
@@ -224,9 +237,6 @@ class TrainSD:
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1."
|
||||
)
|
||||
parser.add_argument("--profile_duration", type=int, default=10000, help="Profile duration in ms")
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
@@ -258,12 +268,6 @@ def parse_args():
|
||||
" or to a folder containing files that 🤗 Datasets can understand."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_config_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The config of the Dataset, leave as None if there's only one config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_data_dir",
|
||||
type=str,
|
||||
@@ -283,15 +287,6 @@ def parse_args():
|
||||
default="text",
|
||||
help="The column of the dataset containing a caption or a list of captions.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_train_samples",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
@@ -304,7 +299,6 @@ def parse_args():
|
||||
default=None,
|
||||
help="The directory where the downloaded models and datasets will be stored.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
@@ -374,12 +368,19 @@ def parse_args():
|
||||
default=1,
|
||||
help=("Number of subprocesses to use for data loading to cpu."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--loader_prefetch_factor",
|
||||
type=int,
|
||||
default=2,
|
||||
help=("Number of batches loaded in advance by each worker."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device_prefetch_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help=("Number of subprocesses to use for data loading to tpu from cpu. "),
|
||||
)
|
||||
parser.add_argument("--measure_start_step", type=int, default=10, help="Step to start profiling.")
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
||||
@@ -394,12 +395,8 @@ def parse_args():
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
||||
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
||||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
||||
),
|
||||
choices=["no", "bf16"],
|
||||
help=("Whether to use mixed precision. Bf16 requires PyTorch >= 1.10"),
|
||||
)
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
||||
@@ -409,6 +406,12 @@ def parse_args():
|
||||
default=None,
|
||||
help="The name of the repository to keep in sync with the local `output_dir`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--print_loss",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=("Print loss at every step."),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -436,7 +439,6 @@ def load_dataset(args):
|
||||
# Downloading and loading a dataset from the hub.
|
||||
dataset = datasets.load_dataset(
|
||||
args.dataset_name,
|
||||
args.dataset_config_name,
|
||||
cache_dir=args.cache_dir,
|
||||
data_dir=args.train_data_dir,
|
||||
)
|
||||
@@ -481,9 +483,7 @@ def main(args):
|
||||
_ = xp.start_server(PORT)
|
||||
|
||||
num_devices = xr.global_runtime_device_count()
|
||||
device_ids = np.arange(num_devices)
|
||||
mesh_shape = (num_devices, 1)
|
||||
mesh = xs.Mesh(device_ids, mesh_shape, ("x", "y"))
|
||||
mesh = xs.get_1d_mesh("data")
|
||||
xs.set_global_mesh(mesh)
|
||||
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
@@ -520,6 +520,7 @@ def main(args):
|
||||
from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear
|
||||
|
||||
unet = apply_xla_patch_to_nn_linear(unet, xs.xla_patched_nn_linear_forward)
|
||||
unet.enable_xla_flash_attention(partition_spec=("data", None, None, None))
|
||||
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
@@ -530,15 +531,12 @@ def main(args):
|
||||
# as these weights are only used for inference, keeping weights in full
|
||||
# precision is not required.
|
||||
weight_dtype = torch.float32
|
||||
if args.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif args.mixed_precision == "bf16":
|
||||
if args.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
device = xm.xla_device()
|
||||
print("device: ", device)
|
||||
print("weight_dtype: ", weight_dtype)
|
||||
|
||||
# Move text_encode and vae to device and cast to weight_dtype
|
||||
text_encoder = text_encoder.to(device, dtype=weight_dtype)
|
||||
vae = vae.to(device, dtype=weight_dtype)
|
||||
unet = unet.to(device, dtype=weight_dtype)
|
||||
@@ -606,24 +604,27 @@ def main(args):
|
||||
collate_fn=collate_fn,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
batch_size=args.train_batch_size,
|
||||
prefetch_factor=args.loader_prefetch_factor,
|
||||
)
|
||||
|
||||
train_dataloader = pl.MpDeviceLoader(
|
||||
train_dataloader,
|
||||
device,
|
||||
input_sharding={
|
||||
"pixel_values": xs.ShardingSpec(mesh, ("x", None, None, None), minibatch=True),
|
||||
"input_ids": xs.ShardingSpec(mesh, ("x", None), minibatch=True),
|
||||
"pixel_values": xs.ShardingSpec(mesh, ("data", None, None, None), minibatch=True),
|
||||
"input_ids": xs.ShardingSpec(mesh, ("data", None), minibatch=True),
|
||||
},
|
||||
loader_prefetch_size=args.loader_prefetch_size,
|
||||
device_prefetch_size=args.device_prefetch_size,
|
||||
)
|
||||
|
||||
num_hosts = xr.process_count()
|
||||
num_devices_per_host = num_devices // num_hosts
|
||||
if xm.is_master_ordinal():
|
||||
print("***** Running training *****")
|
||||
print(f"Instantaneous batch size per device = {args.train_batch_size}")
|
||||
print(f"Instantaneous batch size per device = {args.train_batch_size // num_devices_per_host }")
|
||||
print(
|
||||
f"Total train batch size (w. parallel, distributed & accumulation) = {args.train_batch_size * num_devices}"
|
||||
f"Total train batch size (w. parallel, distributed & accumulation) = {args.train_batch_size * num_hosts}"
|
||||
)
|
||||
print(f" Total optimization steps = {args.max_train_steps}")
|
||||
|
||||
|
||||
@@ -0,0 +1,323 @@
|
||||
import argparse
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from diffusers import AutoencoderDC
|
||||
|
||||
|
||||
def remap_qkv_(key: str, state_dict: Dict[str, Any]):
|
||||
qkv = state_dict.pop(key)
|
||||
q, k, v = torch.chunk(qkv, 3, dim=0)
|
||||
parent_module, _, _ = key.rpartition(".qkv.conv.weight")
|
||||
state_dict[f"{parent_module}.to_q.weight"] = q.squeeze()
|
||||
state_dict[f"{parent_module}.to_k.weight"] = k.squeeze()
|
||||
state_dict[f"{parent_module}.to_v.weight"] = v.squeeze()
|
||||
|
||||
|
||||
def remap_proj_conv_(key: str, state_dict: Dict[str, Any]):
|
||||
parent_module, _, _ = key.rpartition(".proj.conv.weight")
|
||||
state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze()
|
||||
|
||||
|
||||
AE_KEYS_RENAME_DICT = {
|
||||
# common
|
||||
"main.": "",
|
||||
"op_list.": "",
|
||||
"context_module": "attn",
|
||||
"local_module": "conv_out",
|
||||
# NOTE: The below two lines work because scales in the available configs only have a tuple length of 1
|
||||
# If there were more scales, there would be more layers, so a loop would be better to handle this
|
||||
"aggreg.0.0": "to_qkv_multiscale.0.proj_in",
|
||||
"aggreg.0.1": "to_qkv_multiscale.0.proj_out",
|
||||
"depth_conv.conv": "conv_depth",
|
||||
"inverted_conv.conv": "conv_inverted",
|
||||
"point_conv.conv": "conv_point",
|
||||
"point_conv.norm": "norm",
|
||||
"conv.conv.": "conv.",
|
||||
"conv1.conv": "conv1",
|
||||
"conv2.conv": "conv2",
|
||||
"conv2.norm": "norm",
|
||||
"proj.norm": "norm_out",
|
||||
# encoder
|
||||
"encoder.project_in.conv": "encoder.conv_in",
|
||||
"encoder.project_out.0.conv": "encoder.conv_out",
|
||||
"encoder.stages": "encoder.down_blocks",
|
||||
# decoder
|
||||
"decoder.project_in.conv": "decoder.conv_in",
|
||||
"decoder.project_out.0": "decoder.norm_out",
|
||||
"decoder.project_out.2.conv": "decoder.conv_out",
|
||||
"decoder.stages": "decoder.up_blocks",
|
||||
}
|
||||
|
||||
AE_F32C32_KEYS = {
|
||||
# encoder
|
||||
"encoder.project_in.conv": "encoder.conv_in.conv",
|
||||
# decoder
|
||||
"decoder.project_out.2.conv": "decoder.conv_out.conv",
|
||||
}
|
||||
|
||||
AE_F64C128_KEYS = {
|
||||
# encoder
|
||||
"encoder.project_in.conv": "encoder.conv_in.conv",
|
||||
# decoder
|
||||
"decoder.project_out.2.conv": "decoder.conv_out.conv",
|
||||
}
|
||||
|
||||
AE_F128C512_KEYS = {
|
||||
# encoder
|
||||
"encoder.project_in.conv": "encoder.conv_in.conv",
|
||||
# decoder
|
||||
"decoder.project_out.2.conv": "decoder.conv_out.conv",
|
||||
}
|
||||
|
||||
AE_SPECIAL_KEYS_REMAP = {
|
||||
"qkv.conv.weight": remap_qkv_,
|
||||
"proj.conv.weight": remap_proj_conv_,
|
||||
}
|
||||
|
||||
|
||||
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
state_dict = saved_dict
|
||||
if "model" in saved_dict.keys():
|
||||
state_dict = state_dict["model"]
|
||||
if "module" in saved_dict.keys():
|
||||
state_dict = state_dict["module"]
|
||||
if "state_dict" in saved_dict.keys():
|
||||
state_dict = state_dict["state_dict"]
|
||||
return state_dict
|
||||
|
||||
|
||||
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
|
||||
def convert_ae(config_name: str, dtype: torch.dtype):
|
||||
config = get_ae_config(config_name)
|
||||
hub_id = f"mit-han-lab/{config_name}"
|
||||
ckpt_path = hf_hub_download(hub_id, "model.safetensors")
|
||||
original_state_dict = get_state_dict(load_file(ckpt_path))
|
||||
|
||||
ae = AutoencoderDC(**config).to(dtype=dtype)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in AE_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_(original_state_dict, key, new_key)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
ae.load_state_dict(original_state_dict, strict=True)
|
||||
return ae
|
||||
|
||||
|
||||
def get_ae_config(name: str):
|
||||
if name in ["dc-ae-f32c32-sana-1.0"]:
|
||||
config = {
|
||||
"latent_channels": 32,
|
||||
"encoder_block_types": (
|
||||
"ResBlock",
|
||||
"ResBlock",
|
||||
"ResBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
),
|
||||
"decoder_block_types": (
|
||||
"ResBlock",
|
||||
"ResBlock",
|
||||
"ResBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
),
|
||||
"encoder_block_out_channels": (128, 256, 512, 512, 1024, 1024),
|
||||
"decoder_block_out_channels": (128, 256, 512, 512, 1024, 1024),
|
||||
"encoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)),
|
||||
"decoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)),
|
||||
"encoder_layers_per_block": (2, 2, 2, 3, 3, 3),
|
||||
"decoder_layers_per_block": [3, 3, 3, 3, 3, 3],
|
||||
"downsample_block_type": "conv",
|
||||
"upsample_block_type": "interpolate",
|
||||
"decoder_norm_types": "rms_norm",
|
||||
"decoder_act_fns": "silu",
|
||||
"scaling_factor": 0.41407,
|
||||
}
|
||||
elif name in ["dc-ae-f32c32-in-1.0", "dc-ae-f32c32-mix-1.0"]:
|
||||
AE_KEYS_RENAME_DICT.update(AE_F32C32_KEYS)
|
||||
config = {
|
||||
"latent_channels": 32,
|
||||
"encoder_block_types": [
|
||||
"ResBlock",
|
||||
"ResBlock",
|
||||
"ResBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
],
|
||||
"decoder_block_types": [
|
||||
"ResBlock",
|
||||
"ResBlock",
|
||||
"ResBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
],
|
||||
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024],
|
||||
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024],
|
||||
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2],
|
||||
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2],
|
||||
"encoder_qkv_multiscales": ((), (), (), (), (), ()),
|
||||
"decoder_qkv_multiscales": ((), (), (), (), (), ()),
|
||||
"decoder_norm_types": ["batch_norm", "batch_norm", "batch_norm", "rms_norm", "rms_norm", "rms_norm"],
|
||||
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu"],
|
||||
}
|
||||
if name == "dc-ae-f32c32-in-1.0":
|
||||
config["scaling_factor"] = 0.3189
|
||||
elif name == "dc-ae-f32c32-mix-1.0":
|
||||
config["scaling_factor"] = 0.4552
|
||||
elif name in ["dc-ae-f64c128-in-1.0", "dc-ae-f64c128-mix-1.0"]:
|
||||
AE_KEYS_RENAME_DICT.update(AE_F64C128_KEYS)
|
||||
config = {
|
||||
"latent_channels": 128,
|
||||
"encoder_block_types": [
|
||||
"ResBlock",
|
||||
"ResBlock",
|
||||
"ResBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
],
|
||||
"decoder_block_types": [
|
||||
"ResBlock",
|
||||
"ResBlock",
|
||||
"ResBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
],
|
||||
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048],
|
||||
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048],
|
||||
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2],
|
||||
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2],
|
||||
"encoder_qkv_multiscales": ((), (), (), (), (), (), ()),
|
||||
"decoder_qkv_multiscales": ((), (), (), (), (), (), ()),
|
||||
"decoder_norm_types": [
|
||||
"batch_norm",
|
||||
"batch_norm",
|
||||
"batch_norm",
|
||||
"rms_norm",
|
||||
"rms_norm",
|
||||
"rms_norm",
|
||||
"rms_norm",
|
||||
],
|
||||
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu"],
|
||||
}
|
||||
if name == "dc-ae-f64c128-in-1.0":
|
||||
config["scaling_factor"] = 0.2889
|
||||
elif name == "dc-ae-f64c128-mix-1.0":
|
||||
config["scaling_factor"] = 0.4538
|
||||
elif name in ["dc-ae-f128c512-in-1.0", "dc-ae-f128c512-mix-1.0"]:
|
||||
AE_KEYS_RENAME_DICT.update(AE_F128C512_KEYS)
|
||||
config = {
|
||||
"latent_channels": 512,
|
||||
"encoder_block_types": [
|
||||
"ResBlock",
|
||||
"ResBlock",
|
||||
"ResBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
],
|
||||
"decoder_block_types": [
|
||||
"ResBlock",
|
||||
"ResBlock",
|
||||
"ResBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
"EfficientViTBlock",
|
||||
],
|
||||
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048],
|
||||
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048],
|
||||
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2, 2],
|
||||
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2, 2],
|
||||
"encoder_qkv_multiscales": ((), (), (), (), (), (), (), ()),
|
||||
"decoder_qkv_multiscales": ((), (), (), (), (), (), (), ()),
|
||||
"decoder_norm_types": [
|
||||
"batch_norm",
|
||||
"batch_norm",
|
||||
"batch_norm",
|
||||
"rms_norm",
|
||||
"rms_norm",
|
||||
"rms_norm",
|
||||
"rms_norm",
|
||||
"rms_norm",
|
||||
],
|
||||
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu", "silu"],
|
||||
}
|
||||
if name == "dc-ae-f128c512-in-1.0":
|
||||
config["scaling_factor"] = 0.4883
|
||||
elif name == "dc-ae-f128c512-mix-1.0":
|
||||
config["scaling_factor"] = 0.3620
|
||||
else:
|
||||
raise ValueError("Invalid config name provided.")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--config_name",
|
||||
type=str,
|
||||
default="dc-ae-f32c32-sana-1.0",
|
||||
choices=[
|
||||
"dc-ae-f32c32-sana-1.0",
|
||||
"dc-ae-f32c32-in-1.0",
|
||||
"dc-ae-f32c32-mix-1.0",
|
||||
"dc-ae-f64c128-in-1.0",
|
||||
"dc-ae-f64c128-mix-1.0",
|
||||
"dc-ae-f128c512-in-1.0",
|
||||
"dc-ae-f128c512-mix-1.0",
|
||||
],
|
||||
help="The DCAE checkpoint to convert",
|
||||
)
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
DTYPE_MAPPING = {
|
||||
"fp32": torch.float32,
|
||||
"fp16": torch.float16,
|
||||
"bf16": torch.bfloat16,
|
||||
}
|
||||
|
||||
VARIANT_MAPPING = {
|
||||
"fp32": None,
|
||||
"fp16": "fp16",
|
||||
"bf16": "bf16",
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
dtype = DTYPE_MAPPING[args.dtype]
|
||||
variant = VARIANT_MAPPING[args.dtype]
|
||||
|
||||
ae = convert_ae(args.config_name, dtype)
|
||||
ae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant)
|
||||
@@ -0,0 +1,97 @@
|
||||
import argparse
|
||||
from contextlib import nullcontext
|
||||
|
||||
import safetensors.torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from diffusers.utils.import_utils import is_accelerate_available, is_transformers_available
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import CLIPVisionModelWithProjection
|
||||
|
||||
vision = True
|
||||
else:
|
||||
vision = False
|
||||
|
||||
"""
|
||||
python scripts/convert_flux_xlabs_ipadapter_to_diffusers.py \
|
||||
--original_state_dict_repo_id "XLabs-AI/flux-ip-adapter" \
|
||||
--filename "flux-ip-adapter.safetensors"
|
||||
--output_path "flux-ip-adapter-hf/"
|
||||
"""
|
||||
|
||||
|
||||
CTX = init_empty_weights if is_accelerate_available else nullcontext
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
|
||||
parser.add_argument("--filename", default="flux.safetensors", type=str)
|
||||
parser.add_argument("--checkpoint_path", default=None, type=str)
|
||||
parser.add_argument("--output_path", type=str)
|
||||
parser.add_argument("--vision_pretrained_or_path", default="openai/clip-vit-large-patch14", type=str)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def load_original_checkpoint(args):
|
||||
if args.original_state_dict_repo_id is not None:
|
||||
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
|
||||
elif args.checkpoint_path is not None:
|
||||
ckpt_path = args.checkpoint_path
|
||||
else:
|
||||
raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
|
||||
|
||||
original_state_dict = safetensors.torch.load_file(ckpt_path)
|
||||
return original_state_dict
|
||||
|
||||
|
||||
def convert_flux_ipadapter_checkpoint_to_diffusers(original_state_dict, num_layers):
|
||||
converted_state_dict = {}
|
||||
|
||||
# image_proj
|
||||
## norm
|
||||
converted_state_dict["image_proj.norm.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
|
||||
converted_state_dict["image_proj.norm.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")
|
||||
## proj
|
||||
converted_state_dict["image_proj.proj.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
|
||||
converted_state_dict["image_proj.proj.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")
|
||||
|
||||
# double transformer blocks
|
||||
for i in range(num_layers):
|
||||
block_prefix = f"ip_adapter.{i}."
|
||||
# to_k_ip
|
||||
converted_state_dict[f"{block_prefix}to_k_ip.bias"] = original_state_dict.pop(
|
||||
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.bias"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
|
||||
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.weight"
|
||||
)
|
||||
# to_v_ip
|
||||
converted_state_dict[f"{block_prefix}to_v_ip.bias"] = original_state_dict.pop(
|
||||
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.bias"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
|
||||
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.weight"
|
||||
)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def main(args):
|
||||
original_ckpt = load_original_checkpoint(args)
|
||||
|
||||
num_layers = 19
|
||||
converted_ip_adapter_state_dict = convert_flux_ipadapter_checkpoint_to_diffusers(original_ckpt, num_layers)
|
||||
|
||||
print("Saving Flux IP-Adapter in Diffusers format.")
|
||||
safetensors.torch.save_file(converted_ip_adapter_state_dict, f"{args.output_path}/model.safetensors")
|
||||
|
||||
if vision:
|
||||
model = CLIPVisionModelWithProjection.from_pretrained(args.vision_pretrained_or_path)
|
||||
model.save_pretrained(f"{args.output_path}/image_encoder")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(args)
|
||||
@@ -0,0 +1,257 @@
|
||||
import argparse
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLHunyuanVideo,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
HunyuanVideoPipeline,
|
||||
HunyuanVideoTransformer3DModel,
|
||||
)
|
||||
|
||||
|
||||
def remap_norm_scale_shift_(key, state_dict):
|
||||
weight = state_dict.pop(key)
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
state_dict[key.replace("final_layer.adaLN_modulation.1", "norm_out.linear")] = new_weight
|
||||
|
||||
|
||||
def remap_txt_in_(key, state_dict):
|
||||
def rename_key(key):
|
||||
new_key = key.replace("individual_token_refiner.blocks", "token_refiner.refiner_blocks")
|
||||
new_key = new_key.replace("adaLN_modulation.1", "norm_out.linear")
|
||||
new_key = new_key.replace("txt_in", "context_embedder")
|
||||
new_key = new_key.replace("t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1")
|
||||
new_key = new_key.replace("t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2")
|
||||
new_key = new_key.replace("c_embedder", "time_text_embed.text_embedder")
|
||||
new_key = new_key.replace("mlp", "ff")
|
||||
return new_key
|
||||
|
||||
if "self_attn_qkv" in key:
|
||||
weight = state_dict.pop(key)
|
||||
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
||||
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_q"))] = to_q
|
||||
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_k"))] = to_k
|
||||
state_dict[rename_key(key.replace("self_attn_qkv", "attn.to_v"))] = to_v
|
||||
else:
|
||||
state_dict[rename_key(key)] = state_dict.pop(key)
|
||||
|
||||
|
||||
def remap_img_attn_qkv_(key, state_dict):
|
||||
weight = state_dict.pop(key)
|
||||
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_q")] = to_q
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_k")] = to_k
|
||||
state_dict[key.replace("img_attn_qkv", "attn.to_v")] = to_v
|
||||
|
||||
|
||||
def remap_txt_attn_qkv_(key, state_dict):
|
||||
weight = state_dict.pop(key)
|
||||
to_q, to_k, to_v = weight.chunk(3, dim=0)
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_q_proj")] = to_q
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_k_proj")] = to_k
|
||||
state_dict[key.replace("txt_attn_qkv", "attn.add_v_proj")] = to_v
|
||||
|
||||
|
||||
def remap_single_transformer_blocks_(key, state_dict):
|
||||
hidden_size = 3072
|
||||
|
||||
if "linear1.weight" in key:
|
||||
linear1_weight = state_dict.pop(key)
|
||||
split_size = (hidden_size, hidden_size, hidden_size, linear1_weight.size(0) - 3 * hidden_size)
|
||||
q, k, v, mlp = torch.split(linear1_weight, split_size, dim=0)
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.weight")
|
||||
state_dict[f"{new_key}.attn.to_q.weight"] = q
|
||||
state_dict[f"{new_key}.attn.to_k.weight"] = k
|
||||
state_dict[f"{new_key}.attn.to_v.weight"] = v
|
||||
state_dict[f"{new_key}.proj_mlp.weight"] = mlp
|
||||
|
||||
elif "linear1.bias" in key:
|
||||
linear1_bias = state_dict.pop(key)
|
||||
split_size = (hidden_size, hidden_size, hidden_size, linear1_bias.size(0) - 3 * hidden_size)
|
||||
q_bias, k_bias, v_bias, mlp_bias = torch.split(linear1_bias, split_size, dim=0)
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks").removesuffix(".linear1.bias")
|
||||
state_dict[f"{new_key}.attn.to_q.bias"] = q_bias
|
||||
state_dict[f"{new_key}.attn.to_k.bias"] = k_bias
|
||||
state_dict[f"{new_key}.attn.to_v.bias"] = v_bias
|
||||
state_dict[f"{new_key}.proj_mlp.bias"] = mlp_bias
|
||||
|
||||
else:
|
||||
new_key = key.replace("single_blocks", "single_transformer_blocks")
|
||||
new_key = new_key.replace("linear2", "proj_out")
|
||||
new_key = new_key.replace("q_norm", "attn.norm_q")
|
||||
new_key = new_key.replace("k_norm", "attn.norm_k")
|
||||
state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
|
||||
TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"img_in": "x_embedder",
|
||||
"time_in.mlp.0": "time_text_embed.timestep_embedder.linear_1",
|
||||
"time_in.mlp.2": "time_text_embed.timestep_embedder.linear_2",
|
||||
"guidance_in.mlp.0": "time_text_embed.guidance_embedder.linear_1",
|
||||
"guidance_in.mlp.2": "time_text_embed.guidance_embedder.linear_2",
|
||||
"vector_in.in_layer": "time_text_embed.text_embedder.linear_1",
|
||||
"vector_in.out_layer": "time_text_embed.text_embedder.linear_2",
|
||||
"double_blocks": "transformer_blocks",
|
||||
"img_attn_q_norm": "attn.norm_q",
|
||||
"img_attn_k_norm": "attn.norm_k",
|
||||
"img_attn_proj": "attn.to_out.0",
|
||||
"txt_attn_q_norm": "attn.norm_added_q",
|
||||
"txt_attn_k_norm": "attn.norm_added_k",
|
||||
"txt_attn_proj": "attn.to_add_out",
|
||||
"img_mod.linear": "norm1.linear",
|
||||
"img_norm1": "norm1.norm",
|
||||
"img_norm2": "norm2",
|
||||
"img_mlp": "ff",
|
||||
"txt_mod.linear": "norm1_context.linear",
|
||||
"txt_norm1": "norm1.norm",
|
||||
"txt_norm2": "norm2_context",
|
||||
"txt_mlp": "ff_context",
|
||||
"self_attn_proj": "attn.to_out.0",
|
||||
"modulation.linear": "norm.linear",
|
||||
"pre_norm": "norm.norm",
|
||||
"final_layer.norm_final": "norm_out.norm",
|
||||
"final_layer.linear": "proj_out",
|
||||
"fc1": "net.0.proj",
|
||||
"fc2": "net.2",
|
||||
"input_embedder": "proj_in",
|
||||
}
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"txt_in": remap_txt_in_,
|
||||
"img_attn_qkv": remap_img_attn_qkv_,
|
||||
"txt_attn_qkv": remap_txt_attn_qkv_,
|
||||
"single_blocks": remap_single_transformer_blocks_,
|
||||
"final_layer.adaLN_modulation.1": remap_norm_scale_shift_,
|
||||
}
|
||||
|
||||
VAE_KEYS_RENAME_DICT = {}
|
||||
|
||||
VAE_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
|
||||
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
|
||||
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
state_dict = saved_dict
|
||||
if "model" in saved_dict.keys():
|
||||
state_dict = state_dict["model"]
|
||||
if "module" in saved_dict.keys():
|
||||
state_dict = state_dict["module"]
|
||||
if "state_dict" in saved_dict.keys():
|
||||
state_dict = state_dict["state_dict"]
|
||||
return state_dict
|
||||
|
||||
|
||||
def convert_transformer(ckpt_path: str):
|
||||
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))
|
||||
|
||||
with init_empty_weights():
|
||||
transformer = HunyuanVideoTransformer3DModel()
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_(original_state_dict, key, new_key)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return transformer
|
||||
|
||||
|
||||
def convert_vae(ckpt_path: str):
|
||||
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))
|
||||
|
||||
with init_empty_weights():
|
||||
vae = AutoencoderKLHunyuanVideo()
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_(original_state_dict, key, new_key)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
vae.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return vae
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
|
||||
)
|
||||
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original VAE checkpoint")
|
||||
parser.add_argument("--text_encoder_path", type=str, default=None, help="Path to original llama checkpoint")
|
||||
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to original llama tokenizer")
|
||||
parser.add_argument("--text_encoder_2_path", type=str, default=None, help="Path to original clip checkpoint")
|
||||
parser.add_argument("--save_pipeline", action="store_true")
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
DTYPE_MAPPING = {
|
||||
"fp32": torch.float32,
|
||||
"fp16": torch.float16,
|
||||
"bf16": torch.bfloat16,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
transformer = None
|
||||
dtype = DTYPE_MAPPING[args.dtype]
|
||||
|
||||
if args.save_pipeline:
|
||||
assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None
|
||||
assert args.text_encoder_path is not None
|
||||
assert args.tokenizer_path is not None
|
||||
assert args.text_encoder_2_path is not None
|
||||
|
||||
if args.transformer_ckpt_path is not None:
|
||||
transformer = convert_transformer(args.transformer_ckpt_path)
|
||||
transformer = transformer.to(dtype=dtype)
|
||||
if not args.save_pipeline:
|
||||
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
if args.vae_ckpt_path is not None:
|
||||
vae = convert_vae(args.vae_ckpt_path)
|
||||
if not args.save_pipeline:
|
||||
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
if args.save_pipeline:
|
||||
text_encoder = AutoModel.from_pretrained(args.text_encoder_path, torch_dtype=torch.float16)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path, padding_side="right")
|
||||
text_encoder_2 = CLIPTextModel.from_pretrained(args.text_encoder_2_path, torch_dtype=torch.float16)
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained(args.text_encoder_2_path)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
|
||||
|
||||
pipe = HunyuanVideoPipeline(
|
||||
transformer=transformer,
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_2=tokenizer_2,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
@@ -0,0 +1,209 @@
|
||||
import argparse
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
|
||||
from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LTXPipeline, LTXVideoTransformer3DModel
|
||||
|
||||
|
||||
def remove_keys_(key: str, state_dict: Dict[str, Any]):
|
||||
state_dict.pop(key)
|
||||
|
||||
|
||||
TOKENIZER_MAX_LENGTH = 128
|
||||
|
||||
TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"patchify_proj": "proj_in",
|
||||
"adaln_single": "time_embed",
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
VAE_KEYS_RENAME_DICT = {
|
||||
# decoder
|
||||
"up_blocks.0": "mid_block",
|
||||
"up_blocks.1": "up_blocks.0",
|
||||
"up_blocks.2": "up_blocks.1.upsamplers.0",
|
||||
"up_blocks.3": "up_blocks.1",
|
||||
"up_blocks.4": "up_blocks.2.conv_in",
|
||||
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
||||
"up_blocks.6": "up_blocks.2",
|
||||
"up_blocks.7": "up_blocks.3.conv_in",
|
||||
"up_blocks.8": "up_blocks.3.upsamplers.0",
|
||||
"up_blocks.9": "up_blocks.3",
|
||||
# encoder
|
||||
"down_blocks.0": "down_blocks.0",
|
||||
"down_blocks.1": "down_blocks.0.downsamplers.0",
|
||||
"down_blocks.2": "down_blocks.0.conv_out",
|
||||
"down_blocks.3": "down_blocks.1",
|
||||
"down_blocks.4": "down_blocks.1.downsamplers.0",
|
||||
"down_blocks.5": "down_blocks.1.conv_out",
|
||||
"down_blocks.6": "down_blocks.2",
|
||||
"down_blocks.7": "down_blocks.2.downsamplers.0",
|
||||
"down_blocks.8": "down_blocks.3",
|
||||
"down_blocks.9": "mid_block",
|
||||
# common
|
||||
"conv_shortcut": "conv_shortcut.conv",
|
||||
"res_blocks": "resnets",
|
||||
"norm3.norm": "norm3",
|
||||
"per_channel_statistics.mean-of-means": "latents_mean",
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
}
|
||||
|
||||
VAE_SPECIAL_KEYS_REMAP = {
|
||||
"per_channel_statistics.channel": remove_keys_,
|
||||
"per_channel_statistics.mean-of-means": remove_keys_,
|
||||
"per_channel_statistics.mean-of-stds": remove_keys_,
|
||||
}
|
||||
|
||||
|
||||
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
state_dict = saved_dict
|
||||
if "model" in saved_dict.keys():
|
||||
state_dict = state_dict["model"]
|
||||
if "module" in saved_dict.keys():
|
||||
state_dict = state_dict["module"]
|
||||
if "state_dict" in saved_dict.keys():
|
||||
state_dict = state_dict["state_dict"]
|
||||
return state_dict
|
||||
|
||||
|
||||
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
|
||||
def convert_transformer(
|
||||
ckpt_path: str,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
PREFIX_KEY = ""
|
||||
|
||||
original_state_dict = get_state_dict(load_file(ckpt_path))
|
||||
transformer = LTXVideoTransformer3DModel().to(dtype=dtype)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[len(PREFIX_KEY) :]
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
transformer.load_state_dict(original_state_dict, strict=True)
|
||||
return transformer
|
||||
|
||||
|
||||
def convert_vae(ckpt_path: str, dtype: torch.dtype):
|
||||
original_state_dict = get_state_dict(load_file(ckpt_path))
|
||||
vae = AutoencoderKLLTXVideo().to(dtype=dtype)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
vae.load_state_dict(original_state_dict, strict=True)
|
||||
return vae
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
|
||||
)
|
||||
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original vae checkpoint")
|
||||
parser.add_argument(
|
||||
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--typecast_text_encoder",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether or not to apply fp16/bf16 precision to text_encoder",
|
||||
)
|
||||
parser.add_argument("--save_pipeline", action="store_true")
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
DTYPE_MAPPING = {
|
||||
"fp32": torch.float32,
|
||||
"fp16": torch.float16,
|
||||
"bf16": torch.bfloat16,
|
||||
}
|
||||
|
||||
VARIANT_MAPPING = {
|
||||
"fp32": None,
|
||||
"fp16": "fp16",
|
||||
"bf16": "bf16",
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
transformer = None
|
||||
dtype = DTYPE_MAPPING[args.dtype]
|
||||
variant = VARIANT_MAPPING[args.dtype]
|
||||
|
||||
if args.save_pipeline:
|
||||
assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None
|
||||
|
||||
if args.transformer_ckpt_path is not None:
|
||||
transformer: LTXVideoTransformer3DModel = convert_transformer(args.transformer_ckpt_path, dtype)
|
||||
if not args.save_pipeline:
|
||||
transformer.save_pretrained(
|
||||
args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant
|
||||
)
|
||||
|
||||
if args.vae_ckpt_path is not None:
|
||||
vae: AutoencoderKLLTXVideo = convert_vae(args.vae_ckpt_path, dtype)
|
||||
if not args.save_pipeline:
|
||||
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant)
|
||||
|
||||
if args.save_pipeline:
|
||||
text_encoder_id = "google/t5-v1_1-xxl"
|
||||
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
|
||||
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
|
||||
|
||||
if args.typecast_text_encoder:
|
||||
text_encoder = text_encoder.to(dtype=dtype)
|
||||
|
||||
# Apparently, the conversion does not work anymore without this :shrug:
|
||||
for param in text_encoder.parameters():
|
||||
param.data = param.data.contiguous()
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
use_dynamic_shifting=True,
|
||||
base_shift=0.95,
|
||||
max_shift=2.05,
|
||||
base_image_seq_len=1024,
|
||||
max_image_seq_len=4096,
|
||||
shift_terminal=0.1,
|
||||
)
|
||||
|
||||
pipe = LTXPipeline(
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, variant=variant, max_shard_size="5GB")
|
||||
@@ -0,0 +1,308 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
from termcolor import colored
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderDC,
|
||||
DPMSolverMultistepScheduler,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
SanaPipeline,
|
||||
SanaTransformer2DModel,
|
||||
)
|
||||
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
|
||||
CTX = init_empty_weights if is_accelerate_available else nullcontext
|
||||
|
||||
ckpt_ids = [
|
||||
"Efficient-Large-Model/Sana_1600M_2Kpx_BF16/checkpoints/Sana_1600M_2Kpx_BF16.pth",
|
||||
"Efficient-Large-Model/Sana_1600M_1024px_MultiLing/checkpoints/Sana_1600M_1024px_MultiLing.pth",
|
||||
"Efficient-Large-Model/Sana_1600M_1024px_BF16/checkpoints/Sana_1600M_1024px_BF16.pth",
|
||||
"Efficient-Large-Model/Sana_1600M_512px_MultiLing/checkpoints/Sana_1600M_512px_MultiLing.pth",
|
||||
"Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth",
|
||||
"Efficient-Large-Model/Sana_1600M_512px/checkpoints/Sana_1600M_512px.pth",
|
||||
"Efficient-Large-Model/Sana_600M_1024px/checkpoints/Sana_600M_1024px_MultiLing.pth",
|
||||
"Efficient-Large-Model/Sana_600M_512px/checkpoints/Sana_600M_512px_MultiLing.pth",
|
||||
]
|
||||
# https://github.com/NVlabs/Sana/blob/main/scripts/inference.py
|
||||
|
||||
|
||||
def main(args):
|
||||
cache_dir_path = os.path.expanduser("~/.cache/huggingface/hub")
|
||||
|
||||
if args.orig_ckpt_path is None or args.orig_ckpt_path in ckpt_ids:
|
||||
ckpt_id = args.orig_ckpt_path or ckpt_ids[0]
|
||||
snapshot_download(
|
||||
repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}",
|
||||
cache_dir=cache_dir_path,
|
||||
repo_type="model",
|
||||
)
|
||||
file_path = hf_hub_download(
|
||||
repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}",
|
||||
filename=f"{'/'.join(ckpt_id.split('/')[2:])}",
|
||||
cache_dir=cache_dir_path,
|
||||
repo_type="model",
|
||||
)
|
||||
else:
|
||||
file_path = args.orig_ckpt_path
|
||||
|
||||
print(colored(f"Loading checkpoint from {file_path}", "green", attrs=["bold"]))
|
||||
all_state_dict = torch.load(file_path, weights_only=True)
|
||||
state_dict = all_state_dict.pop("state_dict")
|
||||
converted_state_dict = {}
|
||||
|
||||
# Patch embeddings.
|
||||
converted_state_dict["patch_embed.proj.weight"] = state_dict.pop("x_embedder.proj.weight")
|
||||
converted_state_dict["patch_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias")
|
||||
|
||||
# Caption projection.
|
||||
converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight")
|
||||
converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias")
|
||||
converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
|
||||
converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
|
||||
|
||||
# AdaLN-single LN
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
|
||||
"t_embedder.mlp.0.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
|
||||
"t_embedder.mlp.2.weight"
|
||||
)
|
||||
converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
|
||||
|
||||
# Shared norm.
|
||||
converted_state_dict["time_embed.linear.weight"] = state_dict.pop("t_block.1.weight")
|
||||
converted_state_dict["time_embed.linear.bias"] = state_dict.pop("t_block.1.bias")
|
||||
|
||||
# y norm
|
||||
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
|
||||
|
||||
flow_shift = 3.0
|
||||
if args.model_type == "SanaMS_1600M_P1_D20":
|
||||
layer_num = 20
|
||||
elif args.model_type == "SanaMS_600M_P1_D28":
|
||||
layer_num = 28
|
||||
else:
|
||||
raise ValueError(f"{args.model_type} is not supported.")
|
||||
|
||||
for depth in range(layer_num):
|
||||
# Transformer blocks.
|
||||
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
|
||||
f"blocks.{depth}.scale_shift_table"
|
||||
)
|
||||
|
||||
# Linear Attention is all you need 🤘
|
||||
# Self attention.
|
||||
q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
|
||||
# Projection.
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.attn.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop(
|
||||
f"blocks.{depth}.attn.proj.bias"
|
||||
)
|
||||
|
||||
# Feed-forward.
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.mlp.inverted_conv.conv.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.bias"] = state_dict.pop(
|
||||
f"blocks.{depth}.mlp.inverted_conv.conv.bias"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.mlp.depth_conv.conv.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.bias"] = state_dict.pop(
|
||||
f"blocks.{depth}.mlp.depth_conv.conv.bias"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.ff.conv_point.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.mlp.point_conv.conv.weight"
|
||||
)
|
||||
|
||||
# Cross-attention.
|
||||
q = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.weight")
|
||||
q_bias = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.bias")
|
||||
k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.weight"), 2, dim=0)
|
||||
k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.bias"), 2, dim=0)
|
||||
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
|
||||
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop(
|
||||
f"blocks.{depth}.cross_attn.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = state_dict.pop(
|
||||
f"blocks.{depth}.cross_attn.proj.bias"
|
||||
)
|
||||
|
||||
# Final block.
|
||||
converted_state_dict["proj_out.weight"] = state_dict.pop("final_layer.linear.weight")
|
||||
converted_state_dict["proj_out.bias"] = state_dict.pop("final_layer.linear.bias")
|
||||
converted_state_dict["scale_shift_table"] = state_dict.pop("final_layer.scale_shift_table")
|
||||
|
||||
# Transformer
|
||||
with CTX():
|
||||
transformer = SanaTransformer2DModel(
|
||||
in_channels=32,
|
||||
out_channels=32,
|
||||
num_attention_heads=model_kwargs[args.model_type]["num_attention_heads"],
|
||||
attention_head_dim=model_kwargs[args.model_type]["attention_head_dim"],
|
||||
num_layers=model_kwargs[args.model_type]["num_layers"],
|
||||
num_cross_attention_heads=model_kwargs[args.model_type]["num_cross_attention_heads"],
|
||||
cross_attention_head_dim=model_kwargs[args.model_type]["cross_attention_head_dim"],
|
||||
cross_attention_dim=model_kwargs[args.model_type]["cross_attention_dim"],
|
||||
caption_channels=2304,
|
||||
mlp_ratio=2.5,
|
||||
attention_bias=False,
|
||||
sample_size=args.image_size // 32,
|
||||
patch_size=1,
|
||||
norm_elementwise_affine=False,
|
||||
norm_eps=1e-6,
|
||||
)
|
||||
|
||||
if is_accelerate_available():
|
||||
load_model_dict_into_meta(transformer, converted_state_dict)
|
||||
else:
|
||||
transformer.load_state_dict(converted_state_dict, strict=True, assign=True)
|
||||
|
||||
try:
|
||||
state_dict.pop("y_embedder.y_embedding")
|
||||
state_dict.pop("pos_embed")
|
||||
except KeyError:
|
||||
print("y_embedder.y_embedding or pos_embed not found in the state_dict")
|
||||
|
||||
assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}"
|
||||
|
||||
num_model_params = sum(p.numel() for p in transformer.parameters())
|
||||
print(f"Total number of transformer parameters: {num_model_params}")
|
||||
|
||||
transformer = transformer.to(weight_dtype)
|
||||
|
||||
if not args.save_full_pipeline:
|
||||
print(
|
||||
colored(
|
||||
f"Only saving transformer model of {args.model_type}. "
|
||||
f"Set --save_full_pipeline to save the whole SanaPipeline",
|
||||
"green",
|
||||
attrs=["bold"],
|
||||
)
|
||||
)
|
||||
transformer.save_pretrained(
|
||||
os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB", variant=variant
|
||||
)
|
||||
else:
|
||||
print(colored(f"Saving the whole SanaPipeline containing {args.model_type}", "green", attrs=["bold"]))
|
||||
# VAE
|
||||
ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers", torch_dtype=torch.float32)
|
||||
|
||||
# Text Encoder
|
||||
text_encoder_model_path = "google/gemma-2-2b-it"
|
||||
tokenizer = AutoTokenizer.from_pretrained(text_encoder_model_path)
|
||||
tokenizer.padding_side = "right"
|
||||
text_encoder = AutoModelForCausalLM.from_pretrained(
|
||||
text_encoder_model_path, torch_dtype=torch.bfloat16
|
||||
).get_decoder()
|
||||
|
||||
# Scheduler
|
||||
if args.scheduler_type == "flow-dpm_solver":
|
||||
scheduler = DPMSolverMultistepScheduler(
|
||||
flow_shift=flow_shift,
|
||||
use_flow_sigmas=True,
|
||||
prediction_type="flow_prediction",
|
||||
)
|
||||
elif args.scheduler_type == "flow-euler":
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
|
||||
else:
|
||||
raise ValueError(f"Scheduler type {args.scheduler_type} is not supported")
|
||||
|
||||
pipe = SanaPipeline(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
transformer=transformer,
|
||||
vae=ae,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB", variant=variant)
|
||||
|
||||
|
||||
DTYPE_MAPPING = {
|
||||
"fp32": torch.float32,
|
||||
"fp16": torch.float16,
|
||||
"bf16": torch.bfloat16,
|
||||
}
|
||||
|
||||
VARIANT_MAPPING = {
|
||||
"fp32": None,
|
||||
"fp16": "fp16",
|
||||
"bf16": "bf16",
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--orig_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_size",
|
||||
default=1024,
|
||||
type=int,
|
||||
choices=[512, 1024, 2048],
|
||||
required=False,
|
||||
help="Image size of pretrained model, 512, 1024 or 2048.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type", default="SanaMS_1600M_P1_D20", type=str, choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"]
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scheduler_type", default="flow-dpm_solver", type=str, choices=["flow-dpm_solver", "flow-euler"]
|
||||
)
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
|
||||
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipelien elemets in one.")
|
||||
parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
model_kwargs = {
|
||||
"SanaMS_1600M_P1_D20": {
|
||||
"num_attention_heads": 70,
|
||||
"attention_head_dim": 32,
|
||||
"num_cross_attention_heads": 20,
|
||||
"cross_attention_head_dim": 112,
|
||||
"cross_attention_dim": 2240,
|
||||
"num_layers": 20,
|
||||
},
|
||||
"SanaMS_600M_P1_D28": {
|
||||
"num_attention_heads": 36,
|
||||
"attention_head_dim": 32,
|
||||
"num_cross_attention_heads": 16,
|
||||
"cross_attention_head_dim": 72,
|
||||
"cross_attention_dim": 1152,
|
||||
"num_layers": 28,
|
||||
},
|
||||
}
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
weight_dtype = DTYPE_MAPPING[args.dtype]
|
||||
variant = VARIANT_MAPPING[args.dtype]
|
||||
|
||||
main(args)
|
||||
@@ -31,7 +31,7 @@ _import_structure = {
|
||||
"loaders": ["FromOriginalModelMixin"],
|
||||
"models": [],
|
||||
"pipelines": [],
|
||||
"quantizers.quantization_config": ["BitsAndBytesConfig"],
|
||||
"quantizers.quantization_config": ["BitsAndBytesConfig", "GGUFQuantizationConfig", "TorchAoConfig"],
|
||||
"schedulers": [],
|
||||
"utils": [
|
||||
"OptionalDependencyNotAvailable",
|
||||
@@ -80,9 +80,12 @@ else:
|
||||
"AllegroTransformer3DModel",
|
||||
"AsymmetricAutoencoderKL",
|
||||
"AuraFlowTransformer2DModel",
|
||||
"AutoencoderDC",
|
||||
"AutoencoderKL",
|
||||
"AutoencoderKLAllegro",
|
||||
"AutoencoderKLCogVideoX",
|
||||
"AutoencoderKLHunyuanVideo",
|
||||
"AutoencoderKLLTXVideo",
|
||||
"AutoencoderKLMochi",
|
||||
"AutoencoderKLTemporalDecoder",
|
||||
"AutoencoderOobleck",
|
||||
@@ -91,6 +94,7 @@ else:
|
||||
"CogView3PlusTransformer2DModel",
|
||||
"ConsistencyDecoderVAE",
|
||||
"ControlNetModel",
|
||||
"ControlNetUnionModel",
|
||||
"ControlNetXSAdapter",
|
||||
"DiTTransformer2DModel",
|
||||
"FluxControlNetModel",
|
||||
@@ -99,9 +103,11 @@ else:
|
||||
"HunyuanDiT2DControlNetModel",
|
||||
"HunyuanDiT2DModel",
|
||||
"HunyuanDiT2DMultiControlNetModel",
|
||||
"HunyuanVideoTransformer3DModel",
|
||||
"I2VGenXLUNet",
|
||||
"Kandinsky3UNet",
|
||||
"LatteTransformer3DModel",
|
||||
"LTXVideoTransformer3DModel",
|
||||
"LuminaNextDiT2DModel",
|
||||
"MochiTransformer3DModel",
|
||||
"ModelMixin",
|
||||
@@ -110,6 +116,7 @@ else:
|
||||
"MultiControlNetModel",
|
||||
"PixArtTransformer2DModel",
|
||||
"PriorTransformer",
|
||||
"SanaTransformer2DModel",
|
||||
"SD3ControlNetModel",
|
||||
"SD3MultiControlNetModel",
|
||||
"SD3Transformer2DModel",
|
||||
@@ -270,6 +277,7 @@ else:
|
||||
"CogView3PlusPipeline",
|
||||
"CycleDiffusionPipeline",
|
||||
"FluxControlImg2ImgPipeline",
|
||||
"FluxControlInpaintPipeline",
|
||||
"FluxControlNetImg2ImgPipeline",
|
||||
"FluxControlNetInpaintPipeline",
|
||||
"FluxControlNetPipeline",
|
||||
@@ -282,6 +290,7 @@ else:
|
||||
"HunyuanDiTControlNetPipeline",
|
||||
"HunyuanDiTPAGPipeline",
|
||||
"HunyuanDiTPipeline",
|
||||
"HunyuanVideoPipeline",
|
||||
"I2VGenXLPipeline",
|
||||
"IFImg2ImgPipeline",
|
||||
"IFImg2ImgSuperResolutionPipeline",
|
||||
@@ -315,6 +324,8 @@ else:
|
||||
"LDMTextToImagePipeline",
|
||||
"LEditsPPPipelineStableDiffusion",
|
||||
"LEditsPPPipelineStableDiffusionXL",
|
||||
"LTXImageToVideoPipeline",
|
||||
"LTXPipeline",
|
||||
"LuminaText2ImgPipeline",
|
||||
"MarigoldDepthPipeline",
|
||||
"MarigoldNormalsPipeline",
|
||||
@@ -326,6 +337,8 @@ else:
|
||||
"PixArtSigmaPAGPipeline",
|
||||
"PixArtSigmaPipeline",
|
||||
"ReduxImageEncoder",
|
||||
"SanaPAGPipeline",
|
||||
"SanaPipeline",
|
||||
"SemanticStableDiffusionPipeline",
|
||||
"ShapEImg2ImgPipeline",
|
||||
"ShapEPipeline",
|
||||
@@ -339,6 +352,7 @@ else:
|
||||
"StableDiffusion3Img2ImgPipeline",
|
||||
"StableDiffusion3InpaintPipeline",
|
||||
"StableDiffusion3PAGImg2ImgPipeline",
|
||||
"StableDiffusion3PAGImg2ImgPipeline",
|
||||
"StableDiffusion3PAGPipeline",
|
||||
"StableDiffusion3Pipeline",
|
||||
"StableDiffusionAdapterPipeline",
|
||||
@@ -362,6 +376,7 @@ else:
|
||||
"StableDiffusionLDM3DPipeline",
|
||||
"StableDiffusionModelEditingPipeline",
|
||||
"StableDiffusionPAGImg2ImgPipeline",
|
||||
"StableDiffusionPAGInpaintPipeline",
|
||||
"StableDiffusionPAGPipeline",
|
||||
"StableDiffusionPanoramaPipeline",
|
||||
"StableDiffusionParadigmsPipeline",
|
||||
@@ -376,6 +391,9 @@ else:
|
||||
"StableDiffusionXLControlNetPAGImg2ImgPipeline",
|
||||
"StableDiffusionXLControlNetPAGPipeline",
|
||||
"StableDiffusionXLControlNetPipeline",
|
||||
"StableDiffusionXLControlNetUnionImg2ImgPipeline",
|
||||
"StableDiffusionXLControlNetUnionInpaintPipeline",
|
||||
"StableDiffusionXLControlNetUnionPipeline",
|
||||
"StableDiffusionXLControlNetXSPipeline",
|
||||
"StableDiffusionXLImg2ImgPipeline",
|
||||
"StableDiffusionXLInpaintPipeline",
|
||||
@@ -552,7 +570,7 @@ else:
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .quantizers.quantization_config import BitsAndBytesConfig
|
||||
from .quantizers.quantization_config import BitsAndBytesConfig, GGUFQuantizationConfig, TorchAoConfig
|
||||
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
@@ -572,9 +590,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AllegroTransformer3DModel,
|
||||
AsymmetricAutoencoderKL,
|
||||
AuraFlowTransformer2DModel,
|
||||
AutoencoderDC,
|
||||
AutoencoderKL,
|
||||
AutoencoderKLAllegro,
|
||||
AutoencoderKLCogVideoX,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMochi,
|
||||
AutoencoderKLTemporalDecoder,
|
||||
AutoencoderOobleck,
|
||||
@@ -583,6 +604,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
CogView3PlusTransformer2DModel,
|
||||
ConsistencyDecoderVAE,
|
||||
ControlNetModel,
|
||||
ControlNetUnionModel,
|
||||
ControlNetXSAdapter,
|
||||
DiTTransformer2DModel,
|
||||
FluxControlNetModel,
|
||||
@@ -591,9 +613,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
HunyuanDiT2DControlNetModel,
|
||||
HunyuanDiT2DModel,
|
||||
HunyuanDiT2DMultiControlNetModel,
|
||||
HunyuanVideoTransformer3DModel,
|
||||
I2VGenXLUNet,
|
||||
Kandinsky3UNet,
|
||||
LatteTransformer3DModel,
|
||||
LTXVideoTransformer3DModel,
|
||||
LuminaNextDiT2DModel,
|
||||
MochiTransformer3DModel,
|
||||
ModelMixin,
|
||||
@@ -602,6 +626,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
MultiControlNetModel,
|
||||
PixArtTransformer2DModel,
|
||||
PriorTransformer,
|
||||
SanaTransformer2DModel,
|
||||
SD3ControlNetModel,
|
||||
SD3MultiControlNetModel,
|
||||
SD3Transformer2DModel,
|
||||
@@ -741,6 +766,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
CogView3PlusPipeline,
|
||||
CycleDiffusionPipeline,
|
||||
FluxControlImg2ImgPipeline,
|
||||
FluxControlInpaintPipeline,
|
||||
FluxControlNetImg2ImgPipeline,
|
||||
FluxControlNetInpaintPipeline,
|
||||
FluxControlNetPipeline,
|
||||
@@ -753,6 +779,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
HunyuanDiTControlNetPipeline,
|
||||
HunyuanDiTPAGPipeline,
|
||||
HunyuanDiTPipeline,
|
||||
HunyuanVideoPipeline,
|
||||
I2VGenXLPipeline,
|
||||
IFImg2ImgPipeline,
|
||||
IFImg2ImgSuperResolutionPipeline,
|
||||
@@ -786,6 +813,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LDMTextToImagePipeline,
|
||||
LEditsPPPipelineStableDiffusion,
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
LTXImageToVideoPipeline,
|
||||
LTXPipeline,
|
||||
LuminaText2ImgPipeline,
|
||||
MarigoldDepthPipeline,
|
||||
MarigoldNormalsPipeline,
|
||||
@@ -797,6 +826,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
PixArtSigmaPAGPipeline,
|
||||
PixArtSigmaPipeline,
|
||||
ReduxImageEncoder,
|
||||
SanaPAGPipeline,
|
||||
SanaPipeline,
|
||||
SemanticStableDiffusionPipeline,
|
||||
ShapEImg2ImgPipeline,
|
||||
ShapEPipeline,
|
||||
@@ -832,6 +863,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionLDM3DPipeline,
|
||||
StableDiffusionModelEditingPipeline,
|
||||
StableDiffusionPAGImg2ImgPipeline,
|
||||
StableDiffusionPAGInpaintPipeline,
|
||||
StableDiffusionPAGPipeline,
|
||||
StableDiffusionPanoramaPipeline,
|
||||
StableDiffusionParadigmsPipeline,
|
||||
@@ -846,6 +878,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionXLControlNetPAGImg2ImgPipeline,
|
||||
StableDiffusionXLControlNetPAGPipeline,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
StableDiffusionXLControlNetUnionImg2ImgPipeline,
|
||||
StableDiffusionXLControlNetUnionInpaintPipeline,
|
||||
StableDiffusionXLControlNetUnionPipeline,
|
||||
StableDiffusionXLControlNetXSPipeline,
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
|
||||
@@ -347,7 +347,6 @@ class ConfigMixin:
|
||||
_ = kwargs.pop("mirror", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
user_agent = kwargs.pop("user_agent", {})
|
||||
dduf_reader = kwargs.pop("dduf_reader", None)
|
||||
|
||||
user_agent = {**user_agent, "file_type": "config"}
|
||||
user_agent = http_user_agent(user_agent)
|
||||
@@ -359,22 +358,8 @@ class ConfigMixin:
|
||||
"`self.config_name` is not defined. Note that one should not load a config from "
|
||||
"`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
|
||||
)
|
||||
# Custom path for now
|
||||
if dduf_reader:
|
||||
if subfolder is not None:
|
||||
if dduf_reader.has_file(os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)):
|
||||
config_file = os.path.join(subfolder, cls.config_name)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"We did not manage to find the file {os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)} in the archive. We only have the following files {dduf_reader.files}"
|
||||
)
|
||||
elif dduf_reader.has_file(os.path.join(pretrained_model_name_or_path, cls.config_name)):
|
||||
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"We did not manage to find the file {os.path.join(pretrained_model_name_or_path, cls.config_name)} in the archive. We only have the following files {dduf_reader.files}"
|
||||
)
|
||||
elif os.path.isfile(pretrained_model_name_or_path):
|
||||
|
||||
if os.path.isfile(pretrained_model_name_or_path):
|
||||
config_file = pretrained_model_name_or_path
|
||||
elif os.path.isdir(pretrained_model_name_or_path):
|
||||
if subfolder is not None and os.path.isfile(
|
||||
@@ -441,8 +426,10 @@ class ConfigMixin:
|
||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||
f"containing a {cls.config_name} file"
|
||||
)
|
||||
|
||||
try:
|
||||
config_dict = cls._dict_from_json_file(config_file, dduf_reader=dduf_reader)
|
||||
# Load config dict
|
||||
config_dict = cls._dict_from_json_file(config_file)
|
||||
|
||||
commit_hash = extract_commit_hash(config_file)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
@@ -565,12 +552,9 @@ class ConfigMixin:
|
||||
return init_dict, unused_kwargs, hidden_config_dict
|
||||
|
||||
@classmethod
|
||||
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike], dduf_reader=None):
|
||||
if dduf_reader:
|
||||
text = dduf_reader.read_file(json_file, encoding="utf-8")
|
||||
else:
|
||||
with open(json_file, "r", encoding="utf-8") as reader:
|
||||
text = reader.read()
|
||||
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
|
||||
with open(json_file, "r", encoding="utf-8") as reader:
|
||||
text = reader.read()
|
||||
return json.loads(text)
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
@@ -236,7 +236,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
`np.ndarray` or `torch.Tensor`:
|
||||
The denormalized image array.
|
||||
"""
|
||||
return (images / 2 + 0.5).clamp(0, 1)
|
||||
return (images * 0.5 + 0.5).clamp(0, 1)
|
||||
|
||||
@staticmethod
|
||||
def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
|
||||
@@ -537,6 +537,26 @@ class VaeImageProcessor(ConfigMixin):
|
||||
|
||||
return image
|
||||
|
||||
def _denormalize_conditionally(
|
||||
self, images: torch.Tensor, do_denormalize: Optional[List[bool]] = None
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Denormalize a batch of images based on a condition list.
|
||||
|
||||
Args:
|
||||
images (`torch.Tensor`):
|
||||
The input image tensor.
|
||||
do_denormalize (`Optional[List[bool]`, *optional*, defaults to `None`):
|
||||
A list of booleans indicating whether to denormalize each image in the batch. If `None`, will use the
|
||||
value of `do_normalize` in the `VaeImageProcessor` config.
|
||||
"""
|
||||
if do_denormalize is None:
|
||||
return self.denormalize(images) if self.config.do_normalize else images
|
||||
|
||||
return torch.stack(
|
||||
[self.denormalize(images[i]) if do_denormalize[i] else images[i] for i in range(images.shape[0])]
|
||||
)
|
||||
|
||||
def get_default_height_width(
|
||||
self,
|
||||
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
||||
@@ -752,12 +772,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
if output_type == "latent":
|
||||
return image
|
||||
|
||||
if do_denormalize is None:
|
||||
do_denormalize = [self.config.do_normalize] * image.shape[0]
|
||||
|
||||
image = torch.stack(
|
||||
[self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
|
||||
)
|
||||
image = self._denormalize_conditionally(image, do_denormalize)
|
||||
|
||||
if output_type == "pt":
|
||||
return image
|
||||
@@ -966,12 +981,7 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
|
||||
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
|
||||
output_type = "np"
|
||||
|
||||
if do_denormalize is None:
|
||||
do_denormalize = [self.config.do_normalize] * image.shape[0]
|
||||
|
||||
image = torch.stack(
|
||||
[self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])]
|
||||
)
|
||||
image = self._denormalize_conditionally(image, do_denormalize)
|
||||
|
||||
image = self.pt_to_numpy(image)
|
||||
|
||||
|
||||
@@ -55,7 +55,8 @@ _import_structure = {}
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["single_file_model"] = ["FromOriginalModelMixin"]
|
||||
|
||||
_import_structure["transformer_flux"] = ["FluxTransformer2DLoadersMixin"]
|
||||
_import_structure["transformer_sd3"] = ["SD3Transformer2DLoadersMixin"]
|
||||
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
|
||||
_import_structure["utils"] = ["AttnProcsLayers"]
|
||||
if is_transformers_available():
|
||||
@@ -65,13 +66,20 @@ if is_torch_available():
|
||||
"StableDiffusionLoraLoaderMixin",
|
||||
"SD3LoraLoaderMixin",
|
||||
"StableDiffusionXLLoraLoaderMixin",
|
||||
"LTXVideoLoraLoaderMixin",
|
||||
"LoraLoaderMixin",
|
||||
"FluxLoraLoaderMixin",
|
||||
"CogVideoXLoraLoaderMixin",
|
||||
"Mochi1LoraLoaderMixin",
|
||||
"HunyuanVideoLoraLoaderMixin",
|
||||
"SanaLoraLoaderMixin",
|
||||
]
|
||||
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
|
||||
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
|
||||
_import_structure["ip_adapter"] = [
|
||||
"IPAdapterMixin",
|
||||
"FluxIPAdapterMixin",
|
||||
"SD3IPAdapterMixin",
|
||||
]
|
||||
|
||||
_import_structure["peft"] = ["PeftAdapterMixin"]
|
||||
|
||||
@@ -79,17 +87,26 @@ _import_structure["peft"] = ["PeftAdapterMixin"]
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
if is_torch_available():
|
||||
from .single_file_model import FromOriginalModelMixin
|
||||
from .transformer_flux import FluxTransformer2DLoadersMixin
|
||||
from .transformer_sd3 import SD3Transformer2DLoadersMixin
|
||||
from .unet import UNet2DConditionLoadersMixin
|
||||
from .utils import AttnProcsLayers
|
||||
|
||||
if is_transformers_available():
|
||||
from .ip_adapter import IPAdapterMixin
|
||||
from .ip_adapter import (
|
||||
FluxIPAdapterMixin,
|
||||
IPAdapterMixin,
|
||||
SD3IPAdapterMixin,
|
||||
)
|
||||
from .lora_pipeline import (
|
||||
AmusedLoraLoaderMixin,
|
||||
CogVideoXLoraLoaderMixin,
|
||||
FluxLoraLoaderMixin,
|
||||
HunyuanVideoLoraLoaderMixin,
|
||||
LoraLoaderMixin,
|
||||
LTXVideoLoraLoaderMixin,
|
||||
Mochi1LoraLoaderMixin,
|
||||
SanaLoraLoaderMixin,
|
||||
SD3LoraLoaderMixin,
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
StableDiffusionXLLoraLoaderMixin,
|
||||
|
||||
@@ -33,15 +33,20 @@ from .unet_loader_utils import _maybe_expand_lora_scales
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor, SiglipVisionModel
|
||||
|
||||
from ..models.attention_processor import (
|
||||
AttnProcessor,
|
||||
AttnProcessor2_0,
|
||||
FluxAttnProcessor2_0,
|
||||
FluxIPAdapterJointAttnProcessor2_0,
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
IPAdapterXFormersAttnProcessor,
|
||||
JointAttnProcessor2_0,
|
||||
SD3IPAdapterJointAttnProcessor2_0,
|
||||
)
|
||||
|
||||
from ..models.attention_processor import (
|
||||
AttnProcessor,
|
||||
AttnProcessor2_0,
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
IPAdapterXFormersAttnProcessor,
|
||||
)
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@@ -187,7 +192,7 @@ class IPAdapterMixin:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
if keys != ["image_proj", "ip_adapter"]:
|
||||
if "image_proj" not in keys and "ip_adapter" not in keys:
|
||||
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
|
||||
|
||||
state_dicts.append(state_dict)
|
||||
@@ -348,3 +353,519 @@ class IPAdapterMixin:
|
||||
else value.__class__()
|
||||
)
|
||||
self.unet.set_attn_processor(attn_procs)
|
||||
|
||||
|
||||
class FluxIPAdapterMixin:
|
||||
"""Mixin for handling Flux IP Adapters."""
|
||||
|
||||
@validate_hf_hub_args
|
||||
def load_ip_adapter(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
|
||||
weight_name: Union[str, List[str]],
|
||||
subfolder: Optional[Union[str, List[str]]] = "",
|
||||
image_encoder_pretrained_model_name_or_path: Optional[str] = "image_encoder",
|
||||
image_encoder_subfolder: Optional[str] = "",
|
||||
image_encoder_dtype: torch.dtype = torch.float16,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
subfolder (`str` or `List[str]`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
|
||||
list is passed, it should have the same length as `weight_name`.
|
||||
weight_name (`str` or `List[str]`):
|
||||
The name of the weight file to load. If a list is passed, it should have the same length as
|
||||
`weight_name`.
|
||||
image_encoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `./image_encoder`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `openai/clip-vit-large-patch14`) of a pretrained model
|
||||
hosted on the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
"""
|
||||
|
||||
# handle the list inputs for multiple IP Adapters
|
||||
if not isinstance(weight_name, list):
|
||||
weight_name = [weight_name]
|
||||
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, list):
|
||||
pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
|
||||
if len(pretrained_model_name_or_path_or_dict) == 1:
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
|
||||
|
||||
if not isinstance(subfolder, list):
|
||||
subfolder = [subfolder]
|
||||
if len(subfolder) == 1:
|
||||
subfolder = subfolder * len(weight_name)
|
||||
|
||||
if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
|
||||
raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
|
||||
|
||||
if len(weight_name) != len(subfolder):
|
||||
raise ValueError("`weight_name` and `subfolder` must have the same length.")
|
||||
|
||||
# Load the main state dict first.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
state_dicts = []
|
||||
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
|
||||
pretrained_model_name_or_path_or_dict, weight_name, subfolder
|
||||
):
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
if weight_name.endswith(".safetensors"):
|
||||
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
||||
with safe_open(model_file, framework="pt", device="cpu") as f:
|
||||
image_proj_keys = ["ip_adapter_proj_model.", "image_proj."]
|
||||
ip_adapter_keys = ["double_blocks.", "ip_adapter."]
|
||||
for key in f.keys():
|
||||
if any(key.startswith(prefix) for prefix in image_proj_keys):
|
||||
diffusers_name = ".".join(key.split(".")[1:])
|
||||
state_dict["image_proj"][diffusers_name] = f.get_tensor(key)
|
||||
elif any(key.startswith(prefix) for prefix in ip_adapter_keys):
|
||||
diffusers_name = (
|
||||
".".join(key.split(".")[1:])
|
||||
.replace("ip_adapter_double_stream_k_proj", "to_k_ip")
|
||||
.replace("ip_adapter_double_stream_v_proj", "to_v_ip")
|
||||
.replace("processor.", "")
|
||||
)
|
||||
state_dict["ip_adapter"][diffusers_name] = f.get_tensor(key)
|
||||
else:
|
||||
state_dict = load_state_dict(model_file)
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
if keys != ["image_proj", "ip_adapter"]:
|
||||
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
|
||||
|
||||
state_dicts.append(state_dict)
|
||||
|
||||
# load CLIP image encoder here if it has not been registered to the pipeline yet
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
|
||||
if image_encoder_pretrained_model_name_or_path is not None:
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
logger.info(f"loading image_encoder from {image_encoder_pretrained_model_name_or_path}")
|
||||
image_encoder = (
|
||||
CLIPVisionModelWithProjection.from_pretrained(
|
||||
image_encoder_pretrained_model_name_or_path,
|
||||
subfolder=image_encoder_subfolder,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
.to(self.device, dtype=image_encoder_dtype)
|
||||
.eval()
|
||||
)
|
||||
self.register_modules(image_encoder=image_encoder)
|
||||
else:
|
||||
raise ValueError(
|
||||
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
|
||||
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
|
||||
)
|
||||
|
||||
# create feature extractor if it has not been registered to the pipeline yet
|
||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
|
||||
# FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224
|
||||
default_clip_size = 224
|
||||
clip_image_size = (
|
||||
self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size
|
||||
)
|
||||
feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
|
||||
self.register_modules(feature_extractor=feature_extractor)
|
||||
|
||||
# load ip-adapter into transformer
|
||||
self.transformer._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||
|
||||
def set_ip_adapter_scale(self, scale: Union[float, List[float], List[List[float]]]):
|
||||
"""
|
||||
Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
|
||||
granular control over each IP-Adapter behavior. A config can be a float or a list.
|
||||
|
||||
`float` is converted to list and repeated for the number of blocks and the number of IP adapters. `List[float]`
|
||||
length match the number of blocks, it is repeated for each IP adapter. `List[List[float]]` must match the
|
||||
number of IP adapters and each must match the number of blocks.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
# To use original IP-Adapter
|
||||
scale = 1.0
|
||||
pipeline.set_ip_adapter_scale(scale)
|
||||
|
||||
|
||||
def LinearStrengthModel(start, finish, size):
|
||||
return [(start + (finish - start) * (i / (size - 1))) for i in range(size)]
|
||||
|
||||
|
||||
ip_strengths = LinearStrengthModel(0.3, 0.92, 19)
|
||||
pipeline.set_ip_adapter_scale(ip_strengths)
|
||||
```
|
||||
"""
|
||||
transformer = self.transformer
|
||||
if not isinstance(scale, list):
|
||||
scale = [[scale] * transformer.config.num_layers]
|
||||
elif isinstance(scale, list) and isinstance(scale[0], int) or isinstance(scale[0], float):
|
||||
if len(scale) != transformer.config.num_layers:
|
||||
raise ValueError(f"Expected list of {transformer.config.num_layers} scales, got {len(scale)}.")
|
||||
scale = [scale]
|
||||
|
||||
scale_configs = scale
|
||||
|
||||
key_id = 0
|
||||
for attn_name, attn_processor in transformer.attn_processors.items():
|
||||
if isinstance(attn_processor, (FluxIPAdapterJointAttnProcessor2_0)):
|
||||
if len(scale_configs) != len(attn_processor.scale):
|
||||
raise ValueError(
|
||||
f"Cannot assign {len(scale_configs)} scale_configs to "
|
||||
f"{len(attn_processor.scale)} IP-Adapter."
|
||||
)
|
||||
elif len(scale_configs) == 1:
|
||||
scale_configs = scale_configs * len(attn_processor.scale)
|
||||
for i, scale_config in enumerate(scale_configs):
|
||||
attn_processor.scale[i] = scale_config[key_id]
|
||||
key_id += 1
|
||||
|
||||
def unload_ip_adapter(self):
|
||||
"""
|
||||
Unloads the IP Adapter weights
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
|
||||
>>> pipeline.unload_ip_adapter()
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
# remove CLIP image encoder
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
|
||||
self.image_encoder = None
|
||||
self.register_to_config(image_encoder=[None, None])
|
||||
|
||||
# remove feature extractor only when safety_checker is None as safety_checker uses
|
||||
# the feature_extractor later
|
||||
if not hasattr(self, "safety_checker"):
|
||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
|
||||
self.feature_extractor = None
|
||||
self.register_to_config(feature_extractor=[None, None])
|
||||
|
||||
# remove hidden encoder
|
||||
self.transformer.encoder_hid_proj = None
|
||||
self.transformer.config.encoder_hid_dim_type = None
|
||||
|
||||
# restore original Transformer attention processors layers
|
||||
attn_procs = {}
|
||||
for name, value in self.transformer.attn_processors.items():
|
||||
attn_processor_class = FluxAttnProcessor2_0()
|
||||
attn_procs[name] = (
|
||||
attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__()
|
||||
)
|
||||
self.transformer.set_attn_processor(attn_procs)
|
||||
|
||||
|
||||
class SD3IPAdapterMixin:
|
||||
"""Mixin for handling StableDiffusion 3 IP Adapters."""
|
||||
|
||||
@property
|
||||
def is_ip_adapter_active(self) -> bool:
|
||||
"""Checks if IP-Adapter is loaded and scale > 0.
|
||||
|
||||
IP-Adapter scale controls the influence of the image prompt versus text prompt. When this value is set to 0,
|
||||
the image context is irrelevant.
|
||||
|
||||
Returns:
|
||||
`bool`: True when IP-Adapter is loaded and any layer has scale > 0.
|
||||
"""
|
||||
scales = [
|
||||
attn_proc.scale
|
||||
for attn_proc in self.transformer.attn_processors.values()
|
||||
if isinstance(attn_proc, SD3IPAdapterJointAttnProcessor2_0)
|
||||
]
|
||||
|
||||
return len(scales) > 0 and any(scale > 0 for scale in scales)
|
||||
|
||||
@validate_hf_hub_args
|
||||
def load_ip_adapter(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
weight_name: str = "ip-adapter.safetensors",
|
||||
subfolder: Optional[str] = None,
|
||||
image_encoder_folder: Optional[str] = "image_encoder",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
Can be either:
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
weight_name (`str`, defaults to "ip-adapter.safetensors"):
|
||||
The name of the weight file to load. If a list is passed, it should have the same length as
|
||||
`subfolder`.
|
||||
subfolder (`str`, *optional*):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
|
||||
list is passed, it should have the same length as `weight_name`.
|
||||
image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
|
||||
The subfolder location of the image encoder within a larger model repository on the Hub or locally.
|
||||
Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
|
||||
`subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g.
|
||||
`image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than
|
||||
`subfolder`, you should pass the path to the folder that contains image encoder weights, for example,
|
||||
`image_encoder_folder="different_subfolder/image_encoder"`.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
"""
|
||||
# Load the main state dict first
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
if weight_name.endswith(".safetensors"):
|
||||
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
||||
with safe_open(model_file, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
if key.startswith("image_proj."):
|
||||
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
||||
elif key.startswith("ip_adapter."):
|
||||
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
||||
else:
|
||||
state_dict = load_state_dict(model_file)
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
if "image_proj" not in keys and "ip_adapter" not in keys:
|
||||
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
|
||||
|
||||
# Load image_encoder and feature_extractor here if they haven't been registered to the pipeline yet
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
|
||||
if image_encoder_folder is not None:
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
|
||||
if image_encoder_folder.count("/") == 0:
|
||||
image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()
|
||||
else:
|
||||
image_encoder_subfolder = Path(image_encoder_folder).as_posix()
|
||||
|
||||
# Commons args for loading image encoder and image processor
|
||||
kwargs = {
|
||||
"low_cpu_mem_usage": low_cpu_mem_usage,
|
||||
"cache_dir": cache_dir,
|
||||
"local_files_only": local_files_only,
|
||||
}
|
||||
|
||||
self.register_modules(
|
||||
feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs).to(
|
||||
self.device, dtype=self.dtype
|
||||
),
|
||||
image_encoder=SiglipVisionModel.from_pretrained(image_encoder_subfolder, **kwargs).to(
|
||||
self.device, dtype=self.dtype
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
|
||||
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
|
||||
)
|
||||
|
||||
# Load IP-Adapter into transformer
|
||||
self.transformer._load_ip_adapter_weights(state_dict, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||
|
||||
def set_ip_adapter_scale(self, scale: float) -> None:
|
||||
"""
|
||||
Set IP-Adapter scale, which controls image prompt conditioning. A value of 1.0 means the model is only
|
||||
conditioned on the image prompt, and 0.0 only conditioned by the text prompt. Lowering this value encourages
|
||||
the model to produce more diverse images, but they may not be as aligned with the image prompt.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
|
||||
>>> pipeline.set_ip_adapter_scale(0.6)
|
||||
>>> ...
|
||||
```
|
||||
|
||||
Args:
|
||||
scale (float):
|
||||
IP-Adapter scale to be set.
|
||||
|
||||
"""
|
||||
for attn_processor in self.transformer.attn_processors.values():
|
||||
if isinstance(attn_processor, SD3IPAdapterJointAttnProcessor2_0):
|
||||
attn_processor.scale = scale
|
||||
|
||||
def unload_ip_adapter(self) -> None:
|
||||
"""
|
||||
Unloads the IP Adapter weights.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
|
||||
>>> pipeline.unload_ip_adapter()
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
# Remove image encoder
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
|
||||
self.image_encoder = None
|
||||
self.register_to_config(image_encoder=None)
|
||||
|
||||
# Remove feature extractor
|
||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
|
||||
self.feature_extractor = None
|
||||
self.register_to_config(feature_extractor=None)
|
||||
|
||||
# Remove image projection
|
||||
self.transformer.image_proj = None
|
||||
|
||||
# Restore original attention processors layers
|
||||
attn_procs = {
|
||||
name: (
|
||||
JointAttnProcessor2_0() if isinstance(value, SD3IPAdapterJointAttnProcessor2_0) else value.__class__()
|
||||
)
|
||||
for name, value in self.transformer.attn_processors.items()
|
||||
}
|
||||
self.transformer.set_attn_processor(attn_procs)
|
||||
|
||||
@@ -643,7 +643,11 @@ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
|
||||
old_state_dict,
|
||||
new_state_dict,
|
||||
old_key,
|
||||
[f"transformer.single_transformer_blocks.{block_num}.norm.linear"],
|
||||
[
|
||||
f"transformer.single_transformer_blocks.{block_num}.attn.to_q",
|
||||
f"transformer.single_transformer_blocks.{block_num}.attn.to_k",
|
||||
f"transformer.single_transformer_blocks.{block_num}.attn.to_v",
|
||||
],
|
||||
)
|
||||
|
||||
if "down" in old_key:
|
||||
@@ -663,3 +667,309 @@ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
|
||||
raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.")
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
|
||||
converted_state_dict = {}
|
||||
original_state_dict_keys = list(original_state_dict.keys())
|
||||
num_layers = 19
|
||||
num_single_layers = 38
|
||||
inner_dim = 3072
|
||||
mlp_ratio = 4.0
|
||||
|
||||
def swap_scale_shift(weight):
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
return new_weight
|
||||
|
||||
for lora_key in ["lora_A", "lora_B"]:
|
||||
## time_text_embed.timestep_embedder <- time_in
|
||||
converted_state_dict[
|
||||
f"time_text_embed.timestep_embedder.linear_1.{lora_key}.weight"
|
||||
] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.weight")
|
||||
if f"time_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[
|
||||
f"time_text_embed.timestep_embedder.linear_1.{lora_key}.bias"
|
||||
] = original_state_dict.pop(f"time_in.in_layer.{lora_key}.bias")
|
||||
|
||||
converted_state_dict[
|
||||
f"time_text_embed.timestep_embedder.linear_2.{lora_key}.weight"
|
||||
] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.weight")
|
||||
if f"time_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[
|
||||
f"time_text_embed.timestep_embedder.linear_2.{lora_key}.bias"
|
||||
] = original_state_dict.pop(f"time_in.out_layer.{lora_key}.bias")
|
||||
|
||||
## time_text_embed.text_embedder <- vector_in
|
||||
converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"vector_in.in_layer.{lora_key}.weight"
|
||||
)
|
||||
if f"vector_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"time_text_embed.text_embedder.linear_1.{lora_key}.bias"] = original_state_dict.pop(
|
||||
f"vector_in.in_layer.{lora_key}.bias"
|
||||
)
|
||||
|
||||
converted_state_dict[f"time_text_embed.text_embedder.linear_2.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"vector_in.out_layer.{lora_key}.weight"
|
||||
)
|
||||
if f"vector_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"time_text_embed.text_embedder.linear_2.{lora_key}.bias"] = original_state_dict.pop(
|
||||
f"vector_in.out_layer.{lora_key}.bias"
|
||||
)
|
||||
|
||||
# guidance
|
||||
has_guidance = any("guidance" in k for k in original_state_dict)
|
||||
if has_guidance:
|
||||
converted_state_dict[
|
||||
f"time_text_embed.guidance_embedder.linear_1.{lora_key}.weight"
|
||||
] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.weight")
|
||||
if f"guidance_in.in_layer.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[
|
||||
f"time_text_embed.guidance_embedder.linear_1.{lora_key}.bias"
|
||||
] = original_state_dict.pop(f"guidance_in.in_layer.{lora_key}.bias")
|
||||
|
||||
converted_state_dict[
|
||||
f"time_text_embed.guidance_embedder.linear_2.{lora_key}.weight"
|
||||
] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.weight")
|
||||
if f"guidance_in.out_layer.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[
|
||||
f"time_text_embed.guidance_embedder.linear_2.{lora_key}.bias"
|
||||
] = original_state_dict.pop(f"guidance_in.out_layer.{lora_key}.bias")
|
||||
|
||||
# context_embedder
|
||||
converted_state_dict[f"context_embedder.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"txt_in.{lora_key}.weight"
|
||||
)
|
||||
if f"txt_in.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"context_embedder.{lora_key}.bias"] = original_state_dict.pop(
|
||||
f"txt_in.{lora_key}.bias"
|
||||
)
|
||||
|
||||
# x_embedder
|
||||
converted_state_dict[f"x_embedder.{lora_key}.weight"] = original_state_dict.pop(f"img_in.{lora_key}.weight")
|
||||
if f"img_in.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"x_embedder.{lora_key}.bias"] = original_state_dict.pop(f"img_in.{lora_key}.bias")
|
||||
|
||||
# double transformer blocks
|
||||
for i in range(num_layers):
|
||||
block_prefix = f"transformer_blocks.{i}."
|
||||
|
||||
for lora_key in ["lora_A", "lora_B"]:
|
||||
# norms
|
||||
converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"double_blocks.{i}.img_mod.lin.{lora_key}.weight"
|
||||
)
|
||||
if f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.bias"] = original_state_dict.pop(
|
||||
f"double_blocks.{i}.img_mod.lin.{lora_key}.bias"
|
||||
)
|
||||
|
||||
converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"double_blocks.{i}.txt_mod.lin.{lora_key}.weight"
|
||||
)
|
||||
if f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.bias"] = original_state_dict.pop(
|
||||
f"double_blocks.{i}.txt_mod.lin.{lora_key}.bias"
|
||||
)
|
||||
|
||||
# Q, K, V
|
||||
if lora_key == "lora_A":
|
||||
sample_lora_weight = original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight")
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_lora_weight])
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_lora_weight])
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_lora_weight])
|
||||
|
||||
context_lora_weight = original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight")
|
||||
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat(
|
||||
[context_lora_weight]
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat(
|
||||
[context_lora_weight]
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat(
|
||||
[context_lora_weight]
|
||||
)
|
||||
else:
|
||||
sample_q, sample_k, sample_v = torch.chunk(
|
||||
original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.weight"), 3, dim=0
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_q])
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_k])
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_v])
|
||||
|
||||
context_q, context_k, context_v = torch.chunk(
|
||||
original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"), 3, dim=0
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat([context_q])
|
||||
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat([context_k])
|
||||
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat([context_v])
|
||||
|
||||
if f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
|
||||
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
|
||||
original_state_dict.pop(f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias"), 3, dim=0
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([sample_q_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([sample_k_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([sample_v_bias])
|
||||
|
||||
if f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
|
||||
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
|
||||
original_state_dict.pop(f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias"), 3, dim=0
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.bias"] = torch.cat([context_q_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.bias"] = torch.cat([context_k_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.bias"] = torch.cat([context_v_bias])
|
||||
|
||||
# ff img_mlp
|
||||
converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"double_blocks.{i}.img_mlp.0.{lora_key}.weight"
|
||||
)
|
||||
if f"double_blocks.{i}.img_mlp.0.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
|
||||
f"double_blocks.{i}.img_mlp.0.{lora_key}.bias"
|
||||
)
|
||||
|
||||
converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"double_blocks.{i}.img_mlp.2.{lora_key}.weight"
|
||||
)
|
||||
if f"double_blocks.{i}.img_mlp.2.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.bias"] = original_state_dict.pop(
|
||||
f"double_blocks.{i}.img_mlp.2.{lora_key}.bias"
|
||||
)
|
||||
|
||||
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"double_blocks.{i}.txt_mlp.0.{lora_key}.weight"
|
||||
)
|
||||
if f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
|
||||
f"double_blocks.{i}.txt_mlp.0.{lora_key}.bias"
|
||||
)
|
||||
|
||||
converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"double_blocks.{i}.txt_mlp.2.{lora_key}.weight"
|
||||
)
|
||||
if f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.bias"] = original_state_dict.pop(
|
||||
f"double_blocks.{i}.txt_mlp.2.{lora_key}.bias"
|
||||
)
|
||||
|
||||
# output projections.
|
||||
converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"double_blocks.{i}.img_attn.proj.{lora_key}.weight"
|
||||
)
|
||||
if f"double_blocks.{i}.img_attn.proj.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.bias"] = original_state_dict.pop(
|
||||
f"double_blocks.{i}.img_attn.proj.{lora_key}.bias"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"double_blocks.{i}.txt_attn.proj.{lora_key}.weight"
|
||||
)
|
||||
if f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.bias"] = original_state_dict.pop(
|
||||
f"double_blocks.{i}.txt_attn.proj.{lora_key}.bias"
|
||||
)
|
||||
|
||||
# qk_norm
|
||||
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
|
||||
f"double_blocks.{i}.img_attn.norm.query_norm.scale"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
|
||||
f"double_blocks.{i}.img_attn.norm.key_norm.scale"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop(
|
||||
f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop(
|
||||
f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
|
||||
)
|
||||
|
||||
# single transfomer blocks
|
||||
for i in range(num_single_layers):
|
||||
block_prefix = f"single_transformer_blocks.{i}."
|
||||
|
||||
for lora_key in ["lora_A", "lora_B"]:
|
||||
# norm.linear <- single_blocks.0.modulation.lin
|
||||
converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"single_blocks.{i}.modulation.lin.{lora_key}.weight"
|
||||
)
|
||||
if f"single_blocks.{i}.modulation.lin.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.bias"] = original_state_dict.pop(
|
||||
f"single_blocks.{i}.modulation.lin.{lora_key}.bias"
|
||||
)
|
||||
|
||||
# Q, K, V, mlp
|
||||
mlp_hidden_dim = int(inner_dim * mlp_ratio)
|
||||
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
|
||||
|
||||
if lora_key == "lora_A":
|
||||
lora_weight = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.weight")
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([lora_weight])
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([lora_weight])
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([lora_weight])
|
||||
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([lora_weight])
|
||||
|
||||
if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
|
||||
lora_bias = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias")
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([lora_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([lora_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([lora_bias])
|
||||
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([lora_bias])
|
||||
else:
|
||||
q, k, v, mlp = torch.split(
|
||||
original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.weight"), split_size, dim=0
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([q])
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([k])
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([v])
|
||||
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([mlp])
|
||||
|
||||
if f"single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
|
||||
q_bias, k_bias, v_bias, mlp_bias = torch.split(
|
||||
original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias"), split_size, dim=0
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([q_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([k_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([v_bias])
|
||||
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([mlp_bias])
|
||||
|
||||
# output projections.
|
||||
converted_state_dict[f"{block_prefix}proj_out.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"single_blocks.{i}.linear2.{lora_key}.weight"
|
||||
)
|
||||
if f"single_blocks.{i}.linear2.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"{block_prefix}proj_out.{lora_key}.bias"] = original_state_dict.pop(
|
||||
f"single_blocks.{i}.linear2.{lora_key}.bias"
|
||||
)
|
||||
|
||||
# qk norm
|
||||
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
|
||||
f"single_blocks.{i}.norm.query_norm.scale"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
|
||||
f"single_blocks.{i}.norm.key_norm.scale"
|
||||
)
|
||||
|
||||
for lora_key in ["lora_A", "lora_B"]:
|
||||
converted_state_dict[f"proj_out.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"final_layer.linear.{lora_key}.weight"
|
||||
)
|
||||
if f"final_layer.linear.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"proj_out.{lora_key}.bias"] = original_state_dict.pop(
|
||||
f"final_layer.linear.{lora_key}.bias"
|
||||
)
|
||||
|
||||
converted_state_dict[f"norm_out.linear.{lora_key}.weight"] = swap_scale_shift(
|
||||
original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.weight")
|
||||
)
|
||||
if f"final_layer.adaLN_modulation.1.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"norm_out.linear.{lora_key}.bias"] = swap_scale_shift(
|
||||
original_state_dict.pop(f"final_layer.adaLN_modulation.1.{lora_key}.bias")
|
||||
)
|
||||
|
||||
if len(original_state_dict) > 0:
|
||||
raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")
|
||||
|
||||
for key in list(converted_state_dict.keys()):
|
||||
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -53,9 +53,63 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
|
||||
"FluxTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"CogVideoXTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"MochiTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"SanaTransformer2DModel": lambda model_cls, weights: weights,
|
||||
}
|
||||
|
||||
|
||||
def _maybe_adjust_config(config):
|
||||
"""
|
||||
We may run into some ambiguous configuration values when a model has module names, sharing a common prefix
|
||||
(`proj_out.weight` and `blocks.transformer.proj_out.weight`, for example) and they have different LoRA ranks. This
|
||||
method removes the ambiguity by following what is described here:
|
||||
https://github.com/huggingface/diffusers/pull/9985#issuecomment-2493840028.
|
||||
"""
|
||||
rank_pattern = config["rank_pattern"].copy()
|
||||
target_modules = config["target_modules"]
|
||||
original_r = config["r"]
|
||||
|
||||
for key in list(rank_pattern.keys()):
|
||||
key_rank = rank_pattern[key]
|
||||
|
||||
# try to detect ambiguity
|
||||
# `target_modules` can also be a str, in which case this loop would loop
|
||||
# over the chars of the str. The technically correct way to match LoRA keys
|
||||
# in PEFT is to use LoraModel._check_target_module_exists (lora_config, key).
|
||||
# But this cuts it for now.
|
||||
exact_matches = [mod for mod in target_modules if mod == key]
|
||||
substring_matches = [mod for mod in target_modules if key in mod and mod != key]
|
||||
ambiguous_key = key
|
||||
|
||||
if exact_matches and substring_matches:
|
||||
# if ambiguous we update the rank associated with the ambiguous key (`proj_out`, for example)
|
||||
config["r"] = key_rank
|
||||
# remove the ambiguous key from `rank_pattern` and update its rank to `r`, instead
|
||||
del config["rank_pattern"][key]
|
||||
for mod in substring_matches:
|
||||
# avoid overwriting if the module already has a specific rank
|
||||
if mod not in config["rank_pattern"]:
|
||||
config["rank_pattern"][mod] = original_r
|
||||
|
||||
# update the rest of the keys with the `original_r`
|
||||
for mod in target_modules:
|
||||
if mod != ambiguous_key and mod not in config["rank_pattern"]:
|
||||
config["rank_pattern"][mod] = original_r
|
||||
|
||||
# handle alphas to deal with cases like
|
||||
# https://github.com/huggingface/diffusers/pull/9999#issuecomment-2516180777
|
||||
has_different_ranks = len(config["rank_pattern"]) > 1 and list(config["rank_pattern"])[0] != config["r"]
|
||||
if has_different_ranks:
|
||||
config["lora_alpha"] = config["r"]
|
||||
alpha_pattern = {}
|
||||
for module_name, rank in config["rank_pattern"].items():
|
||||
alpha_pattern[module_name] = rank
|
||||
config["alpha_pattern"] = alpha_pattern
|
||||
|
||||
return config
|
||||
|
||||
|
||||
class PeftAdapterMixin:
|
||||
"""
|
||||
A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
|
||||
@@ -154,6 +208,7 @@ class PeftAdapterMixin:
|
||||
weights.
|
||||
"""
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
@@ -216,7 +271,9 @@ class PeftAdapterMixin:
|
||||
|
||||
rank = {}
|
||||
for key, val in state_dict.items():
|
||||
if "lora_B" in key:
|
||||
# Cannot figure out rank from lora layers that don't have atleast 2 dimensions.
|
||||
# Bias layers in LoRA only have a single dimension
|
||||
if "lora_B" in key and val.ndim > 1:
|
||||
rank[key] = val.shape[1]
|
||||
|
||||
if network_alphas is not None and len(network_alphas) >= 1:
|
||||
@@ -224,6 +281,8 @@ class PeftAdapterMixin:
|
||||
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
|
||||
lora_config_kwargs = _maybe_adjust_config(lora_config_kwargs)
|
||||
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
@@ -233,8 +292,18 @@ class PeftAdapterMixin:
|
||||
else:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
if "lora_bias" in lora_config_kwargs:
|
||||
if lora_config_kwargs["lora_bias"]:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.14.0 at least to use `lora_bias` in LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
lora_config_kwargs.pop("lora_bias")
|
||||
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(self)
|
||||
@@ -251,8 +320,22 @@ class PeftAdapterMixin:
|
||||
if is_peft_version(">=", "0.13.1"):
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
|
||||
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
|
||||
# To handle scenarios where we cannot successfully set state dict. If it's unsucessful,
|
||||
# we should also delete the `peft_config` associated to the `adapter_name`.
|
||||
try:
|
||||
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
|
||||
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
|
||||
except RuntimeError as e:
|
||||
for module in self.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
active_adapters = module.active_adapters
|
||||
for active_adapter in active_adapters:
|
||||
if adapter_name in active_adapter:
|
||||
module.delete_adapter(adapter_name)
|
||||
|
||||
self.peft_config.pop(adapter_name)
|
||||
logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}")
|
||||
raise
|
||||
|
||||
warn_msg = ""
|
||||
if incompatible_keys is not None:
|
||||
|
||||
@@ -17,16 +17,22 @@ import re
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..quantizers import DiffusersAutoQuantizer
|
||||
from ..utils import deprecate, is_accelerate_available, logging
|
||||
from .single_file_utils import (
|
||||
SingleFileComponentError,
|
||||
convert_animatediff_checkpoint_to_diffusers,
|
||||
convert_autoencoder_dc_checkpoint_to_diffusers,
|
||||
convert_controlnet_checkpoint,
|
||||
convert_flux_transformer_checkpoint_to_diffusers,
|
||||
convert_ldm_unet_checkpoint,
|
||||
convert_ldm_vae_checkpoint,
|
||||
convert_ltx_transformer_checkpoint_to_diffusers,
|
||||
convert_ltx_vae_checkpoint_to_diffusers,
|
||||
convert_mochi_transformer_checkpoint_to_diffusers,
|
||||
convert_sd3_transformer_checkpoint_to_diffusers,
|
||||
convert_stable_cascade_unet_single_file_to_diffusers,
|
||||
create_controlnet_diffusers_config_from_ldm,
|
||||
@@ -82,6 +88,19 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"LTXVideoTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"AutoencoderKLLTXVideo": {
|
||||
"checkpoint_mapping_fn": convert_ltx_vae_checkpoint_to_diffusers,
|
||||
"default_subfolder": "vae",
|
||||
},
|
||||
"AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers},
|
||||
"MochiTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -202,6 +221,8 @@ class FromOriginalModelMixin:
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
device = kwargs.pop("device", None)
|
||||
|
||||
if isinstance(pretrained_model_link_or_path_or_dict, dict):
|
||||
checkpoint = pretrained_model_link_or_path_or_dict
|
||||
@@ -215,11 +236,17 @@ class FromOriginalModelMixin:
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
)
|
||||
if quantization_config is not None:
|
||||
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
|
||||
hf_quantizer.validate_environment()
|
||||
|
||||
else:
|
||||
hf_quantizer = None
|
||||
|
||||
mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name]
|
||||
|
||||
checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"]
|
||||
if original_config:
|
||||
if original_config is not None:
|
||||
if "config_mapping_fn" in mapping_functions:
|
||||
config_mapping_fn = mapping_functions["config_mapping_fn"]
|
||||
else:
|
||||
@@ -243,7 +270,7 @@ class FromOriginalModelMixin:
|
||||
original_config=original_config, checkpoint=checkpoint, **config_mapping_kwargs
|
||||
)
|
||||
else:
|
||||
if config:
|
||||
if config is not None:
|
||||
if isinstance(config, str):
|
||||
default_pretrained_model_config_name = config
|
||||
else:
|
||||
@@ -270,6 +297,7 @@ class FromOriginalModelMixin:
|
||||
subfolder=subfolder,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
)
|
||||
expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)
|
||||
|
||||
@@ -296,8 +324,36 @@ class FromOriginalModelMixin:
|
||||
with ctx():
|
||||
model = cls.from_config(diffusers_model_config)
|
||||
|
||||
# Check if `_keep_in_fp32_modules` is not None
|
||||
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
|
||||
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
|
||||
)
|
||||
if use_keep_in_fp32_modules:
|
||||
keep_in_fp32_modules = cls._keep_in_fp32_modules
|
||||
if not isinstance(keep_in_fp32_modules, list):
|
||||
keep_in_fp32_modules = [keep_in_fp32_modules]
|
||||
|
||||
else:
|
||||
keep_in_fp32_modules = []
|
||||
|
||||
if hf_quantizer is not None:
|
||||
hf_quantizer.preprocess_model(
|
||||
model=model,
|
||||
device_map=None,
|
||||
state_dict=diffusers_format_checkpoint,
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
)
|
||||
|
||||
if is_accelerate_available():
|
||||
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
||||
param_device = torch.device(device) if device else torch.device("cpu")
|
||||
unexpected_keys = load_model_dict_into_meta(
|
||||
model,
|
||||
diffusers_format_checkpoint,
|
||||
dtype=torch_dtype,
|
||||
device=param_device,
|
||||
hf_quantizer=hf_quantizer,
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
)
|
||||
|
||||
else:
|
||||
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
|
||||
@@ -311,7 +367,11 @@ class FromOriginalModelMixin:
|
||||
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
|
||||
if torch_dtype is not None:
|
||||
if hf_quantizer is not None:
|
||||
hf_quantizer.postprocess_model(model)
|
||||
model.hf_quantizer = hf_quantizer
|
||||
|
||||
if torch_dtype is not None and hf_quantizer is None:
|
||||
model.to(torch_dtype)
|
||||
|
||||
model.eval()
|
||||
|
||||
@@ -81,8 +81,14 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"open_clip_sd3": "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight",
|
||||
"stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
|
||||
"stable_cascade_stage_c": "clip_txt_mapper.weight",
|
||||
"sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
|
||||
"sd35_large": "model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight",
|
||||
"sd3": [
|
||||
"joint_blocks.0.context_block.adaLN_modulation.1.bias",
|
||||
"model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
|
||||
],
|
||||
"sd35_large": [
|
||||
"joint_blocks.37.x_block.mlp.fc1.weight",
|
||||
"model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight",
|
||||
],
|
||||
"animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe",
|
||||
"animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
|
||||
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
|
||||
@@ -92,6 +98,16 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"double_blocks.0.img_attn.norm.key_norm.scale",
|
||||
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
|
||||
],
|
||||
"ltx-video": [
|
||||
"model.diffusion_model.patchify_proj.weight",
|
||||
"model.diffusion_model.transformer_blocks.27.scale_shift_table",
|
||||
"patchify_proj.weight",
|
||||
"transformer_blocks.27.scale_shift_table",
|
||||
"vae.per_channel_statistics.mean-of-means",
|
||||
],
|
||||
"autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
|
||||
"autoencoder-dc-sana": "encoder.project_in.conv.bias",
|
||||
"mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"],
|
||||
}
|
||||
|
||||
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
@@ -137,7 +153,15 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"},
|
||||
"animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
|
||||
"flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
|
||||
"flux-fill": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Fill-dev"},
|
||||
"flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"},
|
||||
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
|
||||
"ltx-video": {"pretrained_model_name_or_path": "Lightricks/LTX-Video"},
|
||||
"autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"},
|
||||
"autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
|
||||
"autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
|
||||
"autoencoder-dc-f32c32-sana": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"},
|
||||
"mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"},
|
||||
}
|
||||
|
||||
# Use to configure model sample size when original config is provided
|
||||
@@ -529,13 +553,20 @@ def infer_diffusers_model_type(checkpoint):
|
||||
):
|
||||
model_type = "stable_cascade_stage_b"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["sd3"]].shape[-1] == 9216:
|
||||
if checkpoint["model.diffusion_model.pos_embed"].shape[1] == 36864:
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd3"]) and any(
|
||||
checkpoint[key].shape[-1] == 9216 if key in checkpoint else False for key in CHECKPOINT_KEY_NAMES["sd3"]
|
||||
):
|
||||
if "model.diffusion_model.pos_embed" in checkpoint:
|
||||
key = "model.diffusion_model.pos_embed"
|
||||
else:
|
||||
key = "pos_embed"
|
||||
|
||||
if checkpoint[key].shape[1] == 36864:
|
||||
model_type = "sd3"
|
||||
elif checkpoint["model.diffusion_model.pos_embed"].shape[1] == 147456:
|
||||
elif checkpoint[key].shape[1] == 147456:
|
||||
model_type = "sd35_medium"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["sd35_large"] in checkpoint:
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd35_large"]):
|
||||
model_type = "sd35_large"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint:
|
||||
@@ -561,9 +592,38 @@ def infer_diffusers_model_type(checkpoint):
|
||||
if any(
|
||||
g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"]
|
||||
):
|
||||
model_type = "flux-dev"
|
||||
if checkpoint["img_in.weight"].shape[1] == 384:
|
||||
model_type = "flux-fill"
|
||||
|
||||
elif checkpoint["img_in.weight"].shape[1] == 128:
|
||||
model_type = "flux-depth"
|
||||
else:
|
||||
model_type = "flux-dev"
|
||||
else:
|
||||
model_type = "flux-schnell"
|
||||
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]):
|
||||
model_type = "ltx-video"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["autoencoder-dc"] in checkpoint:
|
||||
encoder_key = "encoder.project_in.conv.conv.bias"
|
||||
decoder_key = "decoder.project_in.main.conv.weight"
|
||||
|
||||
if CHECKPOINT_KEY_NAMES["autoencoder-dc-sana"] in checkpoint:
|
||||
model_type = "autoencoder-dc-f32c32-sana"
|
||||
|
||||
elif checkpoint[encoder_key].shape[-1] == 64 and checkpoint[decoder_key].shape[1] == 32:
|
||||
model_type = "autoencoder-dc-f32c32"
|
||||
|
||||
elif checkpoint[encoder_key].shape[-1] == 64 and checkpoint[decoder_key].shape[1] == 128:
|
||||
model_type = "autoencoder-dc-f64c128"
|
||||
|
||||
else:
|
||||
model_type = "autoencoder-dc-f128c512"
|
||||
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["mochi-1-preview"]):
|
||||
model_type = "mochi-1-preview"
|
||||
|
||||
else:
|
||||
model_type = "v1"
|
||||
|
||||
@@ -1704,6 +1764,12 @@ def swap_scale_shift(weight, dim):
|
||||
return new_weight
|
||||
|
||||
|
||||
def swap_proj_gate(weight):
|
||||
proj, gate = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([gate, proj], dim=0)
|
||||
return new_weight
|
||||
|
||||
|
||||
def get_attn2_layers(state_dict):
|
||||
attn2_layers = []
|
||||
for key in state_dict.keys():
|
||||
@@ -2198,3 +2264,261 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_ltx_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "vae" not in key}
|
||||
|
||||
TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"model.diffusion_model.": "",
|
||||
"patchify_proj": "proj_in",
|
||||
"adaln_single": "time_embed",
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
for key in list(converted_state_dict.keys()):
|
||||
new_key = key
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
converted_state_dict[new_key] = converted_state_dict.pop(key)
|
||||
|
||||
for key in list(converted_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, converted_state_dict)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "vae." in key}
|
||||
|
||||
def remove_keys_(key: str, state_dict):
|
||||
state_dict.pop(key)
|
||||
|
||||
VAE_KEYS_RENAME_DICT = {
|
||||
# common
|
||||
"vae.": "",
|
||||
# decoder
|
||||
"up_blocks.0": "mid_block",
|
||||
"up_blocks.1": "up_blocks.0",
|
||||
"up_blocks.2": "up_blocks.1.upsamplers.0",
|
||||
"up_blocks.3": "up_blocks.1",
|
||||
"up_blocks.4": "up_blocks.2.conv_in",
|
||||
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
||||
"up_blocks.6": "up_blocks.2",
|
||||
"up_blocks.7": "up_blocks.3.conv_in",
|
||||
"up_blocks.8": "up_blocks.3.upsamplers.0",
|
||||
"up_blocks.9": "up_blocks.3",
|
||||
# encoder
|
||||
"down_blocks.0": "down_blocks.0",
|
||||
"down_blocks.1": "down_blocks.0.downsamplers.0",
|
||||
"down_blocks.2": "down_blocks.0.conv_out",
|
||||
"down_blocks.3": "down_blocks.1",
|
||||
"down_blocks.4": "down_blocks.1.downsamplers.0",
|
||||
"down_blocks.5": "down_blocks.1.conv_out",
|
||||
"down_blocks.6": "down_blocks.2",
|
||||
"down_blocks.7": "down_blocks.2.downsamplers.0",
|
||||
"down_blocks.8": "down_blocks.3",
|
||||
"down_blocks.9": "mid_block",
|
||||
# common
|
||||
"conv_shortcut": "conv_shortcut.conv",
|
||||
"res_blocks": "resnets",
|
||||
"norm3.norm": "norm3",
|
||||
"per_channel_statistics.mean-of-means": "latents_mean",
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
}
|
||||
|
||||
VAE_SPECIAL_KEYS_REMAP = {
|
||||
"per_channel_statistics.channel": remove_keys_,
|
||||
"per_channel_statistics.mean-of-means": remove_keys_,
|
||||
"per_channel_statistics.mean-of-stds": remove_keys_,
|
||||
}
|
||||
|
||||
for key in list(converted_state_dict.keys()):
|
||||
new_key = key
|
||||
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
converted_state_dict[new_key] = converted_state_dict.pop(key)
|
||||
|
||||
for key in list(converted_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, converted_state_dict)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_autoencoder_dc_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
|
||||
|
||||
def remap_qkv_(key: str, state_dict):
|
||||
qkv = state_dict.pop(key)
|
||||
q, k, v = torch.chunk(qkv, 3, dim=0)
|
||||
parent_module, _, _ = key.rpartition(".qkv.conv.weight")
|
||||
state_dict[f"{parent_module}.to_q.weight"] = q.squeeze()
|
||||
state_dict[f"{parent_module}.to_k.weight"] = k.squeeze()
|
||||
state_dict[f"{parent_module}.to_v.weight"] = v.squeeze()
|
||||
|
||||
def remap_proj_conv_(key: str, state_dict):
|
||||
parent_module, _, _ = key.rpartition(".proj.conv.weight")
|
||||
state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze()
|
||||
|
||||
AE_KEYS_RENAME_DICT = {
|
||||
# common
|
||||
"main.": "",
|
||||
"op_list.": "",
|
||||
"context_module": "attn",
|
||||
"local_module": "conv_out",
|
||||
# NOTE: The below two lines work because scales in the available configs only have a tuple length of 1
|
||||
# If there were more scales, there would be more layers, so a loop would be better to handle this
|
||||
"aggreg.0.0": "to_qkv_multiscale.0.proj_in",
|
||||
"aggreg.0.1": "to_qkv_multiscale.0.proj_out",
|
||||
"depth_conv.conv": "conv_depth",
|
||||
"inverted_conv.conv": "conv_inverted",
|
||||
"point_conv.conv": "conv_point",
|
||||
"point_conv.norm": "norm",
|
||||
"conv.conv.": "conv.",
|
||||
"conv1.conv": "conv1",
|
||||
"conv2.conv": "conv2",
|
||||
"conv2.norm": "norm",
|
||||
"proj.norm": "norm_out",
|
||||
# encoder
|
||||
"encoder.project_in.conv": "encoder.conv_in",
|
||||
"encoder.project_out.0.conv": "encoder.conv_out",
|
||||
"encoder.stages": "encoder.down_blocks",
|
||||
# decoder
|
||||
"decoder.project_in.conv": "decoder.conv_in",
|
||||
"decoder.project_out.0": "decoder.norm_out",
|
||||
"decoder.project_out.2.conv": "decoder.conv_out",
|
||||
"decoder.stages": "decoder.up_blocks",
|
||||
}
|
||||
|
||||
AE_F32C32_F64C128_F128C512_KEYS = {
|
||||
"encoder.project_in.conv": "encoder.conv_in.conv",
|
||||
"decoder.project_out.2.conv": "decoder.conv_out.conv",
|
||||
}
|
||||
|
||||
AE_SPECIAL_KEYS_REMAP = {
|
||||
"qkv.conv.weight": remap_qkv_,
|
||||
"proj.conv.weight": remap_proj_conv_,
|
||||
}
|
||||
if "encoder.project_in.conv.bias" not in converted_state_dict:
|
||||
AE_KEYS_RENAME_DICT.update(AE_F32C32_F64C128_F128C512_KEYS)
|
||||
|
||||
for key in list(converted_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in AE_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
converted_state_dict[new_key] = converted_state_dict.pop(key)
|
||||
|
||||
for key in list(converted_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, converted_state_dict)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_mochi_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
new_state_dict = {}
|
||||
|
||||
# Comfy checkpoints add this prefix
|
||||
keys = list(checkpoint.keys())
|
||||
for k in keys:
|
||||
if "model.diffusion_model." in k:
|
||||
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
||||
|
||||
# Convert patch_embed
|
||||
new_state_dict["patch_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
|
||||
new_state_dict["patch_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
|
||||
|
||||
# Convert time_embed
|
||||
new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop("t_embedder.mlp.0.weight")
|
||||
new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
|
||||
new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop("t_embedder.mlp.2.weight")
|
||||
new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
|
||||
new_state_dict["time_embed.pooler.to_kv.weight"] = checkpoint.pop("t5_y_embedder.to_kv.weight")
|
||||
new_state_dict["time_embed.pooler.to_kv.bias"] = checkpoint.pop("t5_y_embedder.to_kv.bias")
|
||||
new_state_dict["time_embed.pooler.to_q.weight"] = checkpoint.pop("t5_y_embedder.to_q.weight")
|
||||
new_state_dict["time_embed.pooler.to_q.bias"] = checkpoint.pop("t5_y_embedder.to_q.bias")
|
||||
new_state_dict["time_embed.pooler.to_out.weight"] = checkpoint.pop("t5_y_embedder.to_out.weight")
|
||||
new_state_dict["time_embed.pooler.to_out.bias"] = checkpoint.pop("t5_y_embedder.to_out.bias")
|
||||
new_state_dict["time_embed.caption_proj.weight"] = checkpoint.pop("t5_yproj.weight")
|
||||
new_state_dict["time_embed.caption_proj.bias"] = checkpoint.pop("t5_yproj.bias")
|
||||
|
||||
# Convert transformer blocks
|
||||
num_layers = 48
|
||||
for i in range(num_layers):
|
||||
block_prefix = f"transformer_blocks.{i}."
|
||||
old_prefix = f"blocks.{i}."
|
||||
|
||||
# norm1
|
||||
new_state_dict[block_prefix + "norm1.linear.weight"] = checkpoint.pop(old_prefix + "mod_x.weight")
|
||||
new_state_dict[block_prefix + "norm1.linear.bias"] = checkpoint.pop(old_prefix + "mod_x.bias")
|
||||
if i < num_layers - 1:
|
||||
new_state_dict[block_prefix + "norm1_context.linear.weight"] = checkpoint.pop(old_prefix + "mod_y.weight")
|
||||
new_state_dict[block_prefix + "norm1_context.linear.bias"] = checkpoint.pop(old_prefix + "mod_y.bias")
|
||||
else:
|
||||
new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = checkpoint.pop(
|
||||
old_prefix + "mod_y.weight"
|
||||
)
|
||||
new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = checkpoint.pop(old_prefix + "mod_y.bias")
|
||||
|
||||
# Visual attention
|
||||
qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_x.weight")
|
||||
q, k, v = qkv_weight.chunk(3, dim=0)
|
||||
|
||||
new_state_dict[block_prefix + "attn1.to_q.weight"] = q
|
||||
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
|
||||
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
|
||||
new_state_dict[block_prefix + "attn1.norm_q.weight"] = checkpoint.pop(old_prefix + "attn.q_norm_x.weight")
|
||||
new_state_dict[block_prefix + "attn1.norm_k.weight"] = checkpoint.pop(old_prefix + "attn.k_norm_x.weight")
|
||||
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = checkpoint.pop(old_prefix + "attn.proj_x.weight")
|
||||
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = checkpoint.pop(old_prefix + "attn.proj_x.bias")
|
||||
|
||||
# Context attention
|
||||
qkv_weight = checkpoint.pop(old_prefix + "attn.qkv_y.weight")
|
||||
q, k, v = qkv_weight.chunk(3, dim=0)
|
||||
|
||||
new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q
|
||||
new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k
|
||||
new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v
|
||||
new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = checkpoint.pop(
|
||||
old_prefix + "attn.q_norm_y.weight"
|
||||
)
|
||||
new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = checkpoint.pop(
|
||||
old_prefix + "attn.k_norm_y.weight"
|
||||
)
|
||||
if i < num_layers - 1:
|
||||
new_state_dict[block_prefix + "attn1.to_add_out.weight"] = checkpoint.pop(
|
||||
old_prefix + "attn.proj_y.weight"
|
||||
)
|
||||
new_state_dict[block_prefix + "attn1.to_add_out.bias"] = checkpoint.pop(old_prefix + "attn.proj_y.bias")
|
||||
|
||||
# MLP
|
||||
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = swap_proj_gate(
|
||||
checkpoint.pop(old_prefix + "mlp_x.w1.weight")
|
||||
)
|
||||
new_state_dict[block_prefix + "ff.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_x.w2.weight")
|
||||
if i < num_layers - 1:
|
||||
new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = swap_proj_gate(
|
||||
checkpoint.pop(old_prefix + "mlp_y.w1.weight")
|
||||
)
|
||||
new_state_dict[block_prefix + "ff_context.net.2.weight"] = checkpoint.pop(old_prefix + "mlp_y.w2.weight")
|
||||
|
||||
# Output layers
|
||||
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(checkpoint.pop("final_layer.mod.weight"), dim=0)
|
||||
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(checkpoint.pop("final_layer.mod.bias"), dim=0)
|
||||
new_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
|
||||
new_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
|
||||
|
||||
new_state_dict["pos_frequencies"] = checkpoint.pop("pos_frequencies")
|
||||
|
||||
return new_state_dict
|
||||
|
||||
@@ -0,0 +1,179 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from contextlib import nullcontext
|
||||
|
||||
from ..models.embeddings import (
|
||||
ImageProjection,
|
||||
MultiIPAdapterImageProjection,
|
||||
)
|
||||
from ..models.modeling_utils import load_model_dict_into_meta
|
||||
from ..utils import (
|
||||
is_accelerate_available,
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
pass
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class FluxTransformer2DLoadersMixin:
|
||||
"""
|
||||
Load layers into a [`FluxTransformer2DModel`].
|
||||
"""
|
||||
|
||||
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False):
|
||||
if low_cpu_mem_usage:
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
else:
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
updated_state_dict = {}
|
||||
image_projection = None
|
||||
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
|
||||
if "proj.weight" in state_dict:
|
||||
# IP-Adapter
|
||||
num_image_text_embeds = 4
|
||||
if state_dict["proj.weight"].shape[0] == 65536:
|
||||
num_image_text_embeds = 16
|
||||
clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
|
||||
cross_attention_dim = state_dict["proj.weight"].shape[0] // num_image_text_embeds
|
||||
|
||||
with init_context():
|
||||
image_projection = ImageProjection(
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
image_embed_dim=clip_embeddings_dim,
|
||||
num_image_text_embeds=num_image_text_embeds,
|
||||
)
|
||||
|
||||
for key, value in state_dict.items():
|
||||
diffusers_name = key.replace("proj", "image_embeds")
|
||||
updated_state_dict[diffusers_name] = value
|
||||
|
||||
if not low_cpu_mem_usage:
|
||||
image_projection.load_state_dict(updated_state_dict, strict=True)
|
||||
else:
|
||||
load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
|
||||
|
||||
return image_projection
|
||||
|
||||
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
|
||||
from ..models.attention_processor import (
|
||||
FluxIPAdapterJointAttnProcessor2_0,
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
else:
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
# set ip-adapter cross-attention processors & load state_dict
|
||||
attn_procs = {}
|
||||
key_id = 0
|
||||
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
for name in self.attn_processors.keys():
|
||||
if name.startswith("single_transformer_blocks"):
|
||||
attn_processor_class = self.attn_processors[name].__class__
|
||||
attn_procs[name] = attn_processor_class()
|
||||
else:
|
||||
cross_attention_dim = self.config.joint_attention_dim
|
||||
hidden_size = self.inner_dim
|
||||
attn_processor_class = FluxIPAdapterJointAttnProcessor2_0
|
||||
num_image_text_embeds = []
|
||||
for state_dict in state_dicts:
|
||||
if "proj.weight" in state_dict["image_proj"]:
|
||||
num_image_text_embed = 4
|
||||
if state_dict["image_proj"]["proj.weight"].shape[0] == 65536:
|
||||
num_image_text_embed = 16
|
||||
# IP-Adapter
|
||||
num_image_text_embeds += [num_image_text_embed]
|
||||
|
||||
with init_context():
|
||||
attn_procs[name] = attn_processor_class(
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
scale=1.0,
|
||||
num_tokens=num_image_text_embeds,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
value_dict = {}
|
||||
for i, state_dict in enumerate(state_dicts):
|
||||
value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
|
||||
value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
|
||||
value_dict.update({f"to_k_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_k_ip.bias"]})
|
||||
value_dict.update({f"to_v_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_v_ip.bias"]})
|
||||
|
||||
if not low_cpu_mem_usage:
|
||||
attn_procs[name].load_state_dict(value_dict)
|
||||
else:
|
||||
device = self.device
|
||||
dtype = self.dtype
|
||||
load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
|
||||
|
||||
key_id += 1
|
||||
|
||||
return attn_procs
|
||||
|
||||
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
|
||||
if not isinstance(state_dicts, list):
|
||||
state_dicts = [state_dicts]
|
||||
|
||||
self.encoder_hid_proj = None
|
||||
|
||||
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||
self.set_attn_processor(attn_procs)
|
||||
|
||||
image_projection_layers = []
|
||||
for state_dict in state_dicts:
|
||||
image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(
|
||||
state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage
|
||||
)
|
||||
image_projection_layers.append(image_projection_layer)
|
||||
|
||||
self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
|
||||
self.config.encoder_hid_dim_type = "ip_image_proj"
|
||||
@@ -0,0 +1,89 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Dict
|
||||
|
||||
from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
|
||||
from ..models.embeddings import IPAdapterTimeImageProjection
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||
|
||||
|
||||
class SD3Transformer2DLoadersMixin:
|
||||
"""Load IP-Adapters and LoRA layers into a `[SD3Transformer2DModel]`."""
|
||||
|
||||
def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT) -> None:
|
||||
"""Sets IP-Adapter attention processors, image projection, and loads state_dict.
|
||||
|
||||
Args:
|
||||
state_dict (`Dict`):
|
||||
State dict with keys "ip_adapter", which contains parameters for attention processors, and
|
||||
"image_proj", which contains parameters for image projection net.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
"""
|
||||
# IP-Adapter cross attention parameters
|
||||
hidden_size = self.config.attention_head_dim * self.config.num_attention_heads
|
||||
ip_hidden_states_dim = self.config.attention_head_dim * self.config.num_attention_heads
|
||||
timesteps_emb_dim = state_dict["ip_adapter"]["0.norm_ip.linear.weight"].shape[1]
|
||||
|
||||
# Dict where key is transformer layer index, value is attention processor's state dict
|
||||
# ip_adapter state dict keys example: "0.norm_ip.linear.weight"
|
||||
layer_state_dict = {idx: {} for idx in range(len(self.attn_processors))}
|
||||
for key, weights in state_dict["ip_adapter"].items():
|
||||
idx, name = key.split(".", maxsplit=1)
|
||||
layer_state_dict[int(idx)][name] = weights
|
||||
|
||||
# Create IP-Adapter attention processor
|
||||
attn_procs = {}
|
||||
for idx, name in enumerate(self.attn_processors.keys()):
|
||||
attn_procs[name] = SD3IPAdapterJointAttnProcessor2_0(
|
||||
hidden_size=hidden_size,
|
||||
ip_hidden_states_dim=ip_hidden_states_dim,
|
||||
head_dim=self.config.attention_head_dim,
|
||||
timesteps_emb_dim=timesteps_emb_dim,
|
||||
).to(self.device, dtype=self.dtype)
|
||||
|
||||
if not low_cpu_mem_usage:
|
||||
attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True)
|
||||
else:
|
||||
load_model_dict_into_meta(
|
||||
attn_procs[name], layer_state_dict[idx], device=self.device, dtype=self.dtype
|
||||
)
|
||||
|
||||
self.set_attn_processor(attn_procs)
|
||||
|
||||
# Image projetion parameters
|
||||
embed_dim = state_dict["image_proj"]["proj_in.weight"].shape[1]
|
||||
output_dim = state_dict["image_proj"]["proj_out.weight"].shape[0]
|
||||
hidden_dim = state_dict["image_proj"]["proj_in.weight"].shape[0]
|
||||
heads = state_dict["image_proj"]["layers.0.attn.to_q.weight"].shape[0] // 64
|
||||
num_queries = state_dict["image_proj"]["latents"].shape[1]
|
||||
timestep_in_dim = state_dict["image_proj"]["time_embedding.linear_1.weight"].shape[1]
|
||||
|
||||
# Image projection
|
||||
self.image_proj = IPAdapterTimeImageProjection(
|
||||
embed_dim=embed_dim,
|
||||
output_dim=output_dim,
|
||||
hidden_dim=hidden_dim,
|
||||
heads=heads,
|
||||
num_queries=num_queries,
|
||||
timestep_in_dim=timestep_in_dim,
|
||||
).to(device=self.device, dtype=self.dtype)
|
||||
|
||||
if not low_cpu_mem_usage:
|
||||
self.image_proj.load_state_dict(state_dict["image_proj"], strict=True)
|
||||
else:
|
||||
load_model_dict_into_meta(self.image_proj, state_dict["image_proj"], device=self.device, dtype=self.dtype)
|
||||
@@ -492,6 +492,9 @@ class UNet2DConditionLoadersMixin:
|
||||
)
|
||||
state_dict = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)}
|
||||
else:
|
||||
deprecation_message = "Using the `save_attn_procs()` method has been deprecated and will be removed in a future version. Please use `save_lora_adapter()`."
|
||||
deprecate("save_attn_procs", "0.40.0", deprecation_message)
|
||||
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for saving LoRAs using the `save_attn_procs()` method.")
|
||||
|
||||
|
||||
@@ -27,9 +27,12 @@ _import_structure = {}
|
||||
if is_torch_available():
|
||||
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
|
||||
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
|
||||
_import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"]
|
||||
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
|
||||
_import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"]
|
||||
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
|
||||
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
|
||||
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
|
||||
@@ -44,6 +47,7 @@ if is_torch_available():
|
||||
]
|
||||
_import_structure["controlnets.controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
|
||||
_import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"]
|
||||
_import_structure["controlnets.controlnet_union"] = ["ControlNetUnionModel"]
|
||||
_import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
|
||||
_import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"]
|
||||
_import_structure["embeddings"] = ["ImageProjection"]
|
||||
@@ -57,12 +61,15 @@ if is_torch_available():
|
||||
_import_structure["transformers.lumina_nextdit2d"] = ["LuminaNextDiT2DModel"]
|
||||
_import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"]
|
||||
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
|
||||
_import_structure["transformers.sana_transformer"] = ["SanaTransformer2DModel"]
|
||||
_import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
|
||||
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
|
||||
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
|
||||
@@ -88,9 +95,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .adapter import MultiAdapter, T2IAdapter
|
||||
from .autoencoders import (
|
||||
AsymmetricAutoencoderKL,
|
||||
AutoencoderDC,
|
||||
AutoencoderKL,
|
||||
AutoencoderKLAllegro,
|
||||
AutoencoderKLCogVideoX,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMochi,
|
||||
AutoencoderKLTemporalDecoder,
|
||||
AutoencoderOobleck,
|
||||
@@ -100,6 +110,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
)
|
||||
from .controlnets import (
|
||||
ControlNetModel,
|
||||
ControlNetUnionModel,
|
||||
ControlNetXSAdapter,
|
||||
FluxControlNetModel,
|
||||
FluxMultiControlNetModel,
|
||||
@@ -122,11 +133,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
DualTransformer2DModel,
|
||||
FluxTransformer2DModel,
|
||||
HunyuanDiT2DModel,
|
||||
HunyuanVideoTransformer3DModel,
|
||||
LatteTransformer3DModel,
|
||||
LTXVideoTransformer3DModel,
|
||||
LuminaNextDiT2DModel,
|
||||
MochiTransformer3DModel,
|
||||
PixArtTransformer2DModel,
|
||||
PriorTransformer,
|
||||
SanaTransformer2DModel,
|
||||
SD3Transformer2DModel,
|
||||
StableAudioDiTModel,
|
||||
T5FilmDecoder,
|
||||
|
||||
@@ -18,7 +18,7 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..utils import deprecate
|
||||
from ..utils.import_utils import is_torch_npu_available
|
||||
from ..utils.import_utils import is_torch_npu_available, is_torch_version
|
||||
|
||||
|
||||
if is_torch_npu_available():
|
||||
@@ -79,10 +79,10 @@ class GELU(nn.Module):
|
||||
self.approximate = approximate
|
||||
|
||||
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
||||
if gate.device.type != "mps":
|
||||
return F.gelu(gate, approximate=self.approximate)
|
||||
# mps: gelu is not implemented for float16
|
||||
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
|
||||
if gate.device.type == "mps" and is_torch_version("<", "2.0.0"):
|
||||
# fp16 gelu not supported on mps before torch 2.0
|
||||
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
|
||||
return F.gelu(gate, approximate=self.approximate)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.proj(hidden_states)
|
||||
@@ -105,10 +105,10 @@ class GEGLU(nn.Module):
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
|
||||
|
||||
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
||||
if gate.device.type != "mps":
|
||||
return F.gelu(gate)
|
||||
# mps: gelu is not implemented for float16
|
||||
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
||||
if gate.device.type == "mps" and is_torch_version("<", "2.0.0"):
|
||||
# fp16 gelu not supported on mps before torch 2.0
|
||||
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
||||
return F.gelu(gate)
|
||||
|
||||
def forward(self, hidden_states, *args, **kwargs):
|
||||
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
||||
@@ -164,3 +164,15 @@ class ApproximateGELU(nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.proj(x)
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
class LinearActivation(nn.Module):
|
||||
def __init__(self, dim_in: int, dim_out: int, bias: bool = True, activation: str = "silu"):
|
||||
super().__init__()
|
||||
|
||||
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
|
||||
self.activation = get_activation(activation)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.proj(hidden_states)
|
||||
return self.activation(hidden_states)
|
||||
|
||||
@@ -19,7 +19,7 @@ from torch import nn
|
||||
|
||||
from ..utils import deprecate, logging
|
||||
from ..utils.torch_utils import maybe_allow_in_graph
|
||||
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
|
||||
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
|
||||
from .attention_processor import Attention, JointAttnProcessor2_0
|
||||
from .embeddings import SinusoidalPositionalEmbedding
|
||||
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
|
||||
@@ -188,8 +188,13 @@ class JointTransformerBlock(nn.Module):
|
||||
self._chunk_dim = dim
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor,
|
||||
temb: torch.FloatTensor,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
if self.use_dual_attention:
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
|
||||
hidden_states, emb=temb
|
||||
@@ -206,7 +211,9 @@ class JointTransformerBlock(nn.Module):
|
||||
|
||||
# Attention.
|
||||
attn_output, context_attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
**joint_attention_kwargs,
|
||||
)
|
||||
|
||||
# Process attention outputs for the `hidden_states`.
|
||||
@@ -214,7 +221,7 @@ class JointTransformerBlock(nn.Module):
|
||||
hidden_states = hidden_states + attn_output
|
||||
|
||||
if self.use_dual_attention:
|
||||
attn_output2 = self.attn2(hidden_states=norm_hidden_states2)
|
||||
attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs)
|
||||
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
|
||||
hidden_states = hidden_states + attn_output2
|
||||
|
||||
@@ -1222,6 +1229,8 @@ class FeedForward(nn.Module):
|
||||
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
|
||||
elif activation_fn == "swiglu":
|
||||
act_fn = SwiGLU(dim, inner_dim, bias=bias)
|
||||
elif activation_fn == "linear-silu":
|
||||
act_fn = LinearActivation(dim, inner_dim, bias=bias, activation="silu")
|
||||
|
||||
self.net = nn.ModuleList([])
|
||||
# project in
|
||||
|
||||
@@ -216,8 +216,8 @@ class FlaxAttention(nn.Module):
|
||||
hidden_states = jax_memory_efficient_attention(
|
||||
query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 0, 2)
|
||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||
else:
|
||||
# compute attentions
|
||||
if self.split_head_dim:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,7 +1,10 @@
|
||||
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
|
||||
from .autoencoder_dc import AutoencoderDC
|
||||
from .autoencoder_kl import AutoencoderKL
|
||||
from .autoencoder_kl_allegro import AutoencoderKLAllegro
|
||||
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
|
||||
from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
|
||||
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
|
||||
from .autoencoder_kl_mochi import AutoencoderKLMochi
|
||||
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
|
||||
from .autoencoder_oobleck import AutoencoderOobleck
|
||||
|
||||
@@ -0,0 +1,620 @@
|
||||
# Copyright 2024 MIT, Tsinghua University, NVIDIA CORPORATION and The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..activations import get_activation
|
||||
from ..attention_processor import SanaMultiscaleLinearAttention
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import RMSNorm, get_normalization
|
||||
from ..transformers.sana_transformer import GLUMBConv
|
||||
from .vae import DecoderOutput, EncoderOutput
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
norm_type: str = "batch_norm",
|
||||
act_fn: str = "relu6",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.norm_type = norm_type
|
||||
|
||||
self.nonlinearity = get_activation(act_fn) if act_fn is not None else nn.Identity()
|
||||
self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
|
||||
self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False)
|
||||
self.norm = get_normalization(norm_type, out_channels)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.norm_type == "rms_norm":
|
||||
# move channel to the last dimension so we apply RMSnorm across channel dimension
|
||||
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
else:
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
return hidden_states + residual
|
||||
|
||||
|
||||
class EfficientViTBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
mult: float = 1.0,
|
||||
attention_head_dim: int = 32,
|
||||
qkv_multiscales: Tuple[int, ...] = (5,),
|
||||
norm_type: str = "batch_norm",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.attn = SanaMultiscaleLinearAttention(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
mult=mult,
|
||||
attention_head_dim=attention_head_dim,
|
||||
norm_type=norm_type,
|
||||
kernel_sizes=qkv_multiscales,
|
||||
residual_connection=True,
|
||||
)
|
||||
|
||||
self.conv_out = GLUMBConv(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
norm_type="rms_norm",
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.attn(x)
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
|
||||
def get_block(
|
||||
block_type: str,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
attention_head_dim: int,
|
||||
norm_type: str,
|
||||
act_fn: str,
|
||||
qkv_mutliscales: Tuple[int] = (),
|
||||
):
|
||||
if block_type == "ResBlock":
|
||||
block = ResBlock(in_channels, out_channels, norm_type, act_fn)
|
||||
|
||||
elif block_type == "EfficientViTBlock":
|
||||
block = EfficientViTBlock(
|
||||
in_channels, attention_head_dim=attention_head_dim, norm_type=norm_type, qkv_multiscales=qkv_mutliscales
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Block with {block_type=} is not supported.")
|
||||
|
||||
return block
|
||||
|
||||
|
||||
class DCDownBlock2d(nn.Module):
|
||||
def __init__(self, in_channels: int, out_channels: int, downsample: bool = False, shortcut: bool = True) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.downsample = downsample
|
||||
self.factor = 2
|
||||
self.stride = 1 if downsample else 2
|
||||
self.group_size = in_channels * self.factor**2 // out_channels
|
||||
self.shortcut = shortcut
|
||||
|
||||
out_ratio = self.factor**2
|
||||
if downsample:
|
||||
assert out_channels % out_ratio == 0
|
||||
out_channels = out_channels // out_ratio
|
||||
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=self.stride,
|
||||
padding=1,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
x = self.conv(hidden_states)
|
||||
if self.downsample:
|
||||
x = F.pixel_unshuffle(x, self.factor)
|
||||
|
||||
if self.shortcut:
|
||||
y = F.pixel_unshuffle(hidden_states, self.factor)
|
||||
y = y.unflatten(1, (-1, self.group_size))
|
||||
y = y.mean(dim=2)
|
||||
hidden_states = x + y
|
||||
else:
|
||||
hidden_states = x
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DCUpBlock2d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
interpolate: bool = False,
|
||||
shortcut: bool = True,
|
||||
interpolation_mode: str = "nearest",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.interpolate = interpolate
|
||||
self.interpolation_mode = interpolation_mode
|
||||
self.shortcut = shortcut
|
||||
self.factor = 2
|
||||
self.repeats = out_channels * self.factor**2 // in_channels
|
||||
|
||||
out_ratio = self.factor**2
|
||||
|
||||
if not interpolate:
|
||||
out_channels = out_channels * out_ratio
|
||||
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if self.interpolate:
|
||||
x = F.interpolate(hidden_states, scale_factor=self.factor, mode=self.interpolation_mode)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = self.conv(hidden_states)
|
||||
x = F.pixel_shuffle(x, self.factor)
|
||||
|
||||
if self.shortcut:
|
||||
y = hidden_states.repeat_interleave(self.repeats, dim=1)
|
||||
y = F.pixel_shuffle(y, self.factor)
|
||||
hidden_states = x + y
|
||||
else:
|
||||
hidden_states = x
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
latent_channels: int,
|
||||
attention_head_dim: int = 32,
|
||||
block_type: Union[str, Tuple[str]] = "ResBlock",
|
||||
block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024),
|
||||
layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2),
|
||||
qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||
downsample_block_type: str = "pixel_unshuffle",
|
||||
out_shortcut: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
num_blocks = len(block_out_channels)
|
||||
|
||||
if isinstance(block_type, str):
|
||||
block_type = (block_type,) * num_blocks
|
||||
|
||||
if layers_per_block[0] > 0:
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels,
|
||||
block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
)
|
||||
else:
|
||||
self.conv_in = DCDownBlock2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1],
|
||||
downsample=downsample_block_type == "pixel_unshuffle",
|
||||
shortcut=False,
|
||||
)
|
||||
|
||||
down_blocks = []
|
||||
for i, (out_channel, num_layers) in enumerate(zip(block_out_channels, layers_per_block)):
|
||||
down_block_list = []
|
||||
|
||||
for _ in range(num_layers):
|
||||
block = get_block(
|
||||
block_type[i],
|
||||
out_channel,
|
||||
out_channel,
|
||||
attention_head_dim=attention_head_dim,
|
||||
norm_type="rms_norm",
|
||||
act_fn="silu",
|
||||
qkv_mutliscales=qkv_multiscales[i],
|
||||
)
|
||||
down_block_list.append(block)
|
||||
|
||||
if i < num_blocks - 1 and num_layers > 0:
|
||||
downsample_block = DCDownBlock2d(
|
||||
in_channels=out_channel,
|
||||
out_channels=block_out_channels[i + 1],
|
||||
downsample=downsample_block_type == "pixel_unshuffle",
|
||||
shortcut=True,
|
||||
)
|
||||
down_block_list.append(downsample_block)
|
||||
|
||||
down_blocks.append(nn.Sequential(*down_block_list))
|
||||
|
||||
self.down_blocks = nn.ModuleList(down_blocks)
|
||||
|
||||
self.conv_out = nn.Conv2d(block_out_channels[-1], latent_channels, 3, 1, 1)
|
||||
|
||||
self.out_shortcut = out_shortcut
|
||||
if out_shortcut:
|
||||
self.out_shortcut_average_group_size = block_out_channels[-1] // latent_channels
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = down_block(hidden_states)
|
||||
|
||||
if self.out_shortcut:
|
||||
x = hidden_states.unflatten(1, (-1, self.out_shortcut_average_group_size))
|
||||
x = x.mean(dim=2)
|
||||
hidden_states = self.conv_out(hidden_states) + x
|
||||
else:
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
latent_channels: int,
|
||||
attention_head_dim: int = 32,
|
||||
block_type: Union[str, Tuple[str]] = "ResBlock",
|
||||
block_out_channels: Tuple[int] = (128, 256, 512, 512, 1024, 1024),
|
||||
layers_per_block: Tuple[int] = (2, 2, 2, 2, 2, 2),
|
||||
qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||
norm_type: Union[str, Tuple[str]] = "rms_norm",
|
||||
act_fn: Union[str, Tuple[str]] = "silu",
|
||||
upsample_block_type: str = "pixel_shuffle",
|
||||
in_shortcut: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
num_blocks = len(block_out_channels)
|
||||
|
||||
if isinstance(block_type, str):
|
||||
block_type = (block_type,) * num_blocks
|
||||
if isinstance(norm_type, str):
|
||||
norm_type = (norm_type,) * num_blocks
|
||||
if isinstance(act_fn, str):
|
||||
act_fn = (act_fn,) * num_blocks
|
||||
|
||||
self.conv_in = nn.Conv2d(latent_channels, block_out_channels[-1], 3, 1, 1)
|
||||
|
||||
self.in_shortcut = in_shortcut
|
||||
if in_shortcut:
|
||||
self.in_shortcut_repeats = block_out_channels[-1] // latent_channels
|
||||
|
||||
up_blocks = []
|
||||
for i, (out_channel, num_layers) in reversed(list(enumerate(zip(block_out_channels, layers_per_block)))):
|
||||
up_block_list = []
|
||||
|
||||
if i < num_blocks - 1 and num_layers > 0:
|
||||
upsample_block = DCUpBlock2d(
|
||||
block_out_channels[i + 1],
|
||||
out_channel,
|
||||
interpolate=upsample_block_type == "interpolate",
|
||||
shortcut=True,
|
||||
)
|
||||
up_block_list.append(upsample_block)
|
||||
|
||||
for _ in range(num_layers):
|
||||
block = get_block(
|
||||
block_type[i],
|
||||
out_channel,
|
||||
out_channel,
|
||||
attention_head_dim=attention_head_dim,
|
||||
norm_type=norm_type[i],
|
||||
act_fn=act_fn[i],
|
||||
qkv_mutliscales=qkv_multiscales[i],
|
||||
)
|
||||
up_block_list.append(block)
|
||||
|
||||
up_blocks.insert(0, nn.Sequential(*up_block_list))
|
||||
|
||||
self.up_blocks = nn.ModuleList(up_blocks)
|
||||
|
||||
channels = block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1]
|
||||
|
||||
self.norm_out = RMSNorm(channels, 1e-5, elementwise_affine=True, bias=True)
|
||||
self.conv_act = nn.ReLU()
|
||||
self.conv_out = None
|
||||
|
||||
if layers_per_block[0] > 0:
|
||||
self.conv_out = nn.Conv2d(channels, in_channels, 3, 1, 1)
|
||||
else:
|
||||
self.conv_out = DCUpBlock2d(
|
||||
channels, in_channels, interpolate=upsample_block_type == "interpolate", shortcut=False
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if self.in_shortcut:
|
||||
x = hidden_states.repeat_interleave(self.in_shortcut_repeats, dim=1)
|
||||
hidden_states = self.conv_in(hidden_states) + x
|
||||
else:
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
for up_block in reversed(self.up_blocks):
|
||||
hidden_states = up_block(hidden_states)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
r"""
|
||||
An Autoencoder model introduced in [DCAE](https://arxiv.org/abs/2410.10733) and used in
|
||||
[SANA](https://arxiv.org/abs/2410.10629).
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
||||
for all models (such as downloading or saving).
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to `3`):
|
||||
The number of input channels in samples.
|
||||
latent_channels (`int`, defaults to `32`):
|
||||
The number of channels in the latent space representation.
|
||||
encoder_block_types (`Union[str, Tuple[str]]`, defaults to `"ResBlock"`):
|
||||
The type(s) of block to use in the encoder.
|
||||
decoder_block_types (`Union[str, Tuple[str]]`, defaults to `"ResBlock"`):
|
||||
The type(s) of block to use in the decoder.
|
||||
encoder_block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512, 1024, 1024)`):
|
||||
The number of output channels for each block in the encoder.
|
||||
decoder_block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512, 1024, 1024)`):
|
||||
The number of output channels for each block in the decoder.
|
||||
encoder_layers_per_block (`Tuple[int]`, defaults to `(2, 2, 2, 3, 3, 3)`):
|
||||
The number of layers per block in the encoder.
|
||||
decoder_layers_per_block (`Tuple[int]`, defaults to `(3, 3, 3, 3, 3, 3)`):
|
||||
The number of layers per block in the decoder.
|
||||
encoder_qkv_multiscales (`Tuple[Tuple[int, ...], ...]`, defaults to `((), (), (), (5,), (5,), (5,))`):
|
||||
Multi-scale configurations for the encoder's QKV (query-key-value) transformations.
|
||||
decoder_qkv_multiscales (`Tuple[Tuple[int, ...], ...]`, defaults to `((), (), (), (5,), (5,), (5,))`):
|
||||
Multi-scale configurations for the decoder's QKV (query-key-value) transformations.
|
||||
upsample_block_type (`str`, defaults to `"pixel_shuffle"`):
|
||||
The type of block to use for upsampling in the decoder.
|
||||
downsample_block_type (`str`, defaults to `"pixel_unshuffle"`):
|
||||
The type of block to use for downsampling in the encoder.
|
||||
decoder_norm_types (`Union[str, Tuple[str]]`, defaults to `"rms_norm"`):
|
||||
The normalization type(s) to use in the decoder.
|
||||
decoder_act_fns (`Union[str, Tuple[str]]`, defaults to `"silu"`):
|
||||
The activation function(s) to use in the decoder.
|
||||
scaling_factor (`float`, defaults to `1.0`):
|
||||
The multiplicative inverse of the root mean square of the latent features. This is used to scale the latent
|
||||
space to have unit variance when training the diffusion model. The latents are scaled with the formula `z =
|
||||
z * scaling_factor` before being passed to the diffusion model. When decoding, the latents are scaled back
|
||||
to the original scale with the formula: `z = 1 / scaling_factor * z`.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = False
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
latent_channels: int = 32,
|
||||
attention_head_dim: int = 32,
|
||||
encoder_block_types: Union[str, Tuple[str]] = "ResBlock",
|
||||
decoder_block_types: Union[str, Tuple[str]] = "ResBlock",
|
||||
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
|
||||
decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512, 1024, 1024),
|
||||
encoder_layers_per_block: Tuple[int] = (2, 2, 2, 3, 3, 3),
|
||||
decoder_layers_per_block: Tuple[int] = (3, 3, 3, 3, 3, 3),
|
||||
encoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||
decoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (), (5,), (5,), (5,)),
|
||||
upsample_block_type: str = "pixel_shuffle",
|
||||
downsample_block_type: str = "pixel_unshuffle",
|
||||
decoder_norm_types: Union[str, Tuple[str]] = "rms_norm",
|
||||
decoder_act_fns: Union[str, Tuple[str]] = "silu",
|
||||
scaling_factor: float = 1.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.encoder = Encoder(
|
||||
in_channels=in_channels,
|
||||
latent_channels=latent_channels,
|
||||
attention_head_dim=attention_head_dim,
|
||||
block_type=encoder_block_types,
|
||||
block_out_channels=encoder_block_out_channels,
|
||||
layers_per_block=encoder_layers_per_block,
|
||||
qkv_multiscales=encoder_qkv_multiscales,
|
||||
downsample_block_type=downsample_block_type,
|
||||
)
|
||||
self.decoder = Decoder(
|
||||
in_channels=in_channels,
|
||||
latent_channels=latent_channels,
|
||||
attention_head_dim=attention_head_dim,
|
||||
block_type=decoder_block_types,
|
||||
block_out_channels=decoder_block_out_channels,
|
||||
layers_per_block=decoder_layers_per_block,
|
||||
qkv_multiscales=decoder_qkv_multiscales,
|
||||
norm_type=decoder_norm_types,
|
||||
act_fn=decoder_act_fns,
|
||||
upsample_block_type=upsample_block_type,
|
||||
)
|
||||
|
||||
self.spatial_compression_ratio = 2 ** (len(encoder_block_out_channels) - 1)
|
||||
self.temporal_compression_ratio = 1
|
||||
|
||||
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
|
||||
# to perform decoding of a single video latent at a time.
|
||||
self.use_slicing = False
|
||||
|
||||
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
|
||||
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
|
||||
# intermediate tiles together, the memory requirement can be lowered.
|
||||
self.use_tiling = False
|
||||
|
||||
# The minimal tile height and width for spatial tiling to be used
|
||||
self.tile_sample_min_height = 512
|
||||
self.tile_sample_min_width = 512
|
||||
|
||||
# The minimal distance between two spatial tiles
|
||||
self.tile_sample_stride_height = 448
|
||||
self.tile_sample_stride_width = 448
|
||||
|
||||
def enable_tiling(
|
||||
self,
|
||||
tile_sample_min_height: Optional[int] = None,
|
||||
tile_sample_min_width: Optional[int] = None,
|
||||
tile_sample_stride_height: Optional[float] = None,
|
||||
tile_sample_stride_width: Optional[float] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Enable tiled AE decoding. When this option is enabled, the AE will split the input tensor into tiles to compute
|
||||
decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
||||
processing larger images.
|
||||
|
||||
Args:
|
||||
tile_sample_min_height (`int`, *optional*):
|
||||
The minimum height required for a sample to be separated into tiles across the height dimension.
|
||||
tile_sample_min_width (`int`, *optional*):
|
||||
The minimum width required for a sample to be separated into tiles across the width dimension.
|
||||
tile_sample_stride_height (`int`, *optional*):
|
||||
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
|
||||
no tiling artifacts produced across the height dimension.
|
||||
tile_sample_stride_width (`int`, *optional*):
|
||||
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
|
||||
artifacts produced across the width dimension.
|
||||
"""
|
||||
self.use_tiling = True
|
||||
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
|
||||
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
|
||||
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
|
||||
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
|
||||
|
||||
def disable_tiling(self) -> None:
|
||||
r"""
|
||||
Disable tiled AE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_tiling = False
|
||||
|
||||
def enable_slicing(self) -> None:
|
||||
r"""
|
||||
Enable sliced AE decoding. When this option is enabled, the AE will split the input tensor in slices to compute
|
||||
decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.use_slicing = True
|
||||
|
||||
def disable_slicing(self) -> None:
|
||||
r"""
|
||||
Disable sliced AE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, height, width = x.shape
|
||||
|
||||
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
|
||||
return self.tiled_encode(x, return_dict=False)[0]
|
||||
|
||||
encoded = self.encoder(x)
|
||||
|
||||
return encoded
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[EncoderOutput, Tuple[torch.Tensor]]:
|
||||
r"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of images.
|
||||
return_dict (`bool`, defaults to `True`):
|
||||
Whether to return a [`~models.vae.EncoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
The latent representations of the encoded videos. If `return_dict` is True, a
|
||||
[`~models.vae.EncoderOutput`] is returned, otherwise a plain `tuple` is returned.
|
||||
"""
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
||||
encoded = torch.cat(encoded_slices)
|
||||
else:
|
||||
encoded = self._encode(x)
|
||||
|
||||
if not return_dict:
|
||||
return (encoded,)
|
||||
return EncoderOutput(latent=encoded)
|
||||
|
||||
def _decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, height, width = z.shape
|
||||
|
||||
if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
|
||||
return self.tiled_decode(z, return_dict=False)[0]
|
||||
|
||||
decoded = self.decoder(z)
|
||||
|
||||
return decoded
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
|
||||
r"""
|
||||
Decode a batch of images.
|
||||
|
||||
Args:
|
||||
z (`torch.Tensor`): Input batch of latent vectors.
|
||||
return_dict (`bool`, defaults to `True`):
|
||||
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
"""
|
||||
if self.use_slicing and z.size(0) > 1:
|
||||
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z)
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
|
||||
raise NotImplementedError("`tiled_encode` has not been implemented for AutoencoderDC.")
|
||||
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
raise NotImplementedError("`tiled_decode` has not been implemented for AutoencoderDC.")
|
||||
|
||||
def forward(self, sample: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
|
||||
encoded = self.encode(sample, return_dict=False)[0]
|
||||
decoded = self.decode(encoded, return_dict=False)[0]
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
return DecoderOutput(sample=decoded)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -30,6 +30,19 @@ from ..unets.unet_2d_blocks import (
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EncoderOutput(BaseOutput):
|
||||
r"""
|
||||
Output of encoding method.
|
||||
|
||||
Args:
|
||||
latent (`torch.Tensor` of shape `(batch_size, num_channels, latent_height, latent_width)`):
|
||||
The encoded latent.
|
||||
"""
|
||||
|
||||
latent: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecoderOutput(BaseOutput):
|
||||
r"""
|
||||
|
||||
@@ -15,6 +15,7 @@ if is_torch_available():
|
||||
SparseControlNetModel,
|
||||
SparseControlNetOutput,
|
||||
)
|
||||
from .controlnet_union import ControlNetUnionModel
|
||||
from .controlnet_xs import ControlNetXSAdapter, ControlNetXSOutput, UNetControlNetXSModel
|
||||
from .multicontrolnet import MultiControlNetModel
|
||||
|
||||
|
||||
@@ -393,13 +393,19 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
if self.context_embedder is not None:
|
||||
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
# SD3.5 8b controlnet use single transformer block, which does not use `encoder_hidden_states`
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block), hidden_states, temb, **ckpt_kwargs
|
||||
)
|
||||
|
||||
else:
|
||||
if self.context_embedder is not None:
|
||||
|
||||
@@ -0,0 +1,832 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...utils import logging
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from ..embeddings import TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..unets.unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D,
|
||||
DownBlock2D,
|
||||
UNetMidBlock2DCrossAttn,
|
||||
get_down_block,
|
||||
)
|
||||
from ..unets.unet_2d_condition import UNet2DConditionModel
|
||||
from .controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class QuickGELU(nn.Module):
|
||||
"""
|
||||
Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
|
||||
"""
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return input * torch.sigmoid(1.702 * input)
|
||||
|
||||
|
||||
class ResidualAttentionMlp(nn.Module):
|
||||
def __init__(self, d_model: int):
|
||||
super().__init__()
|
||||
self.c_fc = nn.Linear(d_model, d_model * 4)
|
||||
self.gelu = QuickGELU()
|
||||
self.c_proj = nn.Linear(d_model * 4, d_model)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = self.c_fc(x)
|
||||
x = self.gelu(x)
|
||||
x = self.c_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
||||
super().__init__()
|
||||
self.attn = nn.MultiheadAttention(d_model, n_head)
|
||||
self.ln_1 = nn.LayerNorm(d_model)
|
||||
self.mlp = ResidualAttentionMlp(d_model)
|
||||
self.ln_2 = nn.LayerNorm(d_model)
|
||||
self.attn_mask = attn_mask
|
||||
|
||||
def attention(self, x: torch.Tensor):
|
||||
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
||||
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = x + self.attention(self.ln_1(x))
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
|
||||
class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
"""
|
||||
A ControlNetUnion model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to 4):
|
||||
The number of channels in the input sample.
|
||||
flip_sin_to_cos (`bool`, defaults to `True`):
|
||||
Whether to flip the sin to cos in the time embedding.
|
||||
freq_shift (`int`, defaults to 0):
|
||||
The frequency shift to apply to the time embedding.
|
||||
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||
The tuple of downsample blocks to use.
|
||||
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
||||
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
||||
The tuple of output channels for each block.
|
||||
layers_per_block (`int`, defaults to 2):
|
||||
The number of layers per block.
|
||||
downsample_padding (`int`, defaults to 1):
|
||||
The padding to use for the downsampling convolution.
|
||||
mid_block_scale_factor (`float`, defaults to 1):
|
||||
The scale factor to use for the mid block.
|
||||
act_fn (`str`, defaults to "silu"):
|
||||
The activation function to use.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
|
||||
in post-processing.
|
||||
norm_eps (`float`, defaults to 1e-5):
|
||||
The epsilon to use for the normalization.
|
||||
cross_attention_dim (`int`, defaults to 1280):
|
||||
The dimension of the cross attention features.
|
||||
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
||||
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
||||
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
||||
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
||||
encoder_hid_dim (`int`, *optional*, defaults to None):
|
||||
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
||||
dimension to `cross_attention_dim`.
|
||||
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
||||
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
||||
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
||||
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
||||
The dimension of the attention heads.
|
||||
use_linear_projection (`bool`, defaults to `False`):
|
||||
class_embed_type (`str`, *optional*, defaults to `None`):
|
||||
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
|
||||
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
||||
addition_embed_type (`str`, *optional*, defaults to `None`):
|
||||
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
||||
"text". "text" will use the `TextTimeEmbedding` layer.
|
||||
num_class_embeds (`int`, *optional*, defaults to 0):
|
||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||
class conditioning with `class_embed_type` equal to `None`.
|
||||
upcast_attention (`bool`, defaults to `False`):
|
||||
resnet_time_scale_shift (`str`, defaults to `"default"`):
|
||||
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
|
||||
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
|
||||
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
|
||||
`class_embed_type="projection"`.
|
||||
controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
|
||||
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
||||
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(48, 96, 192, 384)`):
|
||||
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
||||
global_pool_conditions (`bool`, defaults to `False`):
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 4,
|
||||
conditioning_channels: int = 3,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
act_fn: str = "silu",
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
encoder_hid_dim_type: Optional[str] = None,
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
addition_time_embed_dim: Optional[int] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (48, 96, 192, 384),
|
||||
global_pool_conditions: bool = False,
|
||||
addition_embed_type_num_heads: int = 64,
|
||||
num_control_type: int = 6,
|
||||
num_trans_channel: int = 320,
|
||||
num_trans_head: int = 8,
|
||||
num_trans_layer: int = 1,
|
||||
num_proj_channel: int = 320,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# If `num_attention_heads` is not defined (which is the case for most models)
|
||||
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
||||
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
||||
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
||||
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
||||
# which is why we correct for the naming here.
|
||||
num_attention_heads = num_attention_heads or attention_head_dim
|
||||
|
||||
# Check inputs
|
||||
if len(block_out_channels) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
||||
|
||||
# input
|
||||
conv_in_kernel = 3
|
||||
conv_in_padding = (conv_in_kernel - 1) // 2
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
||||
)
|
||||
|
||||
# time
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
self.time_embedding = TimestepEmbedding(
|
||||
timestep_input_dim,
|
||||
time_embed_dim,
|
||||
act_fn=act_fn,
|
||||
)
|
||||
|
||||
if encoder_hid_dim_type is not None:
|
||||
raise ValueError(f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None.")
|
||||
else:
|
||||
self.encoder_hid_proj = None
|
||||
|
||||
# class embedding
|
||||
if class_embed_type is None and num_class_embeds is not None:
|
||||
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
||||
elif class_embed_type == "timestep":
|
||||
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
elif class_embed_type == "identity":
|
||||
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
||||
elif class_embed_type == "projection":
|
||||
if projection_class_embeddings_input_dim is None:
|
||||
raise ValueError(
|
||||
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
||||
)
|
||||
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
||||
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
||||
# 2. it projects from an arbitrary input dimension.
|
||||
#
|
||||
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
||||
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
||||
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
||||
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
||||
else:
|
||||
self.class_embedding = None
|
||||
|
||||
if addition_embed_type == "text":
|
||||
if encoder_hid_dim is not None:
|
||||
text_time_embedding_from_dim = encoder_hid_dim
|
||||
else:
|
||||
text_time_embedding_from_dim = cross_attention_dim
|
||||
|
||||
self.add_embedding = TextTimeEmbedding(
|
||||
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
||||
)
|
||||
elif addition_embed_type == "text_image":
|
||||
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
||||
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
||||
# case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
|
||||
self.add_embedding = TextImageTimeEmbedding(
|
||||
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
||||
)
|
||||
elif addition_embed_type == "text_time":
|
||||
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
||||
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
||||
|
||||
elif addition_embed_type is not None:
|
||||
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
||||
|
||||
# control net conditioning embedding
|
||||
self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
|
||||
conditioning_embedding_channels=block_out_channels[0],
|
||||
block_out_channels=conditioning_embedding_out_channels,
|
||||
conditioning_channels=conditioning_channels,
|
||||
)
|
||||
|
||||
task_scale_factor = num_trans_channel**0.5
|
||||
self.task_embedding = nn.Parameter(task_scale_factor * torch.randn(num_control_type, num_trans_channel))
|
||||
self.transformer_layes = nn.ModuleList(
|
||||
[ResidualAttentionBlock(num_trans_channel, num_trans_head) for _ in range(num_trans_layer)]
|
||||
)
|
||||
self.spatial_ch_projs = zero_module(nn.Linear(num_trans_channel, num_proj_channel))
|
||||
self.control_type_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
||||
self.control_add_embedding = TimestepEmbedding(addition_time_embed_dim * num_control_type, time_embed_dim)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.controlnet_down_blocks = nn.ModuleList([])
|
||||
|
||||
if isinstance(only_cross_attention, bool):
|
||||
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
||||
|
||||
if isinstance(attention_head_dim, int):
|
||||
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||
|
||||
if isinstance(num_attention_heads, int):
|
||||
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
|
||||
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_down_blocks.append(controlnet_block)
|
||||
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block,
|
||||
transformer_layers_per_block=transformer_layers_per_block[i],
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_downsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads[i],
|
||||
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
||||
downsample_padding=downsample_padding,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
for _ in range(layers_per_block):
|
||||
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_down_blocks.append(controlnet_block)
|
||||
|
||||
if not is_final_block:
|
||||
controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_down_blocks.append(controlnet_block)
|
||||
|
||||
# mid
|
||||
mid_block_channel = block_out_channels[-1]
|
||||
|
||||
controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_mid_block = controlnet_block
|
||||
|
||||
self.mid_block = UNetMidBlock2DCrossAttn(
|
||||
transformer_layers_per_block=transformer_layers_per_block[-1],
|
||||
in_channels=mid_block_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_unet(
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
load_weights_from_unet: bool = True,
|
||||
):
|
||||
r"""
|
||||
Instantiate a [`ControlNetUnionModel`] from [`UNet2DConditionModel`].
|
||||
|
||||
Parameters:
|
||||
unet (`UNet2DConditionModel`):
|
||||
The UNet model weights to copy to the [`ControlNetUnionModel`]. All configuration options are also
|
||||
copied where applicable.
|
||||
"""
|
||||
transformer_layers_per_block = (
|
||||
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
|
||||
)
|
||||
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
|
||||
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
|
||||
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
|
||||
addition_time_embed_dim = (
|
||||
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
|
||||
)
|
||||
|
||||
controlnet = cls(
|
||||
encoder_hid_dim=encoder_hid_dim,
|
||||
encoder_hid_dim_type=encoder_hid_dim_type,
|
||||
addition_embed_type=addition_embed_type,
|
||||
addition_time_embed_dim=addition_time_embed_dim,
|
||||
transformer_layers_per_block=transformer_layers_per_block,
|
||||
in_channels=unet.config.in_channels,
|
||||
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
||||
freq_shift=unet.config.freq_shift,
|
||||
down_block_types=unet.config.down_block_types,
|
||||
only_cross_attention=unet.config.only_cross_attention,
|
||||
block_out_channels=unet.config.block_out_channels,
|
||||
layers_per_block=unet.config.layers_per_block,
|
||||
downsample_padding=unet.config.downsample_padding,
|
||||
mid_block_scale_factor=unet.config.mid_block_scale_factor,
|
||||
act_fn=unet.config.act_fn,
|
||||
norm_num_groups=unet.config.norm_num_groups,
|
||||
norm_eps=unet.config.norm_eps,
|
||||
cross_attention_dim=unet.config.cross_attention_dim,
|
||||
attention_head_dim=unet.config.attention_head_dim,
|
||||
num_attention_heads=unet.config.num_attention_heads,
|
||||
use_linear_projection=unet.config.use_linear_projection,
|
||||
class_embed_type=unet.config.class_embed_type,
|
||||
num_class_embeds=unet.config.num_class_embeds,
|
||||
upcast_attention=unet.config.upcast_attention,
|
||||
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
|
||||
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
|
||||
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
|
||||
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
||||
)
|
||||
|
||||
if load_weights_from_unet:
|
||||
controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
|
||||
controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
|
||||
controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
||||
|
||||
if controlnet.class_embedding:
|
||||
controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
|
||||
|
||||
controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)
|
||||
controlnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)
|
||||
|
||||
return controlnet
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
"""
|
||||
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnAddedKVProcessor()
|
||||
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnProcessor()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
||||
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
||||
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
||||
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
||||
must be a multiple of `slice_size`.
|
||||
"""
|
||||
sliceable_head_dims = []
|
||||
|
||||
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
sliceable_head_dims.append(module.sliceable_head_dim)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_retrieve_sliceable_dims(child)
|
||||
|
||||
# retrieve number of attention layers
|
||||
for module in self.children():
|
||||
fn_recursive_retrieve_sliceable_dims(module)
|
||||
|
||||
num_sliceable_layers = len(sliceable_head_dims)
|
||||
|
||||
if slice_size == "auto":
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
||||
elif slice_size == "max":
|
||||
# make smallest slice possible
|
||||
slice_size = num_sliceable_layers * [1]
|
||||
|
||||
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
||||
|
||||
if len(slice_size) != len(sliceable_head_dims):
|
||||
raise ValueError(
|
||||
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
||||
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
||||
)
|
||||
|
||||
for i in range(len(slice_size)):
|
||||
size = slice_size[i]
|
||||
dim = sliceable_head_dims[i]
|
||||
if size is not None and size > dim:
|
||||
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
||||
|
||||
# Recursively walk through all the children.
|
||||
# Any children which exposes the set_attention_slice method
|
||||
# gets the message
|
||||
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
module.set_attention_slice(slice_size.pop())
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_set_attention_slice(child, slice_size)
|
||||
|
||||
reversed_slice_size = list(reversed(slice_size))
|
||||
for module in self.children():
|
||||
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
||||
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
controlnet_cond: List[torch.Tensor],
|
||||
control_type: torch.Tensor,
|
||||
control_type_idx: List[int],
|
||||
conditioning_scale: float = 1.0,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
|
||||
"""
|
||||
The [`ControlNetUnionModel`] forward method.
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor`):
|
||||
The noisy input tensor.
|
||||
timestep (`Union[torch.Tensor, float, int]`):
|
||||
The number of timesteps to denoise an input.
|
||||
encoder_hidden_states (`torch.Tensor`):
|
||||
The encoder hidden states.
|
||||
controlnet_cond (`List[torch.Tensor]`):
|
||||
The conditional input tensors.
|
||||
control_type (`torch.Tensor`):
|
||||
A tensor of shape `(batch, num_control_type)` with values `0` or `1` depending on whether the control
|
||||
type is used.
|
||||
control_type_idx (`List[int]`):
|
||||
The indices of `control_type`.
|
||||
conditioning_scale (`float`, defaults to `1.0`):
|
||||
The scale factor for ControlNet outputs.
|
||||
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
||||
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
|
||||
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
|
||||
embeddings.
|
||||
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||
negative values to the attention scores corresponding to "discard" tokens.
|
||||
added_cond_kwargs (`dict`):
|
||||
Additional conditions for the Stable Diffusion XL UNet.
|
||||
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
|
||||
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
||||
guess_mode (`bool`, defaults to `False`):
|
||||
In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
|
||||
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
||||
return_dict (`bool`, defaults to `True`):
|
||||
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.controlnet.ControlNetOutput`] **or** `tuple`:
|
||||
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
|
||||
returned where the first element is the sample tensor.
|
||||
"""
|
||||
# check channel order
|
||||
channel_order = self.config.controlnet_conditioning_channel_order
|
||||
|
||||
if channel_order != "rgb":
|
||||
raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
|
||||
|
||||
# prepare attention_mask
|
||||
if attention_mask is not None:
|
||||
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
# This would be a good case for the `match` statement (Python 3.10+)
|
||||
is_mps = sample.device.type == "mps"
|
||||
if isinstance(timestep, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
||||
elif len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=sample.dtype)
|
||||
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
aug_emb = None
|
||||
|
||||
if self.class_embedding is not None:
|
||||
if class_labels is None:
|
||||
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
||||
|
||||
if self.config.class_embed_type == "timestep":
|
||||
class_labels = self.time_proj(class_labels)
|
||||
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
emb = emb + class_emb
|
||||
|
||||
if self.config.addition_embed_type is not None:
|
||||
if self.config.addition_embed_type == "text":
|
||||
aug_emb = self.add_embedding(encoder_hidden_states)
|
||||
|
||||
elif self.config.addition_embed_type == "text_time":
|
||||
if "text_embeds" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
text_embeds = added_cond_kwargs.get("text_embeds")
|
||||
if "time_ids" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
time_ids = added_cond_kwargs.get("time_ids")
|
||||
time_embeds = self.add_time_proj(time_ids.flatten())
|
||||
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
||||
|
||||
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
||||
add_embeds = add_embeds.to(emb.dtype)
|
||||
aug_emb = self.add_embedding(add_embeds)
|
||||
|
||||
control_embeds = self.control_type_proj(control_type.flatten())
|
||||
control_embeds = control_embeds.reshape((t_emb.shape[0], -1))
|
||||
control_embeds = control_embeds.to(emb.dtype)
|
||||
control_emb = self.control_add_embedding(control_embeds)
|
||||
emb = emb + control_emb
|
||||
emb = emb + aug_emb if aug_emb is not None else emb
|
||||
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
inputs = []
|
||||
condition_list = []
|
||||
|
||||
for cond, control_idx in zip(controlnet_cond, control_type_idx):
|
||||
condition = self.controlnet_cond_embedding(cond)
|
||||
feat_seq = torch.mean(condition, dim=(2, 3))
|
||||
feat_seq = feat_seq + self.task_embedding[control_idx]
|
||||
inputs.append(feat_seq.unsqueeze(1))
|
||||
condition_list.append(condition)
|
||||
|
||||
condition = sample
|
||||
feat_seq = torch.mean(condition, dim=(2, 3))
|
||||
inputs.append(feat_seq.unsqueeze(1))
|
||||
condition_list.append(condition)
|
||||
|
||||
x = torch.cat(inputs, dim=1)
|
||||
for layer in self.transformer_layes:
|
||||
x = layer(x)
|
||||
|
||||
controlnet_cond_fuser = sample * 0.0
|
||||
for idx, condition in enumerate(condition_list[:-1]):
|
||||
alpha = self.spatial_ch_projs(x[:, idx])
|
||||
alpha = alpha.unsqueeze(-1).unsqueeze(-1)
|
||||
controlnet_cond_fuser += condition + alpha
|
||||
|
||||
sample = sample + controlnet_cond_fuser
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. mid
|
||||
if self.mid_block is not None:
|
||||
sample = self.mid_block(
|
||||
sample,
|
||||
emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
|
||||
# 5. Control net blocks
|
||||
controlnet_down_block_res_samples = ()
|
||||
|
||||
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
|
||||
down_block_res_sample = controlnet_block(down_block_res_sample)
|
||||
controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
|
||||
|
||||
down_block_res_samples = controlnet_down_block_res_samples
|
||||
|
||||
mid_block_res_sample = self.controlnet_mid_block(sample)
|
||||
|
||||
# 6. scaling
|
||||
if guess_mode and not self.config.global_pool_conditions:
|
||||
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
||||
scales = scales * conditioning_scale
|
||||
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
||||
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
||||
else:
|
||||
down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
|
||||
mid_block_res_sample = mid_block_res_sample * conditioning_scale
|
||||
|
||||
if self.config.global_pool_conditions:
|
||||
down_block_res_samples = [
|
||||
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
|
||||
]
|
||||
mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
|
||||
|
||||
if not return_dict:
|
||||
return (down_block_res_samples, mid_block_res_sample)
|
||||
|
||||
return ControlNetOutput(
|
||||
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
|
||||
)
|
||||
@@ -84,6 +84,78 @@ def get_3d_sincos_pos_embed(
|
||||
temporal_size: int,
|
||||
spatial_interpolation_scale: float = 1.0,
|
||||
temporal_interpolation_scale: float = 1.0,
|
||||
device: Optional[torch.device] = None,
|
||||
output_type: str = "np",
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Creates 3D sinusoidal positional embeddings.
|
||||
|
||||
Args:
|
||||
embed_dim (`int`):
|
||||
The embedding dimension of inputs. It must be divisible by 16.
|
||||
spatial_size (`int` or `Tuple[int, int]`):
|
||||
The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
|
||||
spatial dimensions (height and width).
|
||||
temporal_size (`int`):
|
||||
The temporal dimension of postional embeddings (number of frames).
|
||||
spatial_interpolation_scale (`float`, defaults to 1.0):
|
||||
Scale factor for spatial grid interpolation.
|
||||
temporal_interpolation_scale (`float`, defaults to 1.0):
|
||||
Scale factor for temporal grid interpolation.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1],
|
||||
embed_dim]`.
|
||||
"""
|
||||
if output_type == "np":
|
||||
return _get_3d_sincos_pos_embed_np(
|
||||
embed_dim=embed_dim,
|
||||
spatial_size=spatial_size,
|
||||
temporal_size=temporal_size,
|
||||
spatial_interpolation_scale=spatial_interpolation_scale,
|
||||
temporal_interpolation_scale=temporal_interpolation_scale,
|
||||
)
|
||||
if embed_dim % 4 != 0:
|
||||
raise ValueError("`embed_dim` must be divisible by 4")
|
||||
if isinstance(spatial_size, int):
|
||||
spatial_size = (spatial_size, spatial_size)
|
||||
|
||||
embed_dim_spatial = 3 * embed_dim // 4
|
||||
embed_dim_temporal = embed_dim // 4
|
||||
|
||||
# 1. Spatial
|
||||
grid_h = torch.arange(spatial_size[1], device=device, dtype=torch.float32) / spatial_interpolation_scale
|
||||
grid_w = torch.arange(spatial_size[0], device=device, dtype=torch.float32) / spatial_interpolation_scale
|
||||
grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first
|
||||
grid = torch.stack(grid, dim=0)
|
||||
|
||||
grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]])
|
||||
pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid, output_type="pt")
|
||||
|
||||
# 2. Temporal
|
||||
grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32) / temporal_interpolation_scale
|
||||
pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t, output_type="pt")
|
||||
|
||||
# 3. Concat
|
||||
pos_embed_spatial = pos_embed_spatial[None, :, :]
|
||||
pos_embed_spatial = pos_embed_spatial.repeat_interleave(temporal_size, dim=0) # [T, H*W, D // 4 * 3]
|
||||
|
||||
pos_embed_temporal = pos_embed_temporal[:, None, :]
|
||||
pos_embed_temporal = pos_embed_temporal.repeat_interleave(
|
||||
spatial_size[0] * spatial_size[1], dim=1
|
||||
) # [T, H*W, D // 4]
|
||||
|
||||
pos_embed = torch.concat([pos_embed_temporal, pos_embed_spatial], dim=-1) # [T, H*W, D]
|
||||
return pos_embed
|
||||
|
||||
|
||||
def _get_3d_sincos_pos_embed_np(
|
||||
embed_dim: int,
|
||||
spatial_size: Union[int, Tuple[int, int]],
|
||||
temporal_size: int,
|
||||
spatial_interpolation_scale: float = 1.0,
|
||||
temporal_interpolation_scale: float = 1.0,
|
||||
) -> np.ndarray:
|
||||
r"""
|
||||
Creates 3D sinusoidal positional embeddings.
|
||||
@@ -106,6 +178,12 @@ def get_3d_sincos_pos_embed(
|
||||
The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1],
|
||||
embed_dim]`.
|
||||
"""
|
||||
deprecation_message = (
|
||||
"`get_3d_sincos_pos_embed` uses `torch` and supports `device`."
|
||||
" `from_numpy` is no longer required."
|
||||
" Pass `output_type='pt' to use the new version now."
|
||||
)
|
||||
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
|
||||
if embed_dim % 4 != 0:
|
||||
raise ValueError("`embed_dim` must be divisible by 4")
|
||||
if isinstance(spatial_size, int):
|
||||
@@ -139,6 +217,143 @@ def get_3d_sincos_pos_embed(
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed(
|
||||
embed_dim,
|
||||
grid_size,
|
||||
cls_token=False,
|
||||
extra_tokens=0,
|
||||
interpolation_scale=1.0,
|
||||
base_size=16,
|
||||
device: Optional[torch.device] = None,
|
||||
output_type: str = "np",
|
||||
):
|
||||
"""
|
||||
Creates 2D sinusoidal positional embeddings.
|
||||
|
||||
Args:
|
||||
embed_dim (`int`):
|
||||
The embedding dimension.
|
||||
grid_size (`int`):
|
||||
The size of the grid height and width.
|
||||
cls_token (`bool`, defaults to `False`):
|
||||
Whether or not to add a classification token.
|
||||
extra_tokens (`int`, defaults to `0`):
|
||||
The number of extra tokens to add.
|
||||
interpolation_scale (`float`, defaults to `1.0`):
|
||||
The scale of the interpolation.
|
||||
|
||||
Returns:
|
||||
pos_embed (`torch.Tensor`):
|
||||
Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size,
|
||||
embed_dim]` if using cls_token
|
||||
"""
|
||||
if output_type == "np":
|
||||
deprecation_message = (
|
||||
"`get_2d_sincos_pos_embed` uses `torch` and supports `device`."
|
||||
" `from_numpy` is no longer required."
|
||||
" Pass `output_type='pt' to use the new version now."
|
||||
)
|
||||
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
|
||||
return get_2d_sincos_pos_embed_np(
|
||||
embed_dim=embed_dim,
|
||||
grid_size=grid_size,
|
||||
cls_token=cls_token,
|
||||
extra_tokens=extra_tokens,
|
||||
interpolation_scale=interpolation_scale,
|
||||
base_size=base_size,
|
||||
)
|
||||
if isinstance(grid_size, int):
|
||||
grid_size = (grid_size, grid_size)
|
||||
|
||||
grid_h = (
|
||||
torch.arange(grid_size[0], device=device, dtype=torch.float32)
|
||||
/ (grid_size[0] / base_size)
|
||||
/ interpolation_scale
|
||||
)
|
||||
grid_w = (
|
||||
torch.arange(grid_size[1], device=device, dtype=torch.float32)
|
||||
/ (grid_size[1] / base_size)
|
||||
/ interpolation_scale
|
||||
)
|
||||
grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first
|
||||
grid = torch.stack(grid, dim=0)
|
||||
|
||||
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type=output_type)
|
||||
if cls_token and extra_tokens > 0:
|
||||
pos_embed = torch.concat([torch.zeros([extra_tokens, embed_dim]), pos_embed], dim=0)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type="np"):
|
||||
r"""
|
||||
This function generates 2D sinusoidal positional embeddings from a grid.
|
||||
|
||||
Args:
|
||||
embed_dim (`int`): The embedding dimension.
|
||||
grid (`torch.Tensor`): Grid of positions with shape `(H * W,)`.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
|
||||
"""
|
||||
if output_type == "np":
|
||||
deprecation_message = (
|
||||
"`get_2d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
|
||||
" `from_numpy` is no longer required."
|
||||
" Pass `output_type='pt' to use the new version now."
|
||||
)
|
||||
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
|
||||
return get_2d_sincos_pos_embed_from_grid_np(
|
||||
embed_dim=embed_dim,
|
||||
grid=grid,
|
||||
)
|
||||
if embed_dim % 2 != 0:
|
||||
raise ValueError("embed_dim must be divisible by 2")
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0], output_type=output_type) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1], output_type=output_type) # (H*W, D/2)
|
||||
|
||||
emb = torch.concat([emb_h, emb_w], dim=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
|
||||
"""
|
||||
This function generates 1D positional embeddings from a grid.
|
||||
|
||||
Args:
|
||||
embed_dim (`int`): The embedding dimension `D`
|
||||
pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)`
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
|
||||
"""
|
||||
if output_type == "np":
|
||||
deprecation_message = (
|
||||
"`get_1d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
|
||||
" `from_numpy` is no longer required."
|
||||
" Pass `output_type='pt' to use the new version now."
|
||||
)
|
||||
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
|
||||
return get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos)
|
||||
if embed_dim % 2 != 0:
|
||||
raise ValueError("embed_dim must be divisible by 2")
|
||||
|
||||
omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64)
|
||||
omega /= embed_dim / 2.0
|
||||
omega = 1.0 / 10000**omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = torch.outer(pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = torch.sin(out) # (M, D/2)
|
||||
emb_cos = torch.cos(out) # (M, D/2)
|
||||
|
||||
emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_np(
|
||||
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
|
||||
):
|
||||
"""
|
||||
@@ -170,13 +385,13 @@ def get_2d_sincos_pos_embed(
|
||||
grid = np.stack(grid, axis=0)
|
||||
|
||||
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid_np(embed_dim, grid)
|
||||
if cls_token and extra_tokens > 0:
|
||||
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
def get_2d_sincos_pos_embed_from_grid_np(embed_dim, grid):
|
||||
r"""
|
||||
This function generates 2D sinusoidal positional embeddings from a grid.
|
||||
|
||||
@@ -191,14 +406,14 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
raise ValueError("embed_dim must be divisible by 2")
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid_np(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid_np(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
def get_1d_sincos_pos_embed_from_grid_np(embed_dim, pos):
|
||||
"""
|
||||
This function generates 1D positional embeddings from a grid.
|
||||
|
||||
@@ -288,10 +503,14 @@ class PatchEmbed(nn.Module):
|
||||
self.pos_embed = None
|
||||
elif pos_embed_type == "sincos":
|
||||
pos_embed = get_2d_sincos_pos_embed(
|
||||
embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale
|
||||
embed_dim,
|
||||
grid_size,
|
||||
base_size=self.base_size,
|
||||
interpolation_scale=self.interpolation_scale,
|
||||
output_type="pt",
|
||||
)
|
||||
persistent = True if pos_embed_max_size else False
|
||||
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent)
|
||||
self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=persistent)
|
||||
else:
|
||||
raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}")
|
||||
|
||||
@@ -323,7 +542,6 @@ class PatchEmbed(nn.Module):
|
||||
height, width = latent.shape[-2:]
|
||||
else:
|
||||
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
|
||||
|
||||
latent = self.proj(latent)
|
||||
if self.flatten:
|
||||
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
@@ -341,8 +559,10 @@ class PatchEmbed(nn.Module):
|
||||
grid_size=(height, width),
|
||||
base_size=self.base_size,
|
||||
interpolation_scale=self.interpolation_scale,
|
||||
device=latent.device,
|
||||
output_type="pt",
|
||||
)
|
||||
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device)
|
||||
pos_embed = pos_embed.float().unsqueeze(0)
|
||||
else:
|
||||
pos_embed = self.pos_embed
|
||||
|
||||
@@ -453,7 +673,9 @@ class CogVideoXPatchEmbed(nn.Module):
|
||||
pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
|
||||
self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
|
||||
|
||||
def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
|
||||
def _get_positional_embeddings(
|
||||
self, sample_height: int, sample_width: int, sample_frames: int, device: Optional[torch.device] = None
|
||||
) -> torch.Tensor:
|
||||
post_patch_height = sample_height // self.patch_size
|
||||
post_patch_width = sample_width // self.patch_size
|
||||
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
|
||||
@@ -465,9 +687,11 @@ class CogVideoXPatchEmbed(nn.Module):
|
||||
post_time_compression_frames,
|
||||
self.spatial_interpolation_scale,
|
||||
self.temporal_interpolation_scale,
|
||||
device=device,
|
||||
output_type="pt",
|
||||
)
|
||||
pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
|
||||
joint_pos_embedding = torch.zeros(
|
||||
pos_embedding = pos_embedding.flatten(0, 1)
|
||||
joint_pos_embedding = pos_embedding.new_zeros(
|
||||
1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
|
||||
)
|
||||
joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
|
||||
@@ -521,8 +745,10 @@ class CogVideoXPatchEmbed(nn.Module):
|
||||
or self.sample_width != width
|
||||
or self.sample_frames != pre_time_compression_frames
|
||||
):
|
||||
pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames)
|
||||
pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
|
||||
pos_embedding = self._get_positional_embeddings(
|
||||
height, width, pre_time_compression_frames, device=embeds.device
|
||||
)
|
||||
pos_embedding = pos_embedding.to(dtype=embeds.dtype)
|
||||
else:
|
||||
pos_embedding = self.pos_embedding
|
||||
|
||||
@@ -552,9 +778,11 @@ class CogView3PlusPatchEmbed(nn.Module):
|
||||
# Linear projection for text embeddings
|
||||
self.text_proj = nn.Linear(text_hidden_size, hidden_size)
|
||||
|
||||
pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, base_size=pos_embed_max_size)
|
||||
pos_embed = get_2d_sincos_pos_embed(
|
||||
hidden_size, pos_embed_max_size, base_size=pos_embed_max_size, output_type="pt"
|
||||
)
|
||||
pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size)
|
||||
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float(), persistent=False)
|
||||
self.register_buffer("pos_embed", pos_embed.float(), persistent=False)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
@@ -594,6 +822,7 @@ def get_3d_rotary_pos_embed(
|
||||
use_real: bool = True,
|
||||
grid_type: str = "linspace",
|
||||
max_size: Optional[Tuple[int, int]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
RoPE for video tokens with 3D structure.
|
||||
@@ -621,16 +850,22 @@ def get_3d_rotary_pos_embed(
|
||||
if grid_type == "linspace":
|
||||
start, stop = crops_coords
|
||||
grid_size_h, grid_size_w = grid_size
|
||||
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
|
||||
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
|
||||
grid_t = np.arange(temporal_size, dtype=np.float32)
|
||||
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
|
||||
grid_h = torch.linspace(
|
||||
start[0], stop[0] * (grid_size_h - 1) / grid_size_h, grid_size_h, device=device, dtype=torch.float32
|
||||
)
|
||||
grid_w = torch.linspace(
|
||||
start[1], stop[1] * (grid_size_w - 1) / grid_size_w, grid_size_w, device=device, dtype=torch.float32
|
||||
)
|
||||
grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32)
|
||||
grid_t = torch.linspace(
|
||||
0, temporal_size * (temporal_size - 1) / temporal_size, temporal_size, device=device, dtype=torch.float32
|
||||
)
|
||||
elif grid_type == "slice":
|
||||
max_h, max_w = max_size
|
||||
grid_size_h, grid_size_w = grid_size
|
||||
grid_h = np.arange(max_h, dtype=np.float32)
|
||||
grid_w = np.arange(max_w, dtype=np.float32)
|
||||
grid_t = np.arange(temporal_size, dtype=np.float32)
|
||||
grid_h = torch.arange(max_h, device=device, dtype=torch.float32)
|
||||
grid_w = torch.arange(max_w, device=device, dtype=torch.float32)
|
||||
grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32)
|
||||
else:
|
||||
raise ValueError("Invalid value passed for `grid_type`.")
|
||||
|
||||
@@ -640,10 +875,10 @@ def get_3d_rotary_pos_embed(
|
||||
dim_w = embed_dim // 8 * 3
|
||||
|
||||
# Temporal frequencies
|
||||
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
|
||||
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, theta=theta, use_real=True)
|
||||
# Spatial frequencies for height and width
|
||||
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
|
||||
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
|
||||
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, theta=theta, use_real=True)
|
||||
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, theta=theta, use_real=True)
|
||||
|
||||
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
|
||||
def combine_time_height_width(freqs_t, freqs_h, freqs_w):
|
||||
@@ -686,14 +921,21 @@ def get_3d_rotary_pos_embed_allegro(
|
||||
temporal_size,
|
||||
interpolation_scale: Tuple[float, float, float] = (1.0, 1.0, 1.0),
|
||||
theta: int = 10000,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
# TODO(aryan): docs
|
||||
start, stop = crops_coords
|
||||
grid_size_h, grid_size_w = grid_size
|
||||
interpolation_scale_t, interpolation_scale_h, interpolation_scale_w = interpolation_scale
|
||||
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
|
||||
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
|
||||
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
|
||||
grid_t = torch.linspace(
|
||||
0, temporal_size * (temporal_size - 1) / temporal_size, temporal_size, device=device, dtype=torch.float32
|
||||
)
|
||||
grid_h = torch.linspace(
|
||||
start[0], stop[0] * (grid_size_h - 1) / grid_size_h, grid_size_h, device=device, dtype=torch.float32
|
||||
)
|
||||
grid_w = torch.linspace(
|
||||
start[1], stop[1] * (grid_size_w - 1) / grid_size_w, grid_size_w, device=device, dtype=torch.float32
|
||||
)
|
||||
|
||||
# Compute dimensions for each axis
|
||||
dim_t = embed_dim // 3
|
||||
@@ -715,7 +957,57 @@ def get_3d_rotary_pos_embed_allegro(
|
||||
return freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w
|
||||
|
||||
|
||||
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
|
||||
def get_2d_rotary_pos_embed(
|
||||
embed_dim, crops_coords, grid_size, use_real=True, device: Optional[torch.device] = None, output_type: str = "np"
|
||||
):
|
||||
"""
|
||||
RoPE for image tokens with 2d structure.
|
||||
|
||||
Args:
|
||||
embed_dim: (`int`):
|
||||
The embedding dimension size
|
||||
crops_coords (`Tuple[int]`)
|
||||
The top-left and bottom-right coordinates of the crop.
|
||||
grid_size (`Tuple[int]`):
|
||||
The grid size of the positional embedding.
|
||||
use_real (`bool`):
|
||||
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||
device: (`torch.device`, **optional**):
|
||||
The device used to create tensors.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
|
||||
"""
|
||||
if output_type == "np":
|
||||
deprecation_message = (
|
||||
"`get_2d_sincos_pos_embed` uses `torch` and supports `device`."
|
||||
" `from_numpy` is no longer required."
|
||||
" Pass `output_type='pt' to use the new version now."
|
||||
)
|
||||
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
|
||||
return _get_2d_rotary_pos_embed_np(
|
||||
embed_dim=embed_dim,
|
||||
crops_coords=crops_coords,
|
||||
grid_size=grid_size,
|
||||
use_real=use_real,
|
||||
)
|
||||
start, stop = crops_coords
|
||||
# scale end by (steps−1)/steps matches np.linspace(..., endpoint=False)
|
||||
grid_h = torch.linspace(
|
||||
start[0], stop[0] * (grid_size[0] - 1) / grid_size[0], grid_size[0], device=device, dtype=torch.float32
|
||||
)
|
||||
grid_w = torch.linspace(
|
||||
start[1], stop[1] * (grid_size[1] - 1) / grid_size[1], grid_size[1], device=device, dtype=torch.float32
|
||||
)
|
||||
grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
|
||||
grid = torch.stack(grid, dim=0) # [2, W, H]
|
||||
|
||||
grid = grid.reshape([2, 1, *grid.shape[1:]])
|
||||
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def _get_2d_rotary_pos_embed_np(embed_dim, crops_coords, grid_size, use_real=True):
|
||||
"""
|
||||
RoPE for image tokens with 2d structure.
|
||||
|
||||
@@ -959,7 +1251,12 @@ class FluxPosEmbed(nn.Module):
|
||||
freqs_dtype = torch.float32 if is_mps else torch.float64
|
||||
for i in range(n_axes):
|
||||
cos, sin = get_1d_rotary_pos_embed(
|
||||
self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype
|
||||
self.axes_dim[i],
|
||||
pos[:, i],
|
||||
theta=self.theta,
|
||||
repeat_interleave_real=True,
|
||||
use_real=True,
|
||||
freqs_dtype=freqs_dtype,
|
||||
)
|
||||
cos_out.append(cos)
|
||||
sin_out.append(sin)
|
||||
@@ -1238,7 +1535,7 @@ class ImageProjection(nn.Module):
|
||||
batch_size = image_embeds.shape[0]
|
||||
|
||||
# image
|
||||
image_embeds = self.image_embeds(image_embeds)
|
||||
image_embeds = self.image_embeds(image_embeds.to(self.image_embeds.weight.dtype))
|
||||
image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
|
||||
image_embeds = self.norm(image_embeds)
|
||||
return image_embeds
|
||||
@@ -2099,6 +2396,187 @@ class IPAdapterFaceIDPlusImageProjection(nn.Module):
|
||||
return out
|
||||
|
||||
|
||||
class IPAdapterTimeImageProjectionBlock(nn.Module):
|
||||
"""Block for IPAdapterTimeImageProjection.
|
||||
|
||||
Args:
|
||||
hidden_dim (`int`, defaults to 1280):
|
||||
The number of hidden channels.
|
||||
dim_head (`int`, defaults to 64):
|
||||
The number of head channels.
|
||||
heads (`int`, defaults to 20):
|
||||
Parallel attention heads.
|
||||
ffn_ratio (`int`, defaults to 4):
|
||||
The expansion ratio of feedforward network hidden layer channels.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim: int = 1280,
|
||||
dim_head: int = 64,
|
||||
heads: int = 20,
|
||||
ffn_ratio: int = 4,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
from .attention import FeedForward
|
||||
|
||||
self.ln0 = nn.LayerNorm(hidden_dim)
|
||||
self.ln1 = nn.LayerNorm(hidden_dim)
|
||||
self.attn = Attention(
|
||||
query_dim=hidden_dim,
|
||||
cross_attention_dim=hidden_dim,
|
||||
dim_head=dim_head,
|
||||
heads=heads,
|
||||
bias=False,
|
||||
out_bias=False,
|
||||
)
|
||||
self.ff = FeedForward(hidden_dim, hidden_dim, activation_fn="gelu", mult=ffn_ratio, bias=False)
|
||||
|
||||
# AdaLayerNorm
|
||||
self.adaln_silu = nn.SiLU()
|
||||
self.adaln_proj = nn.Linear(hidden_dim, 4 * hidden_dim)
|
||||
self.adaln_norm = nn.LayerNorm(hidden_dim)
|
||||
|
||||
# Set attention scale and fuse KV
|
||||
self.attn.scale = 1 / math.sqrt(math.sqrt(dim_head))
|
||||
self.attn.fuse_projections()
|
||||
self.attn.to_k = None
|
||||
self.attn.to_v = None
|
||||
|
||||
def forward(self, x: torch.Tensor, latents: torch.Tensor, timestep_emb: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`):
|
||||
Image features.
|
||||
latents (`torch.Tensor`):
|
||||
Latent features.
|
||||
timestep_emb (`torch.Tensor`):
|
||||
Timestep embedding.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: Output latent features.
|
||||
"""
|
||||
|
||||
# Shift and scale for AdaLayerNorm
|
||||
emb = self.adaln_proj(self.adaln_silu(timestep_emb))
|
||||
shift_msa, scale_msa, shift_mlp, scale_mlp = emb.chunk(4, dim=1)
|
||||
|
||||
# Fused Attention
|
||||
residual = latents
|
||||
x = self.ln0(x)
|
||||
latents = self.ln1(latents) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||
|
||||
batch_size = latents.shape[0]
|
||||
|
||||
query = self.attn.to_q(latents)
|
||||
kv_input = torch.cat((x, latents), dim=-2)
|
||||
key, value = self.attn.to_kv(kv_input).chunk(2, dim=-1)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // self.attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
weight = (query * self.attn.scale) @ (key * self.attn.scale).transpose(-2, -1)
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
latents = weight @ value
|
||||
|
||||
latents = latents.transpose(1, 2).reshape(batch_size, -1, self.attn.heads * head_dim)
|
||||
latents = self.attn.to_out[0](latents)
|
||||
latents = self.attn.to_out[1](latents)
|
||||
latents = latents + residual
|
||||
|
||||
## FeedForward
|
||||
residual = latents
|
||||
latents = self.adaln_norm(latents) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
return self.ff(latents) + residual
|
||||
|
||||
|
||||
# Modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
||||
class IPAdapterTimeImageProjection(nn.Module):
|
||||
"""Resampler of SD3 IP-Adapter with timestep embedding.
|
||||
|
||||
Args:
|
||||
embed_dim (`int`, defaults to 1152):
|
||||
The feature dimension.
|
||||
output_dim (`int`, defaults to 2432):
|
||||
The number of output channels.
|
||||
hidden_dim (`int`, defaults to 1280):
|
||||
The number of hidden channels.
|
||||
depth (`int`, defaults to 4):
|
||||
The number of blocks.
|
||||
dim_head (`int`, defaults to 64):
|
||||
The number of head channels.
|
||||
heads (`int`, defaults to 20):
|
||||
Parallel attention heads.
|
||||
num_queries (`int`, defaults to 64):
|
||||
The number of queries.
|
||||
ffn_ratio (`int`, defaults to 4):
|
||||
The expansion ratio of feedforward network hidden layer channels.
|
||||
timestep_in_dim (`int`, defaults to 320):
|
||||
The number of input channels for timestep embedding.
|
||||
timestep_flip_sin_to_cos (`bool`, defaults to True):
|
||||
Flip the timestep embedding order to `cos, sin` (if True) or `sin, cos` (if False).
|
||||
timestep_freq_shift (`int`, defaults to 0):
|
||||
Controls the timestep delta between frequencies between dimensions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int = 1152,
|
||||
output_dim: int = 2432,
|
||||
hidden_dim: int = 1280,
|
||||
depth: int = 4,
|
||||
dim_head: int = 64,
|
||||
heads: int = 20,
|
||||
num_queries: int = 64,
|
||||
ffn_ratio: int = 4,
|
||||
timestep_in_dim: int = 320,
|
||||
timestep_flip_sin_to_cos: bool = True,
|
||||
timestep_freq_shift: int = 0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dim) / hidden_dim**0.5)
|
||||
self.proj_in = nn.Linear(embed_dim, hidden_dim)
|
||||
self.proj_out = nn.Linear(hidden_dim, output_dim)
|
||||
self.norm_out = nn.LayerNorm(output_dim)
|
||||
self.layers = nn.ModuleList(
|
||||
[IPAdapterTimeImageProjectionBlock(hidden_dim, dim_head, heads, ffn_ratio) for _ in range(depth)]
|
||||
)
|
||||
self.time_proj = Timesteps(timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift)
|
||||
self.time_embedding = TimestepEmbedding(timestep_in_dim, hidden_dim, act_fn="silu")
|
||||
|
||||
def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward pass.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`):
|
||||
Image features.
|
||||
timestep (`torch.Tensor`):
|
||||
Timestep in denoising process.
|
||||
Returns:
|
||||
`Tuple`[`torch.Tensor`, `torch.Tensor`]: The pair (latents, timestep_emb).
|
||||
"""
|
||||
timestep_emb = self.time_proj(timestep).to(dtype=x.dtype)
|
||||
timestep_emb = self.time_embedding(timestep_emb)
|
||||
|
||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||
|
||||
x = self.proj_in(x)
|
||||
x = x + timestep_emb[:, None]
|
||||
|
||||
for block in self.layers:
|
||||
latents = block(x, latents, timestep_emb)
|
||||
|
||||
latents = self.proj_out(latents)
|
||||
latents = self.norm_out(latents)
|
||||
|
||||
return latents, timestep_emb
|
||||
|
||||
|
||||
class MultiIPAdapterImageProjection(nn.Module):
|
||||
def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
|
||||
super().__init__()
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
from array import array
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
@@ -25,8 +26,8 @@ import safetensors
|
||||
import torch
|
||||
from huggingface_hub.utils import EntryNotFoundError
|
||||
|
||||
from ..quantizers.quantization_config import QuantizationMethod
|
||||
from ..utils import (
|
||||
GGUF_FILE_EXTENSION,
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
SAFETENSORS_FILE_EXTENSION,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
@@ -34,6 +35,8 @@ from ..utils import (
|
||||
_get_model_file,
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_gguf_available,
|
||||
is_torch_available,
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
@@ -128,7 +131,7 @@ def _fetch_remapped_cls_from_config(config, old_class):
|
||||
return old_class
|
||||
|
||||
|
||||
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None, dduf_reader=None):
|
||||
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
|
||||
"""
|
||||
Reads a checkpoint file, returning properly formatted errors if they arise.
|
||||
"""
|
||||
@@ -138,15 +141,10 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
|
||||
return checkpoint_file
|
||||
try:
|
||||
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
|
||||
if dduf_reader:
|
||||
checkpoint_file = dduf_reader.read_file(checkpoint_file)
|
||||
if file_extension == SAFETENSORS_FILE_EXTENSION:
|
||||
if dduf_reader:
|
||||
# tensors are loaded on cpu
|
||||
return safetensors.torch.load(checkpoint_file)
|
||||
else:
|
||||
return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
||||
|
||||
return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
||||
elif file_extension == GGUF_FILE_EXTENSION:
|
||||
return load_gguf_checkpoint(checkpoint_file)
|
||||
else:
|
||||
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
|
||||
return torch.load(
|
||||
@@ -183,11 +181,12 @@ def load_model_dict_into_meta(
|
||||
hf_quantizer=None,
|
||||
keep_in_fp32_modules=None,
|
||||
) -> List[str]:
|
||||
if device is not None and not isinstance(device, (str, torch.device)):
|
||||
raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.")
|
||||
if hf_quantizer is None:
|
||||
device = device or torch.device("cpu")
|
||||
dtype = dtype or torch.float32
|
||||
is_quantized = hf_quantizer is not None
|
||||
is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
|
||||
|
||||
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
|
||||
empty_state_dict = model.state_dict()
|
||||
@@ -218,14 +217,15 @@ def load_model_dict_into_meta(
|
||||
set_module_kwargs["dtype"] = dtype
|
||||
|
||||
# bnb params are flattened.
|
||||
# gguf quants have a different shape based on the type of quantization applied
|
||||
if empty_state_dict[param_name].shape != param.shape:
|
||||
if (
|
||||
is_quant_method_bnb
|
||||
is_quantized
|
||||
and hf_quantizer.pre_quantized
|
||||
and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
|
||||
):
|
||||
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name].shape, param.shape)
|
||||
elif not is_quant_method_bnb:
|
||||
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name], param)
|
||||
else:
|
||||
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
|
||||
raise ValueError(
|
||||
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
|
||||
@@ -279,7 +279,6 @@ def _fetch_index_file(
|
||||
revision,
|
||||
user_agent,
|
||||
commit_hash,
|
||||
dduf_reader=None,
|
||||
):
|
||||
if is_local:
|
||||
index_file = Path(
|
||||
@@ -305,7 +304,6 @@ def _fetch_index_file(
|
||||
subfolder=None,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
dduf_reader=dduf_reader,
|
||||
)
|
||||
index_file = Path(index_file)
|
||||
except (EntryNotFoundError, EnvironmentError):
|
||||
@@ -405,3 +403,78 @@ def _fetch_index_file_legacy(
|
||||
index_file = None
|
||||
|
||||
return index_file
|
||||
|
||||
|
||||
def _gguf_parse_value(_value, data_type):
|
||||
if not isinstance(data_type, list):
|
||||
data_type = [data_type]
|
||||
if len(data_type) == 1:
|
||||
data_type = data_type[0]
|
||||
array_data_type = None
|
||||
else:
|
||||
if data_type[0] != 9:
|
||||
raise ValueError("Received multiple types, therefore expected the first type to indicate an array.")
|
||||
data_type, array_data_type = data_type
|
||||
|
||||
if data_type in [0, 1, 2, 3, 4, 5, 10, 11]:
|
||||
_value = int(_value[0])
|
||||
elif data_type in [6, 12]:
|
||||
_value = float(_value[0])
|
||||
elif data_type in [7]:
|
||||
_value = bool(_value[0])
|
||||
elif data_type in [8]:
|
||||
_value = array("B", list(_value)).tobytes().decode()
|
||||
elif data_type in [9]:
|
||||
_value = _gguf_parse_value(_value, array_data_type)
|
||||
return _value
|
||||
|
||||
|
||||
def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
|
||||
"""
|
||||
Load a GGUF file and return a dictionary of parsed parameters containing tensors, the parsed tokenizer and config
|
||||
attributes.
|
||||
|
||||
Args:
|
||||
gguf_checkpoint_path (`str`):
|
||||
The path the to GGUF file to load
|
||||
return_tensors (`bool`, defaults to `True`):
|
||||
Whether to read the tensors from the file and return them. Not doing so is faster and only loads the
|
||||
metadata in memory.
|
||||
"""
|
||||
|
||||
if is_gguf_available() and is_torch_available():
|
||||
import gguf
|
||||
from gguf import GGUFReader
|
||||
|
||||
from ..quantizers.gguf.utils import SUPPORTED_GGUF_QUANT_TYPES, GGUFParameter
|
||||
else:
|
||||
logger.error(
|
||||
"Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF>=0.10.0 to be installed. Please see "
|
||||
"https://pytorch.org/ and https://github.com/ggerganov/llama.cpp/tree/master/gguf-py for installation instructions."
|
||||
)
|
||||
raise ImportError("Please install torch and gguf>=0.10.0 to load a GGUF checkpoint in PyTorch.")
|
||||
|
||||
reader = GGUFReader(gguf_checkpoint_path)
|
||||
|
||||
parsed_parameters = {}
|
||||
for tensor in reader.tensors:
|
||||
name = tensor.name
|
||||
quant_type = tensor.tensor_type
|
||||
|
||||
# if the tensor is a torch supported dtype do not use GGUFParameter
|
||||
is_gguf_quant = quant_type not in [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16]
|
||||
if is_gguf_quant and quant_type not in SUPPORTED_GGUF_QUANT_TYPES:
|
||||
_supported_quants_str = "\n".join([str(type) for type in SUPPORTED_GGUF_QUANT_TYPES])
|
||||
raise ValueError(
|
||||
(
|
||||
f"{name} has a quantization type: {str(quant_type)} which is unsupported."
|
||||
"\n\nCurrently the following quantization types are supported: \n\n"
|
||||
f"{_supported_quants_str}"
|
||||
"\n\nTo request support for this quantization type please open an issue here: https://github.com/huggingface/diffusers"
|
||||
)
|
||||
)
|
||||
|
||||
weights = torch.from_numpy(tensor.data.copy())
|
||||
parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights
|
||||
|
||||
return parsed_parameters
|
||||
|
||||
@@ -99,21 +99,39 @@ def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
|
||||
|
||||
|
||||
def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
|
||||
try:
|
||||
return next(parameter.parameters()).dtype
|
||||
except StopIteration:
|
||||
try:
|
||||
return next(parameter.buffers()).dtype
|
||||
except StopIteration:
|
||||
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
||||
"""
|
||||
Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
|
||||
"""
|
||||
last_dtype = None
|
||||
for param in parameter.parameters():
|
||||
last_dtype = param.dtype
|
||||
if param.is_floating_point():
|
||||
return param.dtype
|
||||
|
||||
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
||||
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
||||
return tuples
|
||||
for buffer in parameter.buffers():
|
||||
last_dtype = buffer.dtype
|
||||
if buffer.is_floating_point():
|
||||
return buffer.dtype
|
||||
|
||||
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
||||
first_tuple = next(gen)
|
||||
return first_tuple[1].dtype
|
||||
if last_dtype is not None:
|
||||
# if no floating dtype was found return whatever the first dtype is
|
||||
return last_dtype
|
||||
|
||||
# For nn.DataParallel compatibility in PyTorch > 1.5
|
||||
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
|
||||
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
||||
return tuples
|
||||
|
||||
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
||||
last_tuple = None
|
||||
for tuple in gen:
|
||||
last_tuple = tuple
|
||||
if tuple[1].is_floating_point():
|
||||
return tuple[1].dtype
|
||||
|
||||
if last_tuple is not None:
|
||||
# fallback to the last dtype
|
||||
return last_tuple[1].dtype
|
||||
|
||||
|
||||
class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
@@ -208,6 +226,35 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
"""
|
||||
self.set_use_npu_flash_attention(False)
|
||||
|
||||
def set_use_xla_flash_attention(
|
||||
self, use_xla_flash_attention: bool, partition_spec: Optional[Callable] = None
|
||||
) -> None:
|
||||
# Recursively walk through all the children.
|
||||
# Any children which exposes the set_use_xla_flash_attention method
|
||||
# gets the message
|
||||
def fn_recursive_set_flash_attention(module: torch.nn.Module):
|
||||
if hasattr(module, "set_use_xla_flash_attention"):
|
||||
module.set_use_xla_flash_attention(use_xla_flash_attention, partition_spec)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_set_flash_attention(child)
|
||||
|
||||
for module in self.children():
|
||||
if isinstance(module, torch.nn.Module):
|
||||
fn_recursive_set_flash_attention(module)
|
||||
|
||||
def enable_xla_flash_attention(self, partition_spec: Optional[Callable] = None):
|
||||
r"""
|
||||
Enable the flash attention pallals kernel for torch_xla.
|
||||
"""
|
||||
self.set_use_xla_flash_attention(True, partition_spec)
|
||||
|
||||
def disable_xla_flash_attention(self):
|
||||
r"""
|
||||
Disable the flash attention pallals kernel for torch_xla.
|
||||
"""
|
||||
self.set_use_xla_flash_attention(False)
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(
|
||||
self, valid: bool, attention_op: Optional[Callable] = None
|
||||
) -> None:
|
||||
@@ -557,7 +604,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
variant = kwargs.pop("variant", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
dduf_reader = kwargs.pop("dduf_reader", None)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
@@ -650,7 +696,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
dduf_reader=dduf_reader,
|
||||
**kwargs,
|
||||
)
|
||||
# no in-place modification of the original config.
|
||||
@@ -673,10 +718,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
hf_quantizer = None
|
||||
|
||||
if hf_quantizer is not None:
|
||||
if device_map is not None:
|
||||
is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes"
|
||||
if is_bnb_quantization_method and device_map is not None:
|
||||
raise NotImplementedError(
|
||||
"Currently, `device_map` is automatically inferred for quantized models. Support for providing `device_map` as an input will be added in the future."
|
||||
"Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future."
|
||||
)
|
||||
|
||||
hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)
|
||||
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
|
||||
|
||||
@@ -726,7 +773,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
"revision": revision,
|
||||
"user_agent": user_agent,
|
||||
"commit_hash": commit_hash,
|
||||
"dduf_reader": dduf_reader,
|
||||
}
|
||||
index_file = _fetch_index_file(**index_file_kwargs)
|
||||
# In case the index file was not found we still have to consider the legacy format.
|
||||
@@ -762,8 +808,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
|
||||
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
|
||||
else:
|
||||
# in the case it is sharded, we have already the index
|
||||
if is_sharded and not dduf_reader:
|
||||
if is_sharded:
|
||||
sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files(
|
||||
pretrained_model_name_or_path,
|
||||
index_file,
|
||||
@@ -775,7 +820,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
revision=revision,
|
||||
subfolder=subfolder or "",
|
||||
)
|
||||
if hf_quantizer is not None:
|
||||
if hf_quantizer is not None and is_bnb_quantization_method:
|
||||
model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
|
||||
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
|
||||
is_sharded = False
|
||||
@@ -794,7 +839,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
dduf_reader=dduf_reader,
|
||||
)
|
||||
|
||||
except IOError as e:
|
||||
@@ -818,7 +862,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
dduf_reader=dduf_reader,
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
@@ -835,15 +878,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
if device_map is None and not is_sharded:
|
||||
# `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None.
|
||||
# It would error out during the `validate_environment()` call above in the absence of cuda.
|
||||
is_quant_method_bnb = (
|
||||
getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
|
||||
)
|
||||
if hf_quantizer is None:
|
||||
param_device = "cpu"
|
||||
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
|
||||
elif is_quant_method_bnb:
|
||||
param_device = torch.cuda.current_device()
|
||||
state_dict = load_state_dict(model_file, variant=variant, dduf_reader=dduf_reader)
|
||||
else:
|
||||
param_device = torch.device(torch.cuda.current_device())
|
||||
state_dict = load_state_dict(model_file, variant=variant)
|
||||
model._convert_deprecated_attention_blocks(state_dict)
|
||||
|
||||
# move the params from meta device to cpu
|
||||
@@ -943,7 +983,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
else:
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
|
||||
state_dict = load_state_dict(model_file, variant=variant, dduf_reader=dduf_reader)
|
||||
state_dict = load_state_dict(model_file, variant=variant)
|
||||
model._convert_deprecated_attention_blocks(state_dict)
|
||||
|
||||
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
||||
@@ -1016,14 +1056,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
dtype_present_in_args = True
|
||||
break
|
||||
|
||||
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
|
||||
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
|
||||
if getattr(self, "is_quantized", False):
|
||||
if dtype_present_in_args:
|
||||
raise ValueError(
|
||||
"You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the"
|
||||
" desired `dtype` by passing the correct `torch_dtype` argument."
|
||||
"Casting a quantized model to a new `dtype` is unsupported. To set the dtype of unquantized layers, please "
|
||||
"use the `torch_dtype` argument when loading the model using `from_pretrained` or `from_single_file`"
|
||||
)
|
||||
|
||||
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
|
||||
if getattr(self, "is_loaded_in_8bit", False):
|
||||
raise ValueError(
|
||||
"`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
|
||||
|
||||
@@ -234,33 +234,6 @@ class LuminaRMSNormZero(nn.Module):
|
||||
return x, gate_msa, scale_mlp, gate_mlp
|
||||
|
||||
|
||||
class MochiRMSNormZero(nn.Module):
|
||||
r"""
|
||||
Adaptive RMS Norm used in Mochi.
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, embedding_dim: int, hidden_dim: int, eps: float = 1e-5, elementwise_affine: bool = False
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, hidden_dim)
|
||||
self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, emb: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
emb = self.linear(self.silu(emb))
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
|
||||
hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None])
|
||||
|
||||
return hidden_states, gate_msa, scale_mlp, gate_mlp
|
||||
|
||||
|
||||
class AdaLayerNormSingle(nn.Module):
|
||||
r"""
|
||||
Norm layer adaptive layer norm single (adaLN-single).
|
||||
@@ -512,6 +485,46 @@ else:
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool = False):
|
||||
super().__init__()
|
||||
|
||||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
|
||||
if isinstance(dim, numbers.Integral):
|
||||
dim = (dim,)
|
||||
|
||||
self.dim = torch.Size(dim)
|
||||
|
||||
self.weight = None
|
||||
self.bias = None
|
||||
|
||||
if elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(dim))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||
|
||||
if self.weight is not None:
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
hidden_states = hidden_states * self.weight
|
||||
if self.bias is not None:
|
||||
hidden_states = hidden_states + self.bias
|
||||
else:
|
||||
hidden_states = hidden_states.to(input_dtype)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# TODO: (Dhruv) This can be replaced with regular RMSNorm in Mochi once `_keep_in_fp32_modules` is supported
|
||||
# for sharded checkpoints, see: https://github.com/huggingface/diffusers/issues/10013
|
||||
class MochiRMSNorm(nn.Module):
|
||||
def __init__(self, dim, eps: float, elementwise_affine: bool = True):
|
||||
super().__init__()
|
||||
|
||||
@@ -533,12 +546,8 @@ class RMSNorm(nn.Module):
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||
|
||||
if self.weight is not None:
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
hidden_states = hidden_states * self.weight
|
||||
else:
|
||||
hidden_states = hidden_states.to(input_dtype)
|
||||
hidden_states = hidden_states.to(input_dtype)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -566,3 +575,21 @@ class LpNorm(nn.Module):
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
return F.normalize(hidden_states, p=self.p, dim=self.dim, eps=self.eps)
|
||||
|
||||
|
||||
def get_normalization(
|
||||
norm_type: str = "batch_norm",
|
||||
num_features: Optional[int] = None,
|
||||
eps: float = 1e-5,
|
||||
elementwise_affine: bool = True,
|
||||
bias: bool = True,
|
||||
) -> nn.Module:
|
||||
if norm_type == "rms_norm":
|
||||
norm = RMSNorm(num_features, eps=eps, elementwise_affine=elementwise_affine, bias=bias)
|
||||
elif norm_type == "layer_norm":
|
||||
norm = nn.LayerNorm(num_features, eps=eps, elementwise_affine=elementwise_affine, bias=bias)
|
||||
elif norm_type == "batch_norm":
|
||||
norm = nn.BatchNorm2d(num_features, eps=eps, affine=elementwise_affine)
|
||||
else:
|
||||
raise ValueError(f"{norm_type=} is not supported.")
|
||||
return norm
|
||||
|
||||
@@ -11,12 +11,15 @@ if is_torch_available():
|
||||
from .lumina_nextdit2d import LuminaNextDiT2DModel
|
||||
from .pixart_transformer_2d import PixArtTransformer2DModel
|
||||
from .prior_transformer import PriorTransformer
|
||||
from .sana_transformer import SanaTransformer2DModel
|
||||
from .stable_audio_transformer import StableAudioDiTModel
|
||||
from .t5_film_transformer import T5FilmDecoder
|
||||
from .transformer_2d import Transformer2DModel
|
||||
from .transformer_allegro import AllegroTransformer3DModel
|
||||
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
|
||||
from .transformer_flux import FluxTransformer2DModel
|
||||
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
|
||||
from .transformer_ltx import LTXVideoTransformer3DModel
|
||||
from .transformer_mochi import MochiTransformer3DModel
|
||||
from .transformer_sd3 import SD3Transformer2DModel
|
||||
from .transformer_temporal import TransformerTemporalModel
|
||||
|
||||
@@ -156,9 +156,9 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# define temporal positional embedding
|
||||
temp_pos_embed = get_1d_sincos_pos_embed_from_grid(
|
||||
inner_dim, torch.arange(0, video_length).unsqueeze(1)
|
||||
inner_dim, torch.arange(0, video_length).unsqueeze(1), output_type="pt"
|
||||
) # 1152 hidden size
|
||||
self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False)
|
||||
self.register_buffer("temp_pos_embed", temp_pos_embed.float().unsqueeze(0), persistent=False)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
|
||||
@@ -0,0 +1,487 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention_processor import (
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
AttnProcessor2_0,
|
||||
SanaLinearAttnProcessor2_0,
|
||||
)
|
||||
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormSingle, RMSNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class GLUMBConv(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
expand_ratio: float = 4,
|
||||
norm_type: Optional[str] = None,
|
||||
residual_connection: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_channels = int(expand_ratio * in_channels)
|
||||
self.norm_type = norm_type
|
||||
self.residual_connection = residual_connection
|
||||
|
||||
self.nonlinearity = nn.SiLU()
|
||||
self.conv_inverted = nn.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0)
|
||||
self.conv_depth = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2)
|
||||
self.conv_point = nn.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False)
|
||||
|
||||
self.norm = None
|
||||
if norm_type == "rms_norm":
|
||||
self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if self.residual_connection:
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.conv_inverted(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
hidden_states = self.conv_depth(hidden_states)
|
||||
hidden_states, gate = torch.chunk(hidden_states, 2, dim=1)
|
||||
hidden_states = hidden_states * self.nonlinearity(gate)
|
||||
|
||||
hidden_states = self.conv_point(hidden_states)
|
||||
|
||||
if self.norm_type == "rms_norm":
|
||||
# move channel to the last dimension so we apply RMSnorm across channel dimension
|
||||
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
|
||||
|
||||
if self.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SanaTransformerBlock(nn.Module):
|
||||
r"""
|
||||
Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int = 2240,
|
||||
num_attention_heads: int = 70,
|
||||
attention_head_dim: int = 32,
|
||||
dropout: float = 0.0,
|
||||
num_cross_attention_heads: Optional[int] = 20,
|
||||
cross_attention_head_dim: Optional[int] = 112,
|
||||
cross_attention_dim: Optional[int] = 2240,
|
||||
attention_bias: bool = True,
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-6,
|
||||
attention_out_bias: bool = True,
|
||||
mlp_ratio: float = 2.5,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# 1. Self Attention
|
||||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=norm_eps)
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=None,
|
||||
processor=SanaLinearAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
# 2. Cross Attention
|
||||
if cross_attention_dim is not None:
|
||||
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
heads=num_cross_attention_heads,
|
||||
dim_head=cross_attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=True,
|
||||
out_bias=attention_out_bias,
|
||||
processor=AttnProcessor2_0(),
|
||||
)
|
||||
|
||||
# 3. Feed-forward
|
||||
self.ff = GLUMBConv(dim, dim, mlp_ratio, norm_type=None, residual_connection=False)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
height: int = None,
|
||||
width: int = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
# 1. Modulation
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
|
||||
).chunk(6, dim=1)
|
||||
|
||||
# 2. Self Attention
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
||||
norm_hidden_states = norm_hidden_states.to(hidden_states.dtype)
|
||||
|
||||
attn_output = self.attn1(norm_hidden_states)
|
||||
hidden_states = hidden_states + gate_msa * attn_output
|
||||
|
||||
# 3. Cross Attention
|
||||
if self.attn2 is not None:
|
||||
attn_output = self.attn2(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=encoder_attention_mask,
|
||||
)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 4. Feed-forward
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
||||
|
||||
norm_hidden_states = norm_hidden_states.unflatten(1, (height, width)).permute(0, 3, 1, 2)
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
ff_output = ff_output.flatten(2, 3).permute(0, 2, 1)
|
||||
hidden_states = hidden_states + gate_mlp * ff_output
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
r"""
|
||||
A 2D Transformer model introduced in [Sana](https://huggingface.co/papers/2410.10629) family of models.
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to `32`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, *optional*, defaults to `32`):
|
||||
The number of channels in the output.
|
||||
num_attention_heads (`int`, defaults to `70`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, defaults to `32`):
|
||||
The number of channels in each head.
|
||||
num_layers (`int`, defaults to `20`):
|
||||
The number of layers of Transformer blocks to use.
|
||||
num_cross_attention_heads (`int`, *optional*, defaults to `20`):
|
||||
The number of heads to use for cross-attention.
|
||||
cross_attention_head_dim (`int`, *optional*, defaults to `112`):
|
||||
The number of channels in each head for cross-attention.
|
||||
cross_attention_dim (`int`, *optional*, defaults to `2240`):
|
||||
The number of channels in the cross-attention output.
|
||||
caption_channels (`int`, defaults to `2304`):
|
||||
The number of channels in the caption embeddings.
|
||||
mlp_ratio (`float`, defaults to `2.5`):
|
||||
The expansion ratio to use in the GLUMBConv layer.
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
The dropout probability.
|
||||
attention_bias (`bool`, defaults to `False`):
|
||||
Whether to use bias in the attention layer.
|
||||
sample_size (`int`, defaults to `32`):
|
||||
The base size of the input latent.
|
||||
patch_size (`int`, defaults to `1`):
|
||||
The size of the patches to use in the patch embedding layer.
|
||||
norm_elementwise_affine (`bool`, defaults to `False`):
|
||||
Whether to use elementwise affinity in the normalization layer.
|
||||
norm_eps (`float`, defaults to `1e-6`):
|
||||
The epsilon value for the normalization layer.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["SanaTransformerBlock", "PatchEmbed"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 32,
|
||||
out_channels: Optional[int] = 32,
|
||||
num_attention_heads: int = 70,
|
||||
attention_head_dim: int = 32,
|
||||
num_layers: int = 20,
|
||||
num_cross_attention_heads: Optional[int] = 20,
|
||||
cross_attention_head_dim: Optional[int] = 112,
|
||||
cross_attention_dim: Optional[int] = 2240,
|
||||
caption_channels: int = 2304,
|
||||
mlp_ratio: float = 2.5,
|
||||
dropout: float = 0.0,
|
||||
attention_bias: bool = False,
|
||||
sample_size: int = 32,
|
||||
patch_size: int = 1,
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-6,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
out_channels = out_channels or in_channels
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
# 1. Patch Embedding
|
||||
self.patch_embed = PatchEmbed(
|
||||
height=sample_size,
|
||||
width=sample_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=inner_dim,
|
||||
interpolation_scale=None,
|
||||
pos_embed_type=None,
|
||||
)
|
||||
|
||||
# 2. Additional condition embeddings
|
||||
self.time_embed = AdaLayerNormSingle(inner_dim)
|
||||
|
||||
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
||||
self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True)
|
||||
|
||||
# 3. Transformer blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
SanaTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
num_cross_attention_heads=num_cross_attention_heads,
|
||||
cross_attention_head_dim=cross_attention_head_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attention_bias=attention_bias,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
norm_eps=norm_eps,
|
||||
mlp_ratio=mlp_ratio,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Output blocks
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
|
||||
|
||||
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
||||
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
||||
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
||||
# expects mask of shape:
|
||||
# [batch, key_tokens]
|
||||
# adds singleton query_tokens dimension:
|
||||
# [batch, 1, key_tokens]
|
||||
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
||||
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
||||
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
||||
if attention_mask is not None and attention_mask.ndim == 2:
|
||||
# assume that mask is expressed as:
|
||||
# (1 = keep, 0 = discard)
|
||||
# convert mask into a bias that can be added to attention scores:
|
||||
# (keep = +0, discard = -10000.0)
|
||||
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
||||
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||
|
||||
# 1. Input
|
||||
batch_size, num_channels, height, width = hidden_states.shape
|
||||
p = self.config.patch_size
|
||||
post_patch_height, post_patch_width = height // p, width // p
|
||||
|
||||
hidden_states = self.patch_embed(hidden_states)
|
||||
|
||||
timestep, embedded_timestep = self.time_embed(
|
||||
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
||||
|
||||
encoder_hidden_states = self.caption_norm(encoder_hidden_states)
|
||||
|
||||
# 2. Transformer blocks
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
timestep,
|
||||
post_patch_height,
|
||||
post_patch_width,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
timestep,
|
||||
post_patch_height,
|
||||
post_patch_width,
|
||||
)
|
||||
|
||||
# 3. Normalization
|
||||
shift, scale = (
|
||||
self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
|
||||
).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
|
||||
# 4. Modulation
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# 5. Unpatchify
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, post_patch_height, post_patch_width, self.config.patch_size, self.config.patch_size, -1
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 5, 1, 3, 2, 4)
|
||||
output = hidden_states.reshape(batch_size, -1, post_patch_height * p, post_patch_width * p)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -21,7 +21,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...models.attention import FeedForward
|
||||
from ...models.attention_processor import (
|
||||
Attention,
|
||||
@@ -177,13 +177,18 @@ class FluxTransformerBlock(nn.Module):
|
||||
)
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
# Attention.
|
||||
attn_output, context_attn_output = self.attn(
|
||||
attention_outputs = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
**joint_attention_kwargs,
|
||||
)
|
||||
|
||||
if len(attention_outputs) == 2:
|
||||
attn_output, context_attn_output = attention_outputs
|
||||
elif len(attention_outputs) == 3:
|
||||
attn_output, context_attn_output, ip_attn_output = attention_outputs
|
||||
|
||||
# Process attention outputs for the `hidden_states`.
|
||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
hidden_states = hidden_states + attn_output
|
||||
@@ -195,6 +200,8 @@ class FluxTransformerBlock(nn.Module):
|
||||
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
||||
|
||||
hidden_states = hidden_states + ff_output
|
||||
if len(attention_outputs) == 3:
|
||||
hidden_states = hidden_states + ip_attn_output
|
||||
|
||||
# Process attention outputs for the `encoder_hidden_states`.
|
||||
|
||||
@@ -212,7 +219,9 @@ class FluxTransformerBlock(nn.Module):
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
class FluxTransformer2DModel(
|
||||
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin
|
||||
):
|
||||
"""
|
||||
The Transformer model introduced in Flux.
|
||||
|
||||
@@ -482,6 +491,11 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
|
||||
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
|
||||
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
|
||||
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
@@ -524,7 +538,6 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
)
|
||||
else:
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
|
||||
@@ -0,0 +1,787 @@
|
||||
# Copyright 2024 The Hunyuan Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention, AttentionProcessor
|
||||
from ..embeddings import (
|
||||
CombinedTimestepGuidanceTextProjEmbeddings,
|
||||
CombinedTimestepTextProjEmbeddings,
|
||||
get_1d_rotary_pos_embed,
|
||||
)
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class HunyuanVideoAttnProcessor2_0:
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"HunyuanVideoAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn.add_q_proj is None and encoder_hidden_states is not None:
|
||||
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
||||
|
||||
# 1. QKV projections
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
|
||||
# 2. QK normalization
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# 3. Rotational positional embeddings applied to latent stream
|
||||
if image_rotary_emb is not None:
|
||||
from ..embeddings import apply_rotary_emb
|
||||
|
||||
if attn.add_q_proj is None and encoder_hidden_states is not None:
|
||||
query = torch.cat(
|
||||
[
|
||||
apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
|
||||
query[:, :, -encoder_hidden_states.shape[1] :],
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
key = torch.cat(
|
||||
[
|
||||
apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
|
||||
key[:, :, -encoder_hidden_states.shape[1] :],
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
else:
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
# 4. Encoder condition QKV projection and normalization
|
||||
if attn.add_q_proj is not None and encoder_hidden_states is not None:
|
||||
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_query = attn.norm_added_q(encoder_query)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_key = attn.norm_added_k(encoder_key)
|
||||
|
||||
query = torch.cat([query, encoder_query], dim=2)
|
||||
key = torch.cat([key, encoder_key], dim=2)
|
||||
value = torch.cat([value, encoder_value], dim=2)
|
||||
|
||||
# 5. Attention
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# 6. Output projection
|
||||
if encoder_hidden_states is not None:
|
||||
hidden_states, encoder_hidden_states = (
|
||||
hidden_states[:, : -encoder_hidden_states.shape[1]],
|
||||
hidden_states[:, -encoder_hidden_states.shape[1] :],
|
||||
)
|
||||
|
||||
if getattr(attn, "to_out", None) is not None:
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if getattr(attn, "to_add_out", None) is not None:
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoPatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: Union[int, Tuple[int, int, int]] = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
patch_size = (patch_size, patch_size, patch_size) if isinstance(patch_size, int) else patch_size
|
||||
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.proj(hidden_states)
|
||||
hidden_states = hidden_states.flatten(2).transpose(1, 2) # BCFHW -> BNC
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoAdaNorm(nn.Module):
|
||||
def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
out_features = out_features or 2 * in_features
|
||||
self.linear = nn.Linear(in_features, out_features)
|
||||
self.nonlinearity = nn.SiLU()
|
||||
|
||||
def forward(
|
||||
self, temb: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
temb = self.linear(self.nonlinearity(temb))
|
||||
gate_msa, gate_mlp = temb.chunk(2, dim=1)
|
||||
gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
|
||||
return gate_msa, gate_mlp
|
||||
|
||||
|
||||
class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
mlp_width_ratio: str = 4.0,
|
||||
mlp_drop_rate: float = 0.0,
|
||||
attention_bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_size = num_attention_heads * attention_head_dim
|
||||
|
||||
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
|
||||
self.attn = Attention(
|
||||
query_dim=hidden_size,
|
||||
cross_attention_dim=None,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
bias=attention_bias,
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
|
||||
self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate)
|
||||
|
||||
self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
gate_msa, gate_mlp = self.norm_out(temb)
|
||||
hidden_states = hidden_states + attn_output * gate_msa
|
||||
|
||||
ff_output = self.ff(self.norm2(hidden_states))
|
||||
hidden_states = hidden_states + ff_output * gate_mlp
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoIndividualTokenRefiner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
num_layers: int,
|
||||
mlp_width_ratio: float = 4.0,
|
||||
mlp_drop_rate: float = 0.0,
|
||||
attention_bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.refiner_blocks = nn.ModuleList(
|
||||
[
|
||||
HunyuanVideoIndividualTokenRefinerBlock(
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
mlp_width_ratio=mlp_width_ratio,
|
||||
mlp_drop_rate=mlp_drop_rate,
|
||||
attention_bias=attention_bias,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
self_attn_mask = None
|
||||
if attention_mask is not None:
|
||||
batch_size = attention_mask.shape[0]
|
||||
seq_len = attention_mask.shape[1]
|
||||
attention_mask = attention_mask.to(hidden_states.device).bool()
|
||||
self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
|
||||
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
|
||||
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
|
||||
self_attn_mask[:, :, :, 0] = True
|
||||
|
||||
for block in self.refiner_blocks:
|
||||
hidden_states = block(hidden_states, temb, self_attn_mask)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoTokenRefiner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
num_layers: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
mlp_drop_rate: float = 0.0,
|
||||
attention_bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_size = num_attention_heads * attention_head_dim
|
||||
|
||||
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
|
||||
embedding_dim=hidden_size, pooled_projection_dim=in_channels
|
||||
)
|
||||
self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
|
||||
self.token_refiner = HunyuanVideoIndividualTokenRefiner(
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_layers=num_layers,
|
||||
mlp_width_ratio=mlp_ratio,
|
||||
mlp_drop_rate=mlp_drop_rate,
|
||||
attention_bias=attention_bias,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if attention_mask is None:
|
||||
pooled_projections = hidden_states.mean(dim=1)
|
||||
else:
|
||||
original_dtype = hidden_states.dtype
|
||||
mask_float = attention_mask.float().unsqueeze(-1)
|
||||
pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
|
||||
pooled_projections = pooled_projections.to(original_dtype)
|
||||
|
||||
temb = self.time_text_embed(timestep, pooled_projections)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoRotaryPosEmbed(nn.Module):
|
||||
def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.patch_size_t = patch_size_t
|
||||
self.rope_dim = rope_dim
|
||||
self.theta = theta
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
rope_sizes = [num_frames // self.patch_size_t, height // self.patch_size, width // self.patch_size]
|
||||
|
||||
axes_grids = []
|
||||
for i in range(3):
|
||||
# Note: The following line diverges from original behaviour. We create the grid on the device, whereas
|
||||
# original implementation creates it on CPU and then moves it to device. This results in numerical
|
||||
# differences in layerwise debugging outputs, but visually it is the same.
|
||||
grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32)
|
||||
axes_grids.append(grid)
|
||||
grid = torch.meshgrid(*axes_grids, indexing="ij") # [W, H, T]
|
||||
grid = torch.stack(grid, dim=0) # [3, W, H, T]
|
||||
|
||||
freqs = []
|
||||
for i in range(3):
|
||||
freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True)
|
||||
freqs.append(freq)
|
||||
|
||||
freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
|
||||
freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
class HunyuanVideoSingleTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_norm: str = "rms_norm",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_size = num_attention_heads * attention_head_dim
|
||||
mlp_dim = int(hidden_size * mlp_ratio)
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=hidden_size,
|
||||
cross_attention_dim=None,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=hidden_size,
|
||||
bias=True,
|
||||
processor=HunyuanVideoAttnProcessor2_0(),
|
||||
qk_norm=qk_norm,
|
||||
eps=1e-6,
|
||||
pre_only=True,
|
||||
)
|
||||
|
||||
self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
|
||||
self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
|
||||
self.act_mlp = nn.GELU(approximate="tanh")
|
||||
self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_length = encoder_hidden_states.shape[1]
|
||||
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
# 1. Input normalization
|
||||
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
||||
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
||||
|
||||
norm_hidden_states, norm_encoder_hidden_states = (
|
||||
norm_hidden_states[:, :-text_seq_length, :],
|
||||
norm_hidden_states[:, -text_seq_length:, :],
|
||||
)
|
||||
|
||||
# 2. Attention
|
||||
attn_output, context_attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
attn_output = torch.cat([attn_output, context_attn_output], dim=1)
|
||||
|
||||
# 3. Modulation and residual connection
|
||||
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
||||
hidden_states = gate.unsqueeze(1) * self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states, encoder_hidden_states = (
|
||||
hidden_states[:, :-text_seq_length, :],
|
||||
hidden_states[:, -text_seq_length:, :],
|
||||
)
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
mlp_ratio: float,
|
||||
qk_norm: str = "rms_norm",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hidden_size = num_attention_heads * attention_head_dim
|
||||
|
||||
self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
|
||||
self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
|
||||
|
||||
self.attn = Attention(
|
||||
query_dim=hidden_size,
|
||||
cross_attention_dim=None,
|
||||
added_kv_proj_dim=hidden_size,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=hidden_size,
|
||||
context_pre_only=False,
|
||||
bias=True,
|
||||
processor=HunyuanVideoAttnProcessor2_0(),
|
||||
qk_norm=qk_norm,
|
||||
eps=1e-6,
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
|
||||
|
||||
self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Input normalization
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
||||
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
||||
encoder_hidden_states, emb=temb
|
||||
)
|
||||
|
||||
# 2. Joint attention
|
||||
attn_output, context_attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=freqs_cis,
|
||||
)
|
||||
|
||||
# 3. Modulation and residual connection
|
||||
hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1)
|
||||
encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
||||
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
||||
|
||||
# 4. Feed-forward
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
|
||||
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
r"""
|
||||
A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to `16`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, defaults to `16`):
|
||||
The number of channels in the output.
|
||||
num_attention_heads (`int`, defaults to `24`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, defaults to `128`):
|
||||
The number of channels in each head.
|
||||
num_layers (`int`, defaults to `20`):
|
||||
The number of layers of dual-stream blocks to use.
|
||||
num_single_layers (`int`, defaults to `40`):
|
||||
The number of layers of single-stream blocks to use.
|
||||
num_refiner_layers (`int`, defaults to `2`):
|
||||
The number of layers of refiner blocks to use.
|
||||
mlp_ratio (`float`, defaults to `4.0`):
|
||||
The ratio of the hidden layer size to the input size in the feedforward network.
|
||||
patch_size (`int`, defaults to `2`):
|
||||
The size of the spatial patches to use in the patch embedding layer.
|
||||
patch_size_t (`int`, defaults to `1`):
|
||||
The size of the tmeporal patches to use in the patch embedding layer.
|
||||
qk_norm (`str`, defaults to `rms_norm`):
|
||||
The normalization to use for the query and key projections in the attention layers.
|
||||
guidance_embeds (`bool`, defaults to `True`):
|
||||
Whether to use guidance embeddings in the model.
|
||||
text_embed_dim (`int`, defaults to `4096`):
|
||||
Input dimension of text embeddings from the text encoder.
|
||||
pooled_projection_dim (`int`, defaults to `768`):
|
||||
The dimension of the pooled projection of the text embeddings.
|
||||
rope_theta (`float`, defaults to `256.0`):
|
||||
The value of theta to use in the RoPE layer.
|
||||
rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
|
||||
The dimensions of the axes to use in the RoPE layer.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 16,
|
||||
out_channels: int = 16,
|
||||
num_attention_heads: int = 24,
|
||||
attention_head_dim: int = 128,
|
||||
num_layers: int = 20,
|
||||
num_single_layers: int = 40,
|
||||
num_refiner_layers: int = 2,
|
||||
mlp_ratio: float = 4.0,
|
||||
patch_size: int = 2,
|
||||
patch_size_t: int = 1,
|
||||
qk_norm: str = "rms_norm",
|
||||
guidance_embeds: bool = True,
|
||||
text_embed_dim: int = 4096,
|
||||
pooled_projection_dim: int = 768,
|
||||
rope_theta: float = 256.0,
|
||||
rope_axes_dim: Tuple[int] = (16, 56, 56),
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
out_channels = out_channels or in_channels
|
||||
|
||||
# 1. Latent and condition embedders
|
||||
self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
|
||||
self.context_embedder = HunyuanVideoTokenRefiner(
|
||||
text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
|
||||
)
|
||||
self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
|
||||
|
||||
# 2. RoPE
|
||||
self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
|
||||
|
||||
# 3. Dual stream transformer blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
HunyuanVideoTransformerBlock(
|
||||
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Single stream transformer blocks
|
||||
self.single_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
HunyuanVideoSingleTransformerBlock(
|
||||
num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
|
||||
)
|
||||
for _ in range(num_single_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 5. Output projection
|
||||
self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_attention_mask: torch.Tensor,
|
||||
pooled_projections: torch.Tensor,
|
||||
guidance: torch.Tensor = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p, p_t = self.config.patch_size, self.config.patch_size_t
|
||||
post_patch_num_frames = num_frames // p_t
|
||||
post_patch_height = height // p
|
||||
post_patch_width = width // p
|
||||
|
||||
# 1. RoPE
|
||||
image_rotary_emb = self.rope(hidden_states)
|
||||
|
||||
# 2. Conditional embeddings
|
||||
temb = self.time_text_embed(timestep, guidance, pooled_projections)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
|
||||
|
||||
# 3. Attention mask preparation
|
||||
latent_sequence_length = hidden_states.shape[1]
|
||||
condition_sequence_length = encoder_hidden_states.shape[1]
|
||||
sequence_length = latent_sequence_length + condition_sequence_length
|
||||
attention_mask = torch.zeros(
|
||||
batch_size, sequence_length, sequence_length, device=hidden_states.device, dtype=torch.bool
|
||||
) # [B, N, N]
|
||||
|
||||
effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,]
|
||||
effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
|
||||
|
||||
for i in range(batch_size):
|
||||
attention_mask[i, : effective_sequence_length[i], : effective_sequence_length[i]] = True
|
||||
|
||||
# 4. Transformer blocks
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
attention_mask,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
for block in self.single_transformer_blocks:
|
||||
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
attention_mask,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
|
||||
)
|
||||
|
||||
for block in self.single_transformer_blocks:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states, encoder_hidden_states, temb, attention_mask, image_rotary_emb
|
||||
)
|
||||
|
||||
# 5. Output projection
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
||||
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (hidden_states,)
|
||||
|
||||
return Transformer2DModelOutput(sample=hidden_states)
|
||||
@@ -0,0 +1,469 @@
|
||||
# Copyright 2024 The Genmo team and The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention
|
||||
from ..embeddings import PixArtAlphaTextProjection
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormSingle, RMSNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class LTXAttentionProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
||||
used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"LTXAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTXRotaryPosEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
base_num_frames: int = 20,
|
||||
base_height: int = 2048,
|
||||
base_width: int = 2048,
|
||||
patch_size: int = 1,
|
||||
patch_size_t: int = 1,
|
||||
theta: float = 10000.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.base_num_frames = base_num_frames
|
||||
self.base_height = base_height
|
||||
self.base_width = base_width
|
||||
self.patch_size = patch_size
|
||||
self.patch_size_t = patch_size_t
|
||||
self.theta = theta
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
num_frames: int,
|
||||
height: int,
|
||||
width: int,
|
||||
rope_interpolation_scale: Optional[Tuple[torch.Tensor, float, float]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = hidden_states.size(0)
|
||||
|
||||
# Always compute rope in fp32
|
||||
grid_h = torch.arange(height, dtype=torch.float32, device=hidden_states.device)
|
||||
grid_w = torch.arange(width, dtype=torch.float32, device=hidden_states.device)
|
||||
grid_f = torch.arange(num_frames, dtype=torch.float32, device=hidden_states.device)
|
||||
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing="ij")
|
||||
grid = torch.stack(grid, dim=0)
|
||||
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||
|
||||
if rope_interpolation_scale is not None:
|
||||
grid[:, 0:1] = grid[:, 0:1] * rope_interpolation_scale[0] * self.patch_size_t / self.base_num_frames
|
||||
grid[:, 1:2] = grid[:, 1:2] * rope_interpolation_scale[1] * self.patch_size / self.base_height
|
||||
grid[:, 2:3] = grid[:, 2:3] * rope_interpolation_scale[2] * self.patch_size / self.base_width
|
||||
|
||||
grid = grid.flatten(2, 4).transpose(1, 2)
|
||||
|
||||
start = 1.0
|
||||
end = self.theta
|
||||
freqs = self.theta ** torch.linspace(
|
||||
math.log(start, self.theta),
|
||||
math.log(end, self.theta),
|
||||
self.dim // 6,
|
||||
device=hidden_states.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
freqs = freqs * math.pi / 2.0
|
||||
freqs = freqs * (grid.unsqueeze(-1) * 2 - 1)
|
||||
freqs = freqs.transpose(-1, -2).flatten(2)
|
||||
|
||||
cos_freqs = freqs.cos().repeat_interleave(2, dim=-1)
|
||||
sin_freqs = freqs.sin().repeat_interleave(2, dim=-1)
|
||||
|
||||
if self.dim % 6 != 0:
|
||||
cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % 6])
|
||||
sin_padding = torch.zeros_like(cos_freqs[:, :, : self.dim % 6])
|
||||
cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1)
|
||||
sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1)
|
||||
|
||||
return cos_freqs, sin_freqs
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class LTXTransformerBlock(nn.Module):
|
||||
r"""
|
||||
Transformer block used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
|
||||
|
||||
Args:
|
||||
dim (`int`):
|
||||
The number of channels in the input and output.
|
||||
num_attention_heads (`int`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`):
|
||||
The number of channels in each head.
|
||||
qk_norm (`str`, defaults to `"rms_norm"`):
|
||||
The normalization layer to use.
|
||||
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
||||
Activation function to use in feed-forward.
|
||||
eps (`float`, defaults to `1e-6`):
|
||||
Epsilon value for normalization layers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
cross_attention_dim: int,
|
||||
qk_norm: str = "rms_norm_across_heads",
|
||||
activation_fn: str = "gelu-approximate",
|
||||
attention_bias: bool = True,
|
||||
attention_out_bias: bool = True,
|
||||
eps: float = 1e-6,
|
||||
elementwise_affine: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
kv_heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=None,
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
processor=LTXAttentionProcessor2_0(),
|
||||
)
|
||||
|
||||
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.attn2 = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
heads=num_attention_heads,
|
||||
kv_heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
bias=attention_bias,
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
processor=LTXAttentionProcessor2_0(),
|
||||
)
|
||||
|
||||
self.ff = FeedForward(dim, activation_fn=activation_fn)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size = hidden_states.size(0)
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
num_ada_params = self.scale_shift_table.shape[0]
|
||||
ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
||||
|
||||
attn_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
hidden_states = hidden_states + attn_hidden_states * gate_msa
|
||||
|
||||
attn_hidden_states = self.attn2(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
image_rotary_emb=None,
|
||||
attention_mask=encoder_attention_mask,
|
||||
)
|
||||
hidden_states = hidden_states + attn_hidden_states
|
||||
norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp) + shift_mlp
|
||||
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
hidden_states = hidden_states + ff_output * gate_mlp
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
|
||||
r"""
|
||||
A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to `128`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, defaults to `128`):
|
||||
The number of channels in the output.
|
||||
patch_size (`int`, defaults to `1`):
|
||||
The size of the spatial patches to use in the patch embedding layer.
|
||||
patch_size_t (`int`, defaults to `1`):
|
||||
The size of the tmeporal patches to use in the patch embedding layer.
|
||||
num_attention_heads (`int`, defaults to `32`):
|
||||
The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, defaults to `64`):
|
||||
The number of channels in each head.
|
||||
cross_attention_dim (`int`, defaults to `2048 `):
|
||||
The number of channels for cross attention heads.
|
||||
num_layers (`int`, defaults to `28`):
|
||||
The number of layers of Transformer blocks to use.
|
||||
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
||||
Activation function to use in feed-forward.
|
||||
qk_norm (`str`, defaults to `"rms_norm_across_heads"`):
|
||||
The normalization layer to use.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 128,
|
||||
out_channels: int = 128,
|
||||
patch_size: int = 1,
|
||||
patch_size_t: int = 1,
|
||||
num_attention_heads: int = 32,
|
||||
attention_head_dim: int = 64,
|
||||
cross_attention_dim: int = 2048,
|
||||
num_layers: int = 28,
|
||||
activation_fn: str = "gelu-approximate",
|
||||
qk_norm: str = "rms_norm_across_heads",
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-6,
|
||||
caption_channels: int = 4096,
|
||||
attention_bias: bool = True,
|
||||
attention_out_bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
out_channels = out_channels or in_channels
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
|
||||
self.time_embed = AdaLayerNormSingle(inner_dim, use_additional_conditions=False)
|
||||
|
||||
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
||||
|
||||
self.rope = LTXRotaryPosEmbed(
|
||||
dim=inner_dim,
|
||||
base_num_frames=20,
|
||||
base_height=2048,
|
||||
base_width=2048,
|
||||
patch_size=patch_size,
|
||||
patch_size_t=patch_size_t,
|
||||
theta=10000.0,
|
||||
)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
LTXTransformerBlock(
|
||||
dim=inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
qk_norm=qk_norm,
|
||||
activation_fn=activation_fn,
|
||||
attention_bias=attention_bias,
|
||||
attention_out_bias=attention_out_bias,
|
||||
eps=norm_eps,
|
||||
elementwise_affine=norm_elementwise_affine,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm_out = nn.LayerNorm(inner_dim, eps=1e-6, elementwise_affine=False)
|
||||
self.proj_out = nn.Linear(inner_dim, out_channels)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
encoder_attention_mask: torch.Tensor,
|
||||
num_frames: int,
|
||||
height: int,
|
||||
width: int,
|
||||
rope_interpolation_scale: Optional[Tuple[float, float, float]] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> torch.Tensor:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale)
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
||||
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||
|
||||
batch_size = hidden_states.size(0)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
temb, embedded_timestep = self.time_embed(
|
||||
timestep.flatten(),
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
temb = temb.view(batch_size, -1, temb.size(-1))
|
||||
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
|
||||
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
encoder_attention_mask,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
|
||||
scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
|
||||
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
||||
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
|
||||
def apply_rotary_emb(x, freqs):
|
||||
cos, sin = freqs
|
||||
x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, H, D // 2]
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2)
|
||||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||
return out
|
||||
@@ -20,19 +20,100 @@ import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention, MochiAttnProcessor2_0
|
||||
from ..attention_processor import MochiAttention, MochiAttnProcessor2_0
|
||||
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, LuminaLayerNormContinuous, MochiRMSNormZero, RMSNorm
|
||||
from ..normalization import AdaLayerNormContinuous, RMSNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class MochiModulatedRMSNorm(nn.Module):
|
||||
def __init__(self, eps: float):
|
||||
super().__init__()
|
||||
|
||||
self.eps = eps
|
||||
self.norm = RMSNorm(0, eps, False)
|
||||
|
||||
def forward(self, hidden_states, scale=None):
|
||||
hidden_states_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
if scale is not None:
|
||||
hidden_states = hidden_states * scale
|
||||
|
||||
hidden_states = hidden_states.to(hidden_states_dtype)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MochiLayerNormContinuous(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
conditioning_embedding_dim: int,
|
||||
eps=1e-5,
|
||||
bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# AdaLN
|
||||
self.silu = nn.SiLU()
|
||||
self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
|
||||
self.norm = MochiModulatedRMSNorm(eps=eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
conditioning_embedding: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
input_dtype = x.dtype
|
||||
|
||||
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
|
||||
scale = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
|
||||
x = self.norm(x, (1 + scale.unsqueeze(1).to(torch.float32)))
|
||||
|
||||
return x.to(input_dtype)
|
||||
|
||||
|
||||
class MochiRMSNormZero(nn.Module):
|
||||
r"""
|
||||
Adaptive RMS Norm used in Mochi.
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, embedding_dim: int, hidden_dim: int, eps: float = 1e-5, elementwise_affine: bool = False
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, hidden_dim)
|
||||
self.norm = RMSNorm(0, eps, False)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, emb: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
hidden_states_dtype = hidden_states.dtype
|
||||
|
||||
emb = self.linear(self.silu(emb))
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
|
||||
hidden_states = self.norm(hidden_states.to(torch.float32)) * (1 + scale_msa[:, None].to(torch.float32))
|
||||
hidden_states = hidden_states.to(hidden_states_dtype)
|
||||
|
||||
return hidden_states, gate_msa, scale_mlp, gate_mlp
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class MochiTransformerBlock(nn.Module):
|
||||
r"""
|
||||
@@ -77,38 +158,32 @@ class MochiTransformerBlock(nn.Module):
|
||||
if not context_pre_only:
|
||||
self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim, eps=eps, elementwise_affine=False)
|
||||
else:
|
||||
self.norm1_context = LuminaLayerNormContinuous(
|
||||
self.norm1_context = MochiLayerNormContinuous(
|
||||
embedding_dim=pooled_projection_dim,
|
||||
conditioning_embedding_dim=dim,
|
||||
eps=eps,
|
||||
elementwise_affine=False,
|
||||
norm_type="rms_norm",
|
||||
out_dim=None,
|
||||
)
|
||||
|
||||
self.attn1 = Attention(
|
||||
self.attn1 = MochiAttention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
bias=False,
|
||||
qk_norm=qk_norm,
|
||||
added_kv_proj_dim=pooled_projection_dim,
|
||||
added_proj_bias=False,
|
||||
out_dim=dim,
|
||||
out_context_dim=pooled_projection_dim,
|
||||
context_pre_only=context_pre_only,
|
||||
processor=MochiAttnProcessor2_0(),
|
||||
eps=eps,
|
||||
elementwise_affine=True,
|
||||
eps=1e-5,
|
||||
)
|
||||
|
||||
# TODO(aryan): norm_context layers are not needed when `context_pre_only` is True
|
||||
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=False)
|
||||
self.norm2_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
|
||||
self.norm2 = MochiModulatedRMSNorm(eps=eps)
|
||||
self.norm2_context = MochiModulatedRMSNorm(eps=eps) if not self.context_pre_only else None
|
||||
|
||||
self.norm3 = RMSNorm(dim, eps=eps, elementwise_affine=False)
|
||||
self.norm3_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
|
||||
self.norm3 = MochiModulatedRMSNorm(eps)
|
||||
self.norm3_context = MochiModulatedRMSNorm(eps=eps) if not self.context_pre_only else None
|
||||
|
||||
self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False)
|
||||
self.ff_context = None
|
||||
@@ -120,14 +195,15 @@ class MochiTransformerBlock(nn.Module):
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.norm4 = RMSNorm(dim, eps=eps, elementwise_affine=False)
|
||||
self.norm4_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
|
||||
self.norm4 = MochiModulatedRMSNorm(eps=eps)
|
||||
self.norm4_context = MochiModulatedRMSNorm(eps=eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
encoder_attention_mask: torch.Tensor,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
||||
@@ -143,22 +219,25 @@ class MochiTransformerBlock(nn.Module):
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=encoder_attention_mask,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1)
|
||||
norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp.unsqueeze(1))
|
||||
hidden_states = hidden_states + self.norm2(attn_hidden_states, torch.tanh(gate_msa).unsqueeze(1))
|
||||
norm_hidden_states = self.norm3(hidden_states, (1 + scale_mlp.unsqueeze(1).to(torch.float32)))
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
hidden_states = hidden_states + self.norm4(ff_output) * torch.tanh(gate_mlp).unsqueeze(1)
|
||||
hidden_states = hidden_states + self.norm4(ff_output, torch.tanh(gate_mlp).unsqueeze(1))
|
||||
|
||||
if not self.context_pre_only:
|
||||
encoder_hidden_states = encoder_hidden_states + self.norm2_context(
|
||||
context_attn_hidden_states
|
||||
) * torch.tanh(enc_gate_msa).unsqueeze(1)
|
||||
norm_encoder_hidden_states = self.norm3_context(encoder_hidden_states) * (1 + enc_scale_mlp.unsqueeze(1))
|
||||
context_attn_hidden_states, torch.tanh(enc_gate_msa).unsqueeze(1)
|
||||
)
|
||||
norm_encoder_hidden_states = self.norm3_context(
|
||||
encoder_hidden_states, (1 + enc_scale_mlp.unsqueeze(1).to(torch.float32))
|
||||
)
|
||||
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states + self.norm4_context(context_ff_output) * torch.tanh(
|
||||
enc_gate_mlp
|
||||
).unsqueeze(1)
|
||||
encoder_hidden_states = encoder_hidden_states + self.norm4_context(
|
||||
context_ff_output, torch.tanh(enc_gate_mlp).unsqueeze(1)
|
||||
)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
@@ -203,7 +282,10 @@ class MochiRoPE(nn.Module):
|
||||
return positions
|
||||
|
||||
def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
|
||||
freqs = torch.einsum("nd,dhf->nhf", pos, freqs.float())
|
||||
with torch.autocast(freqs.device.type, torch.float32):
|
||||
# Always run ROPE freqs computation in FP32
|
||||
freqs = torch.einsum("nd,dhf->nhf", pos.to(torch.float32), freqs.to(torch.float32))
|
||||
|
||||
freqs_cos = torch.cos(freqs)
|
||||
freqs_sin = torch.sin(freqs)
|
||||
return freqs_cos, freqs_sin
|
||||
@@ -223,7 +305,7 @@ class MochiRoPE(nn.Module):
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
r"""
|
||||
A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
|
||||
|
||||
@@ -253,6 +335,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["MochiTransformerBlock"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
@@ -309,7 +392,11 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
)
|
||||
|
||||
self.norm_out = AdaLayerNormContinuous(
|
||||
inner_dim, inner_dim, elementwise_affine=False, eps=1e-6, norm_type="layer_norm"
|
||||
inner_dim,
|
||||
inner_dim,
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
norm_type="layer_norm",
|
||||
)
|
||||
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
|
||||
|
||||
@@ -350,7 +437,10 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
post_patch_width = width // p
|
||||
|
||||
temb, encoder_hidden_states = self.time_embed(
|
||||
timestep, encoder_hidden_states, encoder_attention_mask, hidden_dtype=hidden_states.dtype
|
||||
timestep,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
hidden_dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
|
||||
@@ -381,6 +471,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
encoder_attention_mask,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
@@ -389,9 +480,9 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
|
||||
@@ -11,17 +11,14 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
|
||||
from ...models.attention import FeedForward, JointTransformerBlock
|
||||
from ...models.attention_processor import (
|
||||
Attention,
|
||||
@@ -106,7 +103,9 @@ class SD3SingleTransformerBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
class SD3Transformer2DModel(
|
||||
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin
|
||||
):
|
||||
"""
|
||||
The Transformer model introduced in Stable Diffusion 3.
|
||||
|
||||
@@ -352,8 +351,8 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
||||
from the embeddings of input conditions.
|
||||
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`):
|
||||
Embeddings projected from the embeddings of input conditions.
|
||||
timestep (`torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
block_controlnet_hidden_states (`list` of `torch.Tensor`):
|
||||
@@ -393,6 +392,12 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
||||
temb = self.time_text_embed(timestep, pooled_projections)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
|
||||
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
|
||||
ip_hidden_states, ip_temb = self.image_proj(ip_adapter_image_embeds, timestep)
|
||||
|
||||
joint_attention_kwargs.update(ip_hidden_states=ip_hidden_states, temb=ip_temb)
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
# Skip specified layers
|
||||
is_skip = True if skip_layers is not None and index_block in skip_layers else False
|
||||
@@ -414,18 +419,21 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
joint_attention_kwargs,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
elif not is_skip:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
)
|
||||
|
||||
# controlnet residual
|
||||
if block_controlnet_hidden_states is not None and block.context_pre_only is False:
|
||||
interval_control = len(self.transformer_blocks) / len(block_controlnet_hidden_states)
|
||||
interval_control = int(np.ceil(interval_control))
|
||||
hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control]
|
||||
hidden_states = hidden_states + block_controlnet_hidden_states[int(index_block / interval_control)]
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
@@ -217,7 +217,7 @@ class MidResTemporalBlock1D(nn.Module):
|
||||
if self.upsample:
|
||||
hidden_states = self.upsample(hidden_states)
|
||||
if self.downsample:
|
||||
self.downsample = self.downsample(hidden_states)
|
||||
hidden_states = self.downsample(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user