Compare commits
43 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a43b260c7a | |||
| a79c3af6bb | |||
| 3f3f0c16a6 | |||
| f3e1310469 | |||
| 87f83d3dd9 | |||
| f064b3bf73 | |||
| 3b079ec3fa | |||
| bc34fa8386 | |||
| 05e7a854d0 | |||
| 76ec3d1fee | |||
| cdaf84a708 | |||
| e8e44a510c | |||
| 21543de571 | |||
| d7dd924ece | |||
| 00f95b9755 | |||
| eea76892e8 | |||
| 27bf7fcd0e | |||
| a185e1ab91 | |||
| d93381cd41 | |||
| 3649d7b903 | |||
| 10c36e0b78 | |||
| 8846635873 | |||
| dd285099eb | |||
| 80f27d7e8d | |||
| d3e27e05f0 | |||
| 5df02fc171 | |||
| 7392c8ff5a | |||
| 474a248f10 | |||
| 7bc0a07b19 | |||
| 92542719ed | |||
| 6760300202 | |||
| 798265f2b6 | |||
| cd813499be | |||
| fbddf02807 | |||
| f20b83a04f | |||
| ee40088fe5 | |||
| 7fc53b5d66 | |||
| 0874dd04dc | |||
| 6184d8a433 | |||
| 5a6e386464 | |||
| 42077e6c73 | |||
| 3d8d8485fc | |||
| 195926bbdc |
@@ -75,10 +75,6 @@ jobs:
|
||||
- diffusers-pytorch-cuda
|
||||
- diffusers-pytorch-xformers-cuda
|
||||
- diffusers-pytorch-minimum-cuda
|
||||
- diffusers-flax-cpu
|
||||
- diffusers-flax-tpu
|
||||
- diffusers-onnxruntime-cpu
|
||||
- diffusers-onnxruntime-cuda
|
||||
- diffusers-doc-builder
|
||||
|
||||
steps:
|
||||
|
||||
@@ -321,55 +321,6 @@ jobs:
|
||||
name: torch_minimum_version_cuda_test_reports
|
||||
path: reports
|
||||
|
||||
run_nightly_onnx_tests:
|
||||
name: Nightly ONNXRuntime CUDA tests on Ubuntu
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: diffusers/diffusers-onnxruntime-cuda
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: NVIDIA-SMI
|
||||
run: nvidia-smi
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
python -m uv pip install pytest-reportlog
|
||||
- name: Environment
|
||||
run: python utils/print_env.py
|
||||
|
||||
- name: Run Nightly ONNXRuntime CUDA tests
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Onnx" \
|
||||
--make-reports=tests_onnx_cuda \
|
||||
--report-log=tests_onnx_cuda.log \
|
||||
tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_onnx_cuda_stats.txt
|
||||
cat reports/tests_onnx_cuda_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: tests_onnx_cuda_reports
|
||||
path: reports
|
||||
|
||||
run_nightly_quantization_tests:
|
||||
name: Torch quantization nightly tests
|
||||
strategy:
|
||||
@@ -485,57 +436,6 @@ jobs:
|
||||
name: torch_cuda_pipeline_level_quant_reports
|
||||
path: reports
|
||||
|
||||
run_flax_tpu_tests:
|
||||
name: Nightly Flax TPU Tests
|
||||
runs-on:
|
||||
group: gcp-ct5lp-hightpu-8t
|
||||
if: github.event_name == 'schedule'
|
||||
|
||||
container:
|
||||
image: diffusers/diffusers-flax-tpu
|
||||
options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
python -m uv pip install pytest-reportlog
|
||||
|
||||
- name: Environment
|
||||
run: python utils/print_env.py
|
||||
|
||||
- name: Run nightly Flax TPU tests
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 0 \
|
||||
-s -v -k "Flax" \
|
||||
--make-reports=tests_flax_tpu \
|
||||
--report-log=tests_flax_tpu.log \
|
||||
tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_flax_tpu_stats.txt
|
||||
cat reports/tests_flax_tpu_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: flax_tpu_test_reports
|
||||
path: reports
|
||||
|
||||
generate_consolidated_report:
|
||||
name: Generate Consolidated Test Report
|
||||
needs: [
|
||||
@@ -545,9 +445,9 @@ jobs:
|
||||
run_big_gpu_torch_tests,
|
||||
run_nightly_quantization_tests,
|
||||
run_nightly_pipeline_level_quantization_tests,
|
||||
run_nightly_onnx_tests,
|
||||
# run_nightly_onnx_tests,
|
||||
torch_minimum_version_cuda_tests,
|
||||
run_flax_tpu_tests
|
||||
# run_flax_tpu_tests
|
||||
]
|
||||
if: always()
|
||||
runs-on:
|
||||
|
||||
@@ -87,11 +87,6 @@ jobs:
|
||||
runner: aws-general-8-plus
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_cpu_models_schedulers
|
||||
- name: Fast Flax CPU tests
|
||||
framework: flax
|
||||
runner: aws-general-8-plus
|
||||
image: diffusers/diffusers-flax-cpu
|
||||
report: flax_cpu
|
||||
- name: PyTorch Example CPU tests
|
||||
framework: pytorch_examples
|
||||
runner: aws-general-8-plus
|
||||
@@ -147,15 +142,6 @@ jobs:
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/models tests/schedulers tests/others
|
||||
|
||||
- name: Run fast Flax TPU tests
|
||||
if: ${{ matrix.config.framework == 'flax' }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Flax" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests
|
||||
|
||||
- name: Run example PyTorch CPU tests
|
||||
if: ${{ matrix.config.framework == 'pytorch_examples' }}
|
||||
run: |
|
||||
|
||||
@@ -159,102 +159,6 @@ jobs:
|
||||
name: torch_cuda_test_reports_${{ matrix.module }}
|
||||
path: reports
|
||||
|
||||
flax_tpu_tests:
|
||||
name: Flax TPU Tests
|
||||
runs-on:
|
||||
group: gcp-ct5lp-hightpu-8t
|
||||
container:
|
||||
image: diffusers/diffusers-flax-tpu
|
||||
options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run Flax TPU tests
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 0 \
|
||||
-s -v -k "Flax" \
|
||||
--make-reports=tests_flax_tpu \
|
||||
tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_flax_tpu_stats.txt
|
||||
cat reports/tests_flax_tpu_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: flax_tpu_test_reports
|
||||
path: reports
|
||||
|
||||
onnx_cuda_tests:
|
||||
name: ONNX CUDA Tests
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: diffusers/diffusers-onnxruntime-cuda
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ --gpus 0
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run ONNXRuntime CUDA tests
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Onnx" \
|
||||
--make-reports=tests_onnx_cuda \
|
||||
tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_onnx_cuda_stats.txt
|
||||
cat reports/tests_onnx_cuda_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: onnx_cuda_test_reports
|
||||
path: reports
|
||||
|
||||
run_torch_compile_tests:
|
||||
name: PyTorch Compile CUDA tests
|
||||
|
||||
|
||||
@@ -33,16 +33,6 @@ jobs:
|
||||
runner: aws-general-8-plus
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_cpu
|
||||
- name: Fast Flax CPU tests on Ubuntu
|
||||
framework: flax
|
||||
runner: aws-general-8-plus
|
||||
image: diffusers/diffusers-flax-cpu
|
||||
report: flax_cpu
|
||||
- name: Fast ONNXRuntime CPU tests on Ubuntu
|
||||
framework: onnxruntime
|
||||
runner: aws-general-8-plus
|
||||
image: diffusers/diffusers-onnxruntime-cpu
|
||||
report: onnx_cpu
|
||||
- name: PyTorch Example CPU tests on Ubuntu
|
||||
framework: pytorch_examples
|
||||
runner: aws-general-8-plus
|
||||
@@ -87,24 +77,6 @@ jobs:
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
|
||||
- name: Run fast Flax TPU tests
|
||||
if: ${{ matrix.config.framework == 'flax' }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Flax" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
|
||||
- name: Run fast ONNXRuntime CPU tests
|
||||
if: ${{ matrix.config.framework == 'onnxruntime' }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
|
||||
- name: Run example PyTorch CPU tests
|
||||
if: ${{ matrix.config.framework == 'pytorch_examples' }}
|
||||
run: |
|
||||
|
||||
@@ -1,12 +1,7 @@
|
||||
name: Fast mps tests on main
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "src/diffusers/**.py"
|
||||
- "tests/**.py"
|
||||
workflow_dispatch:
|
||||
|
||||
env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
|
||||
@@ -213,101 +213,6 @@ jobs:
|
||||
with:
|
||||
name: torch_minimum_version_cuda_test_reports
|
||||
path: reports
|
||||
|
||||
flax_tpu_tests:
|
||||
name: Flax TPU Tests
|
||||
runs-on: docker-tpu
|
||||
container:
|
||||
image: diffusers/diffusers-flax-tpu
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ --privileged
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run slow Flax TPU tests
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 0 \
|
||||
-s -v -k "Flax" \
|
||||
--make-reports=tests_flax_tpu \
|
||||
tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_flax_tpu_stats.txt
|
||||
cat reports/tests_flax_tpu_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: flax_tpu_test_reports
|
||||
path: reports
|
||||
|
||||
onnx_cuda_tests:
|
||||
name: ONNX CUDA Tests
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: diffusers/diffusers-onnxruntime-cuda
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ --gpus 0
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run slow ONNXRuntime CUDA tests
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Onnx" \
|
||||
--make-reports=tests_onnx_cuda \
|
||||
tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_onnx_cuda_stats.txt
|
||||
cat reports/tests_onnx_cuda_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: onnx_cuda_test_reports
|
||||
path: reports
|
||||
|
||||
run_torch_compile_tests:
|
||||
name: PyTorch Compile CUDA tests
|
||||
|
||||
@@ -180,6 +180,8 @@
|
||||
title: Caching
|
||||
- local: optimization/memory
|
||||
title: Reduce memory usage
|
||||
- local: optimization/speed-memory-optims
|
||||
title: Compile and offloading quantized models
|
||||
- local: optimization/pruna
|
||||
title: Pruna
|
||||
- local: optimization/xformers
|
||||
|
||||
@@ -37,6 +37,10 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
|
||||
|
||||
</Tip>
|
||||
|
||||
## LoraBaseMixin
|
||||
|
||||
[[autodoc]] loaders.lora_base.LoraBaseMixin
|
||||
|
||||
## StableDiffusionLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.StableDiffusionLoraLoaderMixin
|
||||
@@ -96,10 +100,6 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.HiDreamImageLoraLoaderMixin
|
||||
|
||||
## LoraBaseMixin
|
||||
|
||||
[[autodoc]] loaders.lora_base.LoraBaseMixin
|
||||
|
||||
## WanLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.WanLoraLoaderMixin
|
||||
@@ -27,9 +27,36 @@ Chroma can use all the same optimizations as Flux.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Inference (Single File)
|
||||
## Inference
|
||||
|
||||
The `ChromaTransformer2DModel` supports loading checkpoints in the original format. This is also useful when trying to load finetunes or quantized versions of the models that have been published by the community.
|
||||
The Diffusers version of Chroma is based on the [`unlocked-v37`](https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors) version of the original model, which is available in the [Chroma repository](https://huggingface.co/lodestones/Chroma).
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import ChromaPipeline
|
||||
|
||||
pipe = ChromaPipeline.from_pretrained("lodestones/Chroma", torch_dtype=torch.bfloat16)
|
||||
pipe.enabe_model_cpu_offload()
|
||||
|
||||
prompt = [
|
||||
"A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done."
|
||||
]
|
||||
negative_prompt = ["low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"]
|
||||
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
generator=torch.Generator("cpu").manual_seed(433),
|
||||
num_inference_steps=40,
|
||||
guidance_scale=3.0,
|
||||
num_images_per_prompt=1,
|
||||
).images[0]
|
||||
image.save("chroma.png")
|
||||
```
|
||||
|
||||
## Loading from a single file
|
||||
|
||||
To use updated model checkpoints that are not in the Diffusers format, you can use the `ChromaTransformer2DModel` class to load the model from a single file in the original format. This is also useful when trying to load finetunes or quantized versions of the models that have been published by the community.
|
||||
|
||||
The following example demonstrates how to run Chroma from a single file.
|
||||
|
||||
@@ -38,30 +65,29 @@ Then run the following example
|
||||
```python
|
||||
import torch
|
||||
from diffusers import ChromaTransformer2DModel, ChromaPipeline
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
bfl_repo = "black-forest-labs/FLUX.1-dev"
|
||||
model_id = "lodestones/Chroma"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v35.safetensors", torch_dtype=dtype)
|
||||
|
||||
text_encoder = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
|
||||
tokenizer = T5Tokenizer.from_pretrained(bfl_repo, subfolder="tokenizer_2", torch_dtype=dtype)
|
||||
|
||||
pipe = ChromaPipeline.from_pretrained(bfl_repo, transformer=transformer, text_encoder=text_encoder, tokenizer=tokenizer, torch_dtype=dtype)
|
||||
transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors", torch_dtype=dtype)
|
||||
|
||||
pipe = ChromaPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=dtype)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = "A cat holding a sign that says hello world"
|
||||
prompt = [
|
||||
"A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done."
|
||||
]
|
||||
negative_prompt = ["low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"]
|
||||
|
||||
image = pipe(
|
||||
prompt,
|
||||
guidance_scale=4.0,
|
||||
output_type="pil",
|
||||
num_inference_steps=26,
|
||||
generator=torch.Generator("cpu").manual_seed(0)
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
generator=torch.Generator("cpu").manual_seed(433),
|
||||
num_inference_steps=40,
|
||||
guidance_scale=3.0,
|
||||
).images[0]
|
||||
|
||||
image.save("image.png")
|
||||
image.save("chroma-single-file.png")
|
||||
```
|
||||
|
||||
## ChromaPipeline
|
||||
@@ -69,3 +95,9 @@ image.save("image.png")
|
||||
[[autodoc]] ChromaPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## ChromaImg2ImgPipeline
|
||||
|
||||
[[autodoc]] ChromaImg2ImgPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
@@ -24,6 +24,31 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
|
||||
|
||||
</Tip>
|
||||
|
||||
## Loading original format checkpoints
|
||||
|
||||
Original format checkpoints that have not been converted to diffusers-expected format can be loaded using the `from_single_file` method.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import Cosmos2TextToImagePipeline, CosmosTransformer3DModel
|
||||
|
||||
model_id = "nvidia/Cosmos-Predict2-2B-Text2Image"
|
||||
transformer = CosmosTransformer3DModel.from_single_file(
|
||||
"https://huggingface.co/nvidia/Cosmos-Predict2-2B-Text2Image/blob/main/model.pt",
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to("cuda")
|
||||
pipe = Cosmos2TextToImagePipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface. Suds form and bubble around the scrubber, creating a satisfying visual of cleanliness in progress. The sound of scrubbing fills the air, accompanied by the gentle clinking of the dish against the sink. As the scrubber continues its task, the dish transforms, gleaming under the bright kitchen lights, symbolizing the triumph of cleanliness over mess."
|
||||
negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality."
|
||||
|
||||
output = pipe(
|
||||
prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(1)
|
||||
).images[0]
|
||||
output.save("output.png")
|
||||
```
|
||||
|
||||
## CosmosTextToWorldPipeline
|
||||
|
||||
[[autodoc]] CosmosTextToWorldPipeline
|
||||
|
||||
@@ -39,6 +39,7 @@ Flux comes in the following variants:
|
||||
| Canny Control (LoRA) | [`black-forest-labs/FLUX.1-Canny-dev-lora`](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev-lora) |
|
||||
| Depth Control (LoRA) | [`black-forest-labs/FLUX.1-Depth-dev-lora`](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev-lora) |
|
||||
| Redux (Adapter) | [`black-forest-labs/FLUX.1-Redux-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev) |
|
||||
| Kontext | [`black-forest-labs/FLUX.1-kontext`](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev) |
|
||||
|
||||
All checkpoints have different usage which we detail below.
|
||||
|
||||
@@ -273,6 +274,46 @@ images = pipe(
|
||||
images[0].save("flux-redux.png")
|
||||
```
|
||||
|
||||
### Kontext
|
||||
|
||||
Flux Kontext is a model that allows in-context control of the image generation process, allowing for editing, refinement, relighting, style transfer, character customization, and more.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxKontextPipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipe = FluxKontextPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.to("cuda")
|
||||
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png").convert("RGB")
|
||||
prompt = "Make Pikachu hold a sign that says 'Black Forest Labs is awesome', yarn art style, detailed, vibrant colors"
|
||||
image = pipe(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
guidance_scale=2.5,
|
||||
generator=torch.Generator().manual_seed(42),
|
||||
).images[0]
|
||||
image.save("flux-kontext.png")
|
||||
```
|
||||
|
||||
Flux Kontext comes with an integrity safety checker, which should be run after the image generation step. To run the safety checker, install the official repository from [black-forest-labs/flux](https://github.com/black-forest-labs/flux) and add the following code:
|
||||
|
||||
```python
|
||||
from flux.content_filters import PixtralContentFilter
|
||||
|
||||
# ... pipeline invocation to generate images
|
||||
|
||||
integrity_checker = PixtralContentFilter(torch.device("cuda"))
|
||||
image_ = np.array(image) / 255.0
|
||||
image_ = 2 * image_ - 1
|
||||
image_ = torch.from_numpy(image_).to("cuda", dtype=torch.float32).unsqueeze(0).permute(0, 3, 1, 2)
|
||||
if integrity_checker.test_image(image_):
|
||||
raise ValueError("Your image has been flagged. Choose another prompt/image or try again.")
|
||||
```
|
||||
|
||||
## Combining Flux Turbo LoRAs with Flux Control, Fill, and Redux
|
||||
|
||||
We can combine Flux Turbo LoRAs with Flux Control and other pipelines like Fill and Redux to enable few-steps' inference. The example below shows how to do that for Flux Control LoRA for depth and turbo LoRA from [`ByteDance/Hyper-SD`](https://hf.co/ByteDance/Hyper-SD).
|
||||
|
||||
@@ -302,12 +302,12 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
|
||||
```py
|
||||
# pip install ftfy
|
||||
import torch
|
||||
from diffusers import WanPipeline, AutoModel
|
||||
from diffusers import WanPipeline, WanTransformer3DModel, AutoencoderKLWan
|
||||
|
||||
vae = AutoModel.from_single_file(
|
||||
vae = AutoencoderKLWan.from_single_file(
|
||||
"https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/vae/wan_2.1_vae.safetensors"
|
||||
)
|
||||
transformer = AutoModel.from_single_file(
|
||||
transformer = WanTransformer3DModel.from_single_file(
|
||||
"https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
@@ -150,11 +150,63 @@ pipeline(prompt, num_inference_steps=30).images[0]
|
||||
|
||||
Compilation is slow the first time, but once compiled, it is significantly faster. Try to only use the compiled pipeline on the same type of inference operations. Calling the compiled pipeline on a different image size retriggers compilation which is slow and inefficient.
|
||||
|
||||
### Dynamic shape compilation
|
||||
|
||||
> [!TIP]
|
||||
> Make sure to always use the nightly version of PyTorch for better support.
|
||||
|
||||
`torch.compile` keeps track of input shapes and conditions, and if these are different, it recompiles the model. For example, if a model is compiled on a 1024x1024 resolution image and used on an image with a different resolution, it triggers recompilation.
|
||||
|
||||
To avoid recompilation, add `dynamic=True` to try and generate a more dynamic kernel to avoid recompilation when conditions change.
|
||||
|
||||
```diff
|
||||
+ torch.fx.experimental._config.use_duck_shape = False
|
||||
+ pipeline.unet = torch.compile(
|
||||
pipeline.unet, fullgraph=True, dynamic=True
|
||||
)
|
||||
```
|
||||
|
||||
Specifying `use_duck_shape=False` instructs the compiler if it should use the same symbolic variable to represent input sizes that are the same. For more details, check out this [comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790).
|
||||
|
||||
Not all models may benefit from dynamic compilation out of the box and may require changes. Refer to this [PR](https://github.com/huggingface/diffusers/pull/11297/) that improved the [`AuraFlowPipeline`] implementation to benefit from dynamic compilation.
|
||||
|
||||
Feel free to open an issue if dynamic compilation doesn't work as expected for a Diffusers model.
|
||||
|
||||
### Regional compilation
|
||||
|
||||
[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) reduces the cold start compilation time by only compiling a specific repeated region (or block) of the model instead of the entire model. The compiler reuses the cached and compiled code for the other blocks.
|
||||
|
||||
[Accelerate](https://huggingface.co/docs/accelerate/index) provides the [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78) method for automatically compiling the repeated blocks of a `nn.Module` sequentially. The rest of the model is compiled separately.
|
||||
[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) trims cold-start latency by compiling **only the small, frequently-repeated block(s)** of a model, typically a Transformer layer, enabling reuse of compiled artifacts for every subsequent occurrence.
|
||||
For many diffusion architectures this delivers the *same* runtime speed-ups as full-graph compilation yet cuts compile time by **8–10 ×**.
|
||||
|
||||
To make this effortless, [`ModelMixin`] exposes [`ModelMixin.compile_repeated_blocks`] API, a helper that wraps `torch.compile` around any sub-modules you designate as repeatable:
|
||||
|
||||
```py
|
||||
# pip install -U diffusers
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
).to("cuda")
|
||||
|
||||
# Compile only the repeated Transformer layers inside the UNet
|
||||
pipe.unet.compile_repeated_blocks(fullgraph=True)
|
||||
```
|
||||
|
||||
To enable a new model with regional compilation, add a `_repeated_blocks` attribute to your model class containing the class names (as strings) of the blocks you want compiled:
|
||||
|
||||
|
||||
```py
|
||||
class MyUNet(ModelMixin):
|
||||
_repeated_blocks = ("Transformer2DModel",) # ← compiled by default
|
||||
```
|
||||
|
||||
For more examples, see the reference [PR](https://github.com/huggingface/diffusers/pull/11705).
|
||||
|
||||
**Relation to Accelerate compile_regions** There is also a separate API in [accelerate](https://huggingface.co/docs/accelerate/index) - [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78). It takes a fully automatic approach: it walks the module, picks candidate blocks, then compiles the remaining graph separately. That hands-off experience is handy for quick experiments, but it also leaves fewer knobs when you want to fine-tune which blocks are compiled or adjust compilation flags.
|
||||
|
||||
|
||||
|
||||
```py
|
||||
# pip install -U accelerate
|
||||
@@ -167,6 +219,8 @@ pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
).to("cuda")
|
||||
pipeline.unet = compile_regions(pipeline.unet, mode="reduce-overhead", fullgraph=True)
|
||||
```
|
||||
`compile_repeated_blocks`, by contrast, is intentionally explicit. You list the repeated blocks once (via `_repeated_blocks`) and the helper compiles exactly those, nothing more. In practice this small dose of control hits a sweet spot for diffusion models: predictable behavior, easy reasoning about cache reuse, and still a one-liner for users.
|
||||
|
||||
|
||||
### Graph breaks
|
||||
|
||||
@@ -241,4 +295,4 @@ An input is projected into three subspaces, represented by the projection matric
|
||||
|
||||
```py
|
||||
pipeline.fuse_qkv_projections()
|
||||
```
|
||||
```
|
||||
|
||||
@@ -17,7 +17,7 @@ Modern diffusion models like [Flux](../api/pipelines/flux) and [Wan](../api/pipe
|
||||
This guide will show you how to reduce your memory usage.
|
||||
|
||||
> [!TIP]
|
||||
> Keep in mind these techniques may need to be adjusted depending on the model! For example, a transformer-based diffusion model may not benefit equally from these inference speed optimizations as a UNet-based model.
|
||||
> Keep in mind these techniques may need to be adjusted depending on the model. For example, a transformer-based diffusion model may not benefit equally from these memory optimizations as a UNet-based model.
|
||||
|
||||
## Multiple GPUs
|
||||
|
||||
@@ -63,7 +63,12 @@ pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
> [!WARNING]
|
||||
> Device placement is an experimental feature and the API may change. Only the `balanced` strategy is supported at the moment. We plan to support additional mapping strategies in the future.
|
||||
|
||||
The `device_map` parameter controls how the model components in a pipeline are distributed across devices. The `balanced` device placement strategy evenly splits the pipeline across all available devices.
|
||||
The `device_map` parameter controls how the model components in a pipeline or the layers in an individual model are distributed across devices.
|
||||
|
||||
<hfoptions id="device-map">
|
||||
<hfoption id="pipeline level">
|
||||
|
||||
The `balanced` device placement strategy evenly splits the pipeline across all available devices.
|
||||
|
||||
```py
|
||||
import torch
|
||||
@@ -83,7 +88,10 @@ print(pipeline.hf_device_map)
|
||||
{'unet': 1, 'vae': 1, 'safety_checker': 0, 'text_encoder': 0}
|
||||
```
|
||||
|
||||
The `device_map` parameter also works on the model-level. This is useful for loading large models, such as the Flux diffusion transformer which has 12.5B parameters. Instead of `balanced`, set it to `"auto"` to automatically distribute a model across the fastest device first before moving to slower devices. Refer to the [Model sharding](../training/distributed_inference#model-sharding) docs for more details.
|
||||
</hfoption>
|
||||
<hfoption id="model level">
|
||||
|
||||
The `device_map` is useful for loading large models, such as the Flux diffusion transformer which has 12.5B parameters. Set it to `"auto"` to automatically distribute a model across the fastest device first before moving to slower devices. Refer to the [Model sharding](../training/distributed_inference#model-sharding) docs for more details.
|
||||
|
||||
```py
|
||||
import torch
|
||||
@@ -97,7 +105,43 @@ transformer = AutoModel.from_pretrained(
|
||||
)
|
||||
```
|
||||
|
||||
For more fine-grained control, pass a dictionary to enforce the maximum GPU memory to use on each device. If a device is not in `max_memory`, it is ignored and pipeline components won't be distributed to it.
|
||||
You can inspect a model's device map with `hf_device_map`.
|
||||
|
||||
```py
|
||||
print(transformer.hf_device_map)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
When designing your own `device_map`, it should be a dictionary of a model's specific module name or layer and a device identifier (an integer for GPUs, `cpu` for CPUs, and `disk` for disk).
|
||||
|
||||
Call `hf_device_map` on a model to see how model layers are distributed and then design your own.
|
||||
|
||||
```py
|
||||
print(transformer.hf_device_map)
|
||||
{'pos_embed': 0, 'time_text_embed': 0, 'context_embedder': 0, 'x_embedder': 0, 'transformer_blocks': 0, 'single_transformer_blocks.0': 0, 'single_transformer_blocks.1': 0, 'single_transformer_blocks.2': 0, 'single_transformer_blocks.3': 0, 'single_transformer_blocks.4': 0, 'single_transformer_blocks.5': 0, 'single_transformer_blocks.6': 0, 'single_transformer_blocks.7': 0, 'single_transformer_blocks.8': 0, 'single_transformer_blocks.9': 0, 'single_transformer_blocks.10': 'cpu', 'single_transformer_blocks.11': 'cpu', 'single_transformer_blocks.12': 'cpu', 'single_transformer_blocks.13': 'cpu', 'single_transformer_blocks.14': 'cpu', 'single_transformer_blocks.15': 'cpu', 'single_transformer_blocks.16': 'cpu', 'single_transformer_blocks.17': 'cpu', 'single_transformer_blocks.18': 'cpu', 'single_transformer_blocks.19': 'cpu', 'single_transformer_blocks.20': 'cpu', 'single_transformer_blocks.21': 'cpu', 'single_transformer_blocks.22': 'cpu', 'single_transformer_blocks.23': 'cpu', 'single_transformer_blocks.24': 'cpu', 'single_transformer_blocks.25': 'cpu', 'single_transformer_blocks.26': 'cpu', 'single_transformer_blocks.27': 'cpu', 'single_transformer_blocks.28': 'cpu', 'single_transformer_blocks.29': 'cpu', 'single_transformer_blocks.30': 'cpu', 'single_transformer_blocks.31': 'cpu', 'single_transformer_blocks.32': 'cpu', 'single_transformer_blocks.33': 'cpu', 'single_transformer_blocks.34': 'cpu', 'single_transformer_blocks.35': 'cpu', 'single_transformer_blocks.36': 'cpu', 'single_transformer_blocks.37': 'cpu', 'norm_out': 'cpu', 'proj_out': 'cpu'}
|
||||
```
|
||||
|
||||
For example, the `device_map` below places `single_transformer_blocks.10` through `single_transformer_blocks.20` on a second GPU (`1`).
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import AutoModel
|
||||
|
||||
device_map = {
|
||||
'pos_embed': 0, 'time_text_embed': 0, 'context_embedder': 0, 'x_embedder': 0, 'transformer_blocks': 0, 'single_transformer_blocks.0': 0, 'single_transformer_blocks.1': 0, 'single_transformer_blocks.2': 0, 'single_transformer_blocks.3': 0, 'single_transformer_blocks.4': 0, 'single_transformer_blocks.5': 0, 'single_transformer_blocks.6': 0, 'single_transformer_blocks.7': 0, 'single_transformer_blocks.8': 0, 'single_transformer_blocks.9': 0, 'single_transformer_blocks.10': 1, 'single_transformer_blocks.11': 1, 'single_transformer_blocks.12': 1, 'single_transformer_blocks.13': 1, 'single_transformer_blocks.14': 1, 'single_transformer_blocks.15': 1, 'single_transformer_blocks.16': 1, 'single_transformer_blocks.17': 1, 'single_transformer_blocks.18': 1, 'single_transformer_blocks.19': 1, 'single_transformer_blocks.20': 1, 'single_transformer_blocks.21': 'cpu', 'single_transformer_blocks.22': 'cpu', 'single_transformer_blocks.23': 'cpu', 'single_transformer_blocks.24': 'cpu', 'single_transformer_blocks.25': 'cpu', 'single_transformer_blocks.26': 'cpu', 'single_transformer_blocks.27': 'cpu', 'single_transformer_blocks.28': 'cpu', 'single_transformer_blocks.29': 'cpu', 'single_transformer_blocks.30': 'cpu', 'single_transformer_blocks.31': 'cpu', 'single_transformer_blocks.32': 'cpu', 'single_transformer_blocks.33': 'cpu', 'single_transformer_blocks.34': 'cpu', 'single_transformer_blocks.35': 'cpu', 'single_transformer_blocks.36': 'cpu', 'single_transformer_blocks.37': 'cpu', 'norm_out': 'cpu', 'proj_out': 'cpu'
|
||||
}
|
||||
|
||||
transformer = AutoModel.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
subfolder="transformer",
|
||||
device_map=device_map,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
```
|
||||
|
||||
Pass a dictionary mapping maximum memory usage to each device to enforce a limit. If a device is not in `max_memory`, it is ignored and pipeline components won't be distributed to it.
|
||||
|
||||
```py
|
||||
import torch
|
||||
@@ -145,7 +189,7 @@ print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} G
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> [`AutoencoderKLWan`] and [`AsymmetricAutoencoderKL`] don't support slicing.
|
||||
> The [`AutoencoderKLWan`] and [`AsymmetricAutoencoderKL`] classes don't support slicing.
|
||||
|
||||
## VAE tiling
|
||||
|
||||
@@ -172,7 +216,13 @@ print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} G
|
||||
> [!WARNING]
|
||||
> [`AutoencoderKLWan`] and [`AsymmetricAutoencoderKL`] don't support tiling.
|
||||
|
||||
## CPU offloading
|
||||
## Offloading
|
||||
|
||||
Offloading strategies move not currently active layers or models to the CPU to avoid increasing GPU memory. These strategies can be combined with quantization and torch.compile to balance inference speed and memory usage.
|
||||
|
||||
Refer to the [Compile and offloading quantized models](./speed-memory-optims) guide for more details.
|
||||
|
||||
### CPU offloading
|
||||
|
||||
CPU offloading selectively moves weights from the GPU to the CPU. When a component is required, it is transferred to the GPU and when it isn't required, it is moved to the CPU. This method works on submodules rather than whole models. It saves memory by avoiding storing the entire model on the GPU.
|
||||
|
||||
@@ -203,7 +253,7 @@ pipeline(
|
||||
print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
|
||||
```
|
||||
|
||||
## Model offloading
|
||||
### Model offloading
|
||||
|
||||
Model offloading moves entire models to the GPU instead of selectively moving *some* layers or model components. One of the main pipeline models, usually the text encoder, UNet, and VAE, is placed on the GPU while the other components are held on the CPU. Components like the UNet that run multiple times stays on the GPU until its completely finished and no longer needed. This eliminates the communication overhead of [CPU offloading](#cpu-offloading) and makes model offloading a faster alternative. The tradeoff is memory savings won't be as large.
|
||||
|
||||
@@ -219,7 +269,7 @@ from diffusers import DiffusionPipeline
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipline.enable_model_cpu_offload()
|
||||
pipeline.enable_model_cpu_offload()
|
||||
|
||||
pipeline(
|
||||
prompt="An astronaut riding a horse on Mars",
|
||||
@@ -234,7 +284,7 @@ print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} G
|
||||
|
||||
[`~DiffusionPipeline.enable_model_cpu_offload`] also helps when you're using the [`~StableDiffusionXLPipeline.encode_prompt`] method on its own to generate the text encoders hidden state.
|
||||
|
||||
## Group offloading
|
||||
### Group offloading
|
||||
|
||||
Group offloading moves groups of internal layers ([torch.nn.ModuleList](https://pytorch.org/docs/stable/generated/torch.nn.ModuleList.html) or [torch.nn.Sequential](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html)) to the CPU. It uses less memory than [model offloading](#model-offloading) and it is faster than [CPU offloading](#cpu-offloading) because it reduces communication overhead.
|
||||
|
||||
@@ -278,7 +328,7 @@ print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} G
|
||||
export_to_video(video, "output.mp4", fps=8)
|
||||
```
|
||||
|
||||
### CUDA stream
|
||||
#### CUDA stream
|
||||
|
||||
The `use_stream` parameter can be activated for CUDA devices that support asynchronous data transfer streams to reduce overall execution time compared to [CPU offloading](#cpu-offloading). It overlaps data transfer and computation by using layer prefetching. The next layer to be executed is loaded onto the GPU while the current layer is still being executed. It can increase CPU memory significantly so ensure you have 2x the amount of memory as the model size.
|
||||
|
||||
@@ -295,22 +345,25 @@ pipeline.transformer.enable_group_offload(onload_device=onload_device, offload_d
|
||||
|
||||
The `low_cpu_mem_usage` parameter can be set to `True` to reduce CPU memory usage when using streams during group offloading. It is best for `leaf_level` offloading and when CPU memory is bottlenecked. Memory is saved by creating pinned tensors on the fly instead of pre-pinning them. However, this may increase overall execution time.
|
||||
|
||||
<Tip>
|
||||
#### Offloading to disk
|
||||
|
||||
The offloading strategies can be combined with [quantization](../quantization/overview.md) to enable further memory savings. For image generation, combining [quantization and model offloading](#model-offloading) can often give the best trade-off between quality, speed, and memory. However, for video generation, as the models are more
|
||||
compute-bound, [group-offloading](#group-offloading) tends to be better. Group offloading provides considerable benefits when weight transfers can be overlapped with computation (must use streams). When applying group offloading with quantization on image generation models at typical resolutions (1024x1024, for example), it is usually not possible to *fully* overlap weight transfers if the compute kernel finishes faster, making it communication bound between CPU/GPU (due to device synchronizations).
|
||||
Group offloading can consume significant system memory depending on the model size. On systems with limited memory, try group offloading onto the disk as a secondary memory.
|
||||
|
||||
</Tip>
|
||||
Set the `offload_to_disk_path` argument in either [`~ModelMixin.enable_group_offload`] or [`~hooks.apply_group_offloading`] to offload the model to the disk.
|
||||
|
||||
### Offloading to disk
|
||||
```py
|
||||
pipeline.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", offload_to_disk_path="path/to/disk")
|
||||
|
||||
Group offloading can consume significant system RAM depending on the model size. In limited RAM environments,
|
||||
it can be useful to offload to the second memory, instead. You can do this by setting the `offload_to_disk_path`
|
||||
argument in either of [`~ModelMixin.enable_group_offload`] or [`~hooks.apply_group_offloading`]. Refer [here](https://github.com/huggingface/diffusers/pull/11682#issue-3129365363) and
|
||||
[here](https://github.com/huggingface/diffusers/pull/11682#issuecomment-2955715126) for the expected speed-memory trade-offs with this option enabled.
|
||||
apply_group_offloading(pipeline.text_encoder, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=2, offload_to_disk_path="path/to/disk")
|
||||
```
|
||||
|
||||
Refer to these [two](https://github.com/huggingface/diffusers/pull/11682#issue-3129365363) [tables](https://github.com/huggingface/diffusers/pull/11682#issuecomment-2955715126) to compare the speed and memory trade-offs.
|
||||
|
||||
## Layerwise casting
|
||||
|
||||
> [!TIP]
|
||||
> Combine layerwise casting with [group offloading](#group-offloading) for even more memory savings.
|
||||
|
||||
Layerwise casting stores weights in a smaller data format (for example, `torch.float8_e4m3fn` and `torch.float8_e5m2`) to use less memory and upcasts those weights to a higher precision like `torch.float16` or `torch.bfloat16` for computation. Certain layers (normalization and modulation related weights) are skipped because storing them in fp8 can degrade generation quality.
|
||||
|
||||
> [!WARNING]
|
||||
@@ -500,7 +553,7 @@ with torch.inference_mode():
|
||||
## Memory-efficient attention
|
||||
|
||||
> [!TIP]
|
||||
> Memory-efficient attention optimizes for memory usage *and* [inference speed](./fp16#scaled-dot-product-attention!
|
||||
> Memory-efficient attention optimizes for memory usage *and* [inference speed](./fp16#scaled-dot-product-attention)!
|
||||
|
||||
The Transformers attention mechanism is memory-intensive, especially for long sequences, so you can try using different and more memory-efficient attention types.
|
||||
|
||||
|
||||
@@ -0,0 +1,199 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Compile and offloading quantized models
|
||||
|
||||
Optimizing models often involves trade-offs between [inference speed](./fp16) and [memory-usage](./memory). For instance, while [caching](./cache) can boost inference speed, it also increases memory consumption since it needs to store the outputs of intermediate attention layers. A more balanced optimization strategy combines quantizing a model, [torch.compile](./fp16#torchcompile) and various [offloading methods](./memory#offloading).
|
||||
|
||||
For image generation, combining quantization and [model offloading](./memory#model-offloading) can often give the best trade-off between quality, speed, and memory. Group offloading is not as effective for image generation because it is usually not possible to *fully* overlap data transfer if the compute kernel finishes faster. This results in some communication overhead between the CPU and GPU.
|
||||
|
||||
For video generation, combining quantization and [group-offloading](./memory#group-offloading) tends to be better because video models are more compute-bound.
|
||||
|
||||
The table below provides a comparison of optimization strategy combinations and their impact on latency and memory-usage for Flux.
|
||||
|
||||
| combination | latency (s) | memory-usage (GB) |
|
||||
|---|---|---|
|
||||
| quantization | 32.602 | 14.9453 |
|
||||
| quantization, torch.compile | 25.847 | 14.9448 |
|
||||
| quantization, torch.compile, model CPU offloading | 32.312 | 12.2369 |
|
||||
<small>These results are benchmarked on Flux with a RTX 4090. The transformer and text_encoder components are quantized. Refer to the <a href="https://gist.github.com/sayakpaul/0db9d8eeeb3d2a0e5ed7cf0d9ca19b7d" benchmarking script</a> if you're interested in evaluating your own model.</small>
|
||||
|
||||
This guide will show you how to compile and offload a quantized model with [bitsandbytes](../quantization/bitsandbytes#torchcompile). Make sure you are using [PyTorch nightly](https://pytorch.org/get-started/locally/) and the latest version of bitsandbytes.
|
||||
|
||||
```bash
|
||||
pip install -U bitsandbytes
|
||||
```
|
||||
|
||||
## Quantization and torch.compile
|
||||
|
||||
Start by [quantizing](../quantization/overview) a model to reduce the memory required for storage and [compiling](./fp16#torchcompile) it to accelerate inference.
|
||||
|
||||
Configure the [Dynamo](https://docs.pytorch.org/docs/stable/torch.compiler_dynamo_overview.html) `capture_dynamic_output_shape_ops = True` to handle dynamic outputs when compiling bitsandbytes models.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.quantizers import PipelineQuantizationConfig
|
||||
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
|
||||
# quantize
|
||||
pipeline_quant_config = PipelineQuantizationConfig(
|
||||
quant_backend="bitsandbytes_4bit",
|
||||
quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
|
||||
components_to_quantize=["transformer", "text_encoder_2"],
|
||||
)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
quantization_config=pipeline_quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to("cuda")
|
||||
|
||||
# compile
|
||||
pipeline.transformer.to(memory_format=torch.channels_last)
|
||||
pipeline.transformer.compile(mode="max-autotune", fullgraph=True)
|
||||
pipeline("""
|
||||
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
|
||||
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
|
||||
"""
|
||||
).images[0]
|
||||
```
|
||||
|
||||
## Quantization, torch.compile, and offloading
|
||||
|
||||
In addition to quantization and torch.compile, try offloading if you need to reduce memory-usage further. Offloading moves various layers or model components from the CPU to the GPU as needed for computations.
|
||||
|
||||
Configure the [Dynamo](https://docs.pytorch.org/docs/stable/torch.compiler_dynamo_overview.html) `cache_size_limit` during offloading to avoid excessive recompilation and set `capture_dynamic_output_shape_ops = True` to handle dynamic outputs when compiling bitsandbytes models.
|
||||
|
||||
<hfoptions id="offloading">
|
||||
<hfoption id="model CPU offloading">
|
||||
|
||||
[Model CPU offloading](./memory#model-offloading) moves an individual pipeline component, like the transformer model, to the GPU when it is needed for computation. Otherwise, it is offloaded to the CPU.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.quantizers import PipelineQuantizationConfig
|
||||
|
||||
torch._dynamo.config.cache_size_limit = 1000
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
|
||||
# quantize
|
||||
pipeline_quant_config = PipelineQuantizationConfig(
|
||||
quant_backend="bitsandbytes_4bit",
|
||||
quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
|
||||
components_to_quantize=["transformer", "text_encoder_2"],
|
||||
)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
quantization_config=pipeline_quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to("cuda")
|
||||
|
||||
# model CPU offloading
|
||||
pipeline.enable_model_cpu_offload()
|
||||
|
||||
# compile
|
||||
pipeline.transformer.compile()
|
||||
pipeline(
|
||||
"cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain"
|
||||
).images[0]
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="group offloading">
|
||||
|
||||
[Group offloading](./memory#group-offloading) moves the internal layers of an individual pipeline component, like the transformer model, to the GPU for computation and offloads it when it's not required. At the same time, it uses the [CUDA stream](./memory#cuda-stream) feature to prefetch the next layer for execution.
|
||||
|
||||
By overlapping computation and data transfer, it is faster than model CPU offloading while also saving memory.
|
||||
|
||||
```py
|
||||
# pip install ftfy
|
||||
import torch
|
||||
from diffusers import AutoModel, DiffusionPipeline
|
||||
from diffusers.hooks import apply_group_offloading
|
||||
from diffusers.utils import export_to_video
|
||||
from diffusers.quantizers import PipelineQuantizationConfig
|
||||
from transformers import UMT5EncoderModel
|
||||
|
||||
torch._dynamo.config.cache_size_limit = 1000
|
||||
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
||||
|
||||
# quantize
|
||||
pipeline_quant_config = PipelineQuantizationConfig(
|
||||
quant_backend="bitsandbytes_4bit",
|
||||
quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
|
||||
components_to_quantize=["transformer", "text_encoder"],
|
||||
)
|
||||
|
||||
text_encoder = UMT5EncoderModel.from_pretrained(
|
||||
"Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="text_encoder", torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"Wan-AI/Wan2.1-T2V-14B-Diffusers",
|
||||
quantization_config=pipeline_quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to("cuda")
|
||||
|
||||
# group offloading
|
||||
onload_device = torch.device("cuda")
|
||||
offload_device = torch.device("cpu")
|
||||
|
||||
pipeline.transformer.enable_group_offload(
|
||||
onload_device=onload_device,
|
||||
offload_device=offload_device,
|
||||
offload_type="leaf_level",
|
||||
use_stream=True,
|
||||
non_blocking=True
|
||||
)
|
||||
pipeline.vae.enable_group_offload(
|
||||
onload_device=onload_device,
|
||||
offload_device=offload_device,
|
||||
offload_type="leaf_level",
|
||||
use_stream=True,
|
||||
non_blocking=True
|
||||
)
|
||||
apply_group_offloading(
|
||||
pipeline.text_encoder,
|
||||
onload_device=onload_device,
|
||||
offload_type="leaf_level",
|
||||
use_stream=True,
|
||||
non_blocking=True
|
||||
)
|
||||
|
||||
# compile
|
||||
pipeline.transformer.compile()
|
||||
|
||||
prompt = """
|
||||
The camera rushes from far to near in a low-angle shot,
|
||||
revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
|
||||
for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
|
||||
Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
|
||||
shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
|
||||
"""
|
||||
negative_prompt = """
|
||||
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
|
||||
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
|
||||
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
|
||||
"""
|
||||
|
||||
output = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
num_frames=81,
|
||||
guidance_scale=5.0,
|
||||
).frames[0]
|
||||
export_to_video(output, "output.mp4", fps=16)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
@@ -203,6 +203,46 @@ pipeline("bears, pizza bites").images[0]
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Scale scheduling
|
||||
|
||||
Dynamically adjusting the LoRA scale during sampling gives you better control over the overall composition and layout because certain steps may benefit more from an increased or reduced scale.
|
||||
|
||||
The [character LoRA](https://huggingface.co/alvarobartt/ghibli-characters-flux-lora) in the example below starts with a higher scale that gradually decays over the first 20 steps to establish the character generation. In the later steps, only a scale of 0.2 is applied to avoid adding too much of the LoRA features to other parts of the image the LoRA wasn't trained on.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import FluxPipeline
|
||||
|
||||
pipeline = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
|
||||
pipelne.load_lora_weights("alvarobartt/ghibli-characters-flux-lora", "lora")
|
||||
|
||||
num_inference_steps = 30
|
||||
lora_steps = 20
|
||||
lora_scales = torch.linspace(1.5, 0.7, lora_steps).tolist()
|
||||
lora_scales += [0.2] * (num_inference_steps - lora_steps + 1)
|
||||
|
||||
pipeline.set_adapters("lora", lora_scales[0])
|
||||
|
||||
def callback(pipeline: FluxPipeline, step: int, timestep: torch.LongTensor, callback_kwargs: dict):
|
||||
pipeline.set_adapters("lora", lora_scales[step + 1])
|
||||
return callback_kwargs
|
||||
|
||||
prompt = """
|
||||
Ghibli style The Grinch, a mischievous green creature with a sly grin, peeking out from behind a snow-covered tree while plotting his antics,
|
||||
in a quaint snowy village decorated for the holidays, warm light glowing from cozy homes, with playful snowflakes dancing in the air
|
||||
"""
|
||||
pipeline(
|
||||
prompt=prompt,
|
||||
guidance_scale=3.0,
|
||||
num_inference_steps=num_inference_steps,
|
||||
generator=torch.Generator().manual_seed(42),
|
||||
callback_on_step_end=callback,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
## Hotswapping
|
||||
|
||||
Hotswapping LoRAs is an efficient way to work with multiple LoRAs while avoiding accumulating memory from multiple calls to [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] and in some cases, recompilation, if a model is compiled. This workflow requires a loaded LoRA because the new LoRA weights are swapped in place for the existing loaded LoRA.
|
||||
@@ -275,6 +315,8 @@ pipeline.load_lora_weights(
|
||||
> [!TIP]
|
||||
> Move your code inside the `with torch._dynamo.config.patch(error_on_recompile=True)` context manager to detect if a model was recompiled. If a model is recompiled despite following all the steps above, please open an [issue](https://github.com/huggingface/diffusers/issues) with a reproducible example.
|
||||
|
||||
If you expect to varied resolutions during inference with this feature, then make sure set `dynamic=True` during compilation. Refer to [this document](../optimization/fp16#dynamic-shape-compilation) for more details.
|
||||
|
||||
There are still scenarios where recompulation is unavoidable, such as when the hotswapped LoRA targets more layers than the initial adapter. Try to load the LoRA that targets the most layers *first*. For more details about this limitation, refer to the PEFT [hotswapping](https://huggingface.co/docs/peft/main/en/package_reference/hotswap#peft.utils.hotswap.hotswap_adapter) docs.
|
||||
|
||||
## Merge
|
||||
|
||||
@@ -75,7 +75,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -73,7 +73,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -80,7 +80,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ from diffusers.utils import BaseOutput, check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
|
||||
class MarigoldDepthOutput(BaseOutput):
|
||||
|
||||
@@ -73,7 +73,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -78,7 +78,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -61,7 +61,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -63,7 +63,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -260,5 +260,51 @@ to enable `latent_caching` simply pass `--cache_latents`.
|
||||
By default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision="bf16"`, final finetuned layers will be saved in `torch.bfloat16` as well.
|
||||
This reduces memory requirements significantly w/o a significant quality loss. Note that if you do wish to save the final layers in float32 at the expanse of more memory usage, you can do so by passing `--upcast_before_saving`.
|
||||
|
||||
## Training Kontext
|
||||
|
||||
[Kontext](https://bfl.ai/announcements/flux-1-kontext) lets us perform image editing as well as image generation. Even though it can accept both image and text as inputs, one can use it for text-to-image (T2I) generation, too. We
|
||||
provide a simple script for LoRA fine-tuning Kontext in [train_dreambooth_lora_flux_kontext.py](./train_dreambooth_lora_flux_kontext.py) for T2I. The optimizations discussed above apply this script, too.
|
||||
|
||||
Make sure to follow the [instructions to set up your environment](#running-locally-with-pytorch) before proceeding to the rest of the section.
|
||||
|
||||
Below is an example training command:
|
||||
|
||||
```bash
|
||||
accelerate launch train_dreambooth_lora_flux_kontext.py \
|
||||
--pretrained_model_name_or_path=black-forest-labs/FLUX.1-Kontext-dev \
|
||||
--instance_data_dir="dog" \
|
||||
--output_dir="kontext-dog" \
|
||||
--mixed_precision="bf16" \
|
||||
--instance_prompt="a photo of sks dog" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--guidance_scale=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--gradient_checkpointing \
|
||||
--optimizer="adamw" \
|
||||
--use_8bit_adam \
|
||||
--cache_latents \
|
||||
--learning_rate=1e-4 \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=500 \
|
||||
--seed="0"
|
||||
```
|
||||
|
||||
Fine-tuning Kontext on the T2I task can be useful when working with specific styles/subjects where it may not
|
||||
perform as expected.
|
||||
|
||||
### Misc notes
|
||||
|
||||
* By default, we use `mode` as the value of `--vae_encode_mode` argument. This is because Kontext uses `mode()` of the distribution predicted by the VAE instead of sampling from it.
|
||||
### Aspect Ratio Bucketing
|
||||
we've added aspect ratio bucketing support which allows training on images with different aspect ratios without cropping them to a single square resolution. This technique helps preserve the original composition of training images and can improve training efficiency.
|
||||
|
||||
To enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a semicolon-separated list of height,width pairs, such as:
|
||||
|
||||
`--aspect_ratio_buckets="672,1568;688,1504;720,1456;752,1392;800,1328;832,1248;880,1184;944,1104;1024,1024;1104,944;1184,880;1248,832;1328,800;1392,752;1456,720;1504,688;1568,672"
|
||||
`
|
||||
Since Flux Kontext finetuning is still an experimental phase, we encourage you to explore different settings and share your insights! 🤗
|
||||
|
||||
## Other notes
|
||||
Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️
|
||||
@@ -0,0 +1,281 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import safetensors
|
||||
|
||||
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
class DreamBoothLoRAFluxKontext(ExamplesTestsAccelerate):
|
||||
instance_data_dir = "docs/source/en/imgs"
|
||||
instance_prompt = "photo"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-kontext-pipe"
|
||||
script_path = "examples/dreambooth/train_dreambooth_lora_flux_kontext.py"
|
||||
transformer_layer_type = "single_transformer_blocks.0.attn.to_k"
|
||||
|
||||
def test_dreambooth_lora_flux_kontext(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_text_encoder_flux_kontext(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--train_text_encoder
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
starts_with_expected_prefix = all(
|
||||
(key.startswith("transformer") or key.startswith("text_encoder")) for key in lora_state_dict.keys()
|
||||
)
|
||||
self.assertTrue(starts_with_expected_prefix)
|
||||
|
||||
def test_dreambooth_lora_latent_caching(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--cache_latents
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_layers(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--cache_latents
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lora_layers {self.transformer_layer_type}
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names. In this test, we only params of
|
||||
# transformer.single_transformer_blocks.0.attn.to_k should be in the state dict
|
||||
starts_with_transformer = all(
|
||||
key.startswith("transformer.single_transformer_blocks.0.attn.to_k") for key in lora_state_dict.keys()
|
||||
)
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_flux_kontext_checkpointing_checkpoints_total_limit(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=6
|
||||
--checkpoints_total_limit=2
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_dreambooth_lora_flux_kontext_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=4
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
|
||||
|
||||
resume_run_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=8
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--checkpoints_total_limit=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
|
||||
|
||||
def test_dreambooth_lora_with_metadata(self):
|
||||
# Use a `lora_alpha` that is different from `rank`.
|
||||
lora_alpha = 8
|
||||
rank = 4
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--lora_alpha={lora_alpha}
|
||||
--rank={rank}
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
|
||||
self.assertTrue(os.path.isfile(state_dict_file))
|
||||
|
||||
# Check if the metadata was properly serialized.
|
||||
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata() or {}
|
||||
|
||||
metadata.pop("format", None)
|
||||
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
|
||||
if raw:
|
||||
raw = json.loads(raw)
|
||||
|
||||
loaded_lora_alpha = raw["transformer.lora_alpha"]
|
||||
self.assertTrue(loaded_lora_alpha == lora_alpha)
|
||||
loaded_lora_rank = raw["transformer.r"]
|
||||
self.assertTrue(loaded_lora_rank == rank)
|
||||
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@@ -20,6 +21,8 @@ import tempfile
|
||||
|
||||
import safetensors
|
||||
|
||||
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
|
||||
@@ -204,3 +207,42 @@ class DreamBoothLoRASANA(ExamplesTestsAccelerate):
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
|
||||
|
||||
def test_dreambooth_lora_sana_with_metadata(self):
|
||||
lora_alpha = 8
|
||||
rank = 4
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--resolution=32
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=4
|
||||
--lora_alpha={lora_alpha}
|
||||
--rank={rank}
|
||||
--checkpointing_steps=2
|
||||
--max_sequence_length 166
|
||||
""".split()
|
||||
|
||||
test_args.extend(["--instance_prompt", ""])
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
|
||||
self.assertTrue(os.path.isfile(state_dict_file))
|
||||
|
||||
# Check if the metadata was properly serialized.
|
||||
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata() or {}
|
||||
|
||||
metadata.pop("format", None)
|
||||
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
|
||||
if raw:
|
||||
raw = json.loads(raw)
|
||||
|
||||
loaded_lora_alpha = raw["transformer.lora_alpha"]
|
||||
self.assertTrue(loaded_lora_alpha == lora_alpha)
|
||||
loaded_lora_rank = raw["transformer.r"]
|
||||
self.assertTrue(loaded_lora_rank == rank)
|
||||
|
||||
@@ -63,7 +63,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
# Cache compiled models across invocations of this script.
|
||||
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
||||
|
||||
@@ -65,7 +65,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -74,7 +74,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -73,7 +73,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.33.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -52,6 +52,7 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import (
|
||||
_collate_lora_metadata,
|
||||
cast_training_params,
|
||||
compute_density_for_timestep_sampling,
|
||||
compute_loss_weighting_for_sd3,
|
||||
@@ -71,7 +72,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -323,9 +324,13 @@ def parse_args(input_args=None):
|
||||
default=4,
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lora_alpha",
|
||||
type=int,
|
||||
default=4,
|
||||
help="LoRA alpha to be used for additional scaling.",
|
||||
)
|
||||
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
|
||||
|
||||
parser.add_argument(
|
||||
"--with_prior_preservation",
|
||||
default=False,
|
||||
@@ -1023,7 +1028,7 @@ def main(args):
|
||||
# now we will add new LoRA weights the transformer layers
|
||||
transformer_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
lora_alpha=args.lora_alpha,
|
||||
lora_dropout=args.lora_dropout,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=target_modules,
|
||||
@@ -1039,10 +1044,11 @@ def main(args):
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if accelerator.is_main_process:
|
||||
transformer_lora_layers_to_save = None
|
||||
|
||||
modules_to_save = {}
|
||||
for model in models:
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
modules_to_save["transformer"] = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
@@ -1052,6 +1058,7 @@ def main(args):
|
||||
SanaPipeline.save_lora_weights(
|
||||
output_dir,
|
||||
transformer_lora_layers=transformer_lora_layers_to_save,
|
||||
**_collate_lora_metadata(modules_to_save),
|
||||
)
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
@@ -1507,15 +1514,18 @@ def main(args):
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
transformer = unwrap_model(transformer)
|
||||
modules_to_save = {}
|
||||
if args.upcast_before_saving:
|
||||
transformer.to(torch.float32)
|
||||
else:
|
||||
transformer = transformer.to(weight_dtype)
|
||||
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
||||
modules_to_save["transformer"] = transformer
|
||||
|
||||
SanaPipeline.save_lora_weights(
|
||||
save_directory=args.output_dir,
|
||||
transformer_lora_layers=transformer_lora_layers,
|
||||
**_collate_lora_metadata(modules_to_save),
|
||||
)
|
||||
|
||||
# Final inference
|
||||
|
||||
@@ -72,7 +72,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -55,7 +55,7 @@ from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
if is_torch_npu_available():
|
||||
|
||||
@@ -81,7 +81,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ else:
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.34.0.dev0")
|
||||
check_min_version("0.35.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -95,7 +95,6 @@ TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 = {
|
||||
"mlp.layer1": "ff.net.0.proj",
|
||||
"mlp.layer2": "ff.net.2",
|
||||
"x_embedder.proj.1": "patch_embed.proj",
|
||||
# "extra_pos_embedder": "learnable_pos_embed",
|
||||
"final_layer.adaln_modulation.1": "norm_out.linear_1",
|
||||
"final_layer.adaln_modulation.2": "norm_out.linear_2",
|
||||
"final_layer.linear": "proj_out",
|
||||
|
||||
@@ -269,7 +269,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
|
||||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.34.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="0.35.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
description="State-of-the-art diffusion in PyTorch and JAX.",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
__version__ = "0.34.0.dev0"
|
||||
__version__ = "0.35.0.dev0"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -381,6 +381,7 @@ else:
|
||||
"FluxFillPipeline",
|
||||
"FluxImg2ImgPipeline",
|
||||
"FluxInpaintPipeline",
|
||||
"FluxKontextPipeline",
|
||||
"FluxPipeline",
|
||||
"FluxPriorReduxPipeline",
|
||||
"HiDreamImagePipeline",
|
||||
@@ -974,6 +975,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxFillPipeline,
|
||||
FluxImg2ImgPipeline,
|
||||
FluxInpaintPipeline,
|
||||
FluxKontextPipeline,
|
||||
FluxPipeline,
|
||||
FluxPriorReduxPipeline,
|
||||
HiDreamImagePipeline,
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
|
||||
import os
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import safetensors.torch
|
||||
@@ -46,6 +48,24 @@ _SUPPORTED_PYTORCH_LAYERS = (
|
||||
# fmt: on
|
||||
|
||||
|
||||
class GroupOffloadingType(str, Enum):
|
||||
BLOCK_LEVEL = "block_level"
|
||||
LEAF_LEVEL = "leaf_level"
|
||||
|
||||
|
||||
@dataclass
|
||||
class GroupOffloadingConfig:
|
||||
onload_device: torch.device
|
||||
offload_device: torch.device
|
||||
offload_type: GroupOffloadingType
|
||||
non_blocking: bool
|
||||
record_stream: bool
|
||||
low_cpu_mem_usage: bool
|
||||
num_blocks_per_group: Optional[int] = None
|
||||
offload_to_disk_path: Optional[str] = None
|
||||
stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None
|
||||
|
||||
|
||||
class ModuleGroup:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -96,9 +116,6 @@ class ModuleGroup:
|
||||
else:
|
||||
self.cpu_param_dict = self._init_cpu_param_dict()
|
||||
|
||||
if self.stream is None and self.record_stream:
|
||||
raise ValueError("`record_stream` cannot be True when `stream` is None.")
|
||||
|
||||
def _init_cpu_param_dict(self):
|
||||
cpu_param_dict = {}
|
||||
if self.stream is None:
|
||||
@@ -135,9 +152,58 @@ class ModuleGroup:
|
||||
finally:
|
||||
pinned_dict = None
|
||||
|
||||
def _transfer_tensor_to_device(self, tensor, source_tensor, current_stream=None):
|
||||
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.record_stream and current_stream is not None:
|
||||
tensor.data.record_stream(current_stream)
|
||||
|
||||
def _process_tensors_from_modules(self, pinned_memory=None, current_stream=None):
|
||||
for group_module in self.modules:
|
||||
for param in group_module.parameters():
|
||||
source = pinned_memory[param] if pinned_memory else param.data
|
||||
self._transfer_tensor_to_device(param, source, current_stream)
|
||||
for buffer in group_module.buffers():
|
||||
source = pinned_memory[buffer] if pinned_memory else buffer.data
|
||||
self._transfer_tensor_to_device(buffer, source, current_stream)
|
||||
|
||||
for param in self.parameters:
|
||||
source = pinned_memory[param] if pinned_memory else param.data
|
||||
self._transfer_tensor_to_device(param, source, current_stream)
|
||||
|
||||
for buffer in self.buffers:
|
||||
source = pinned_memory[buffer] if pinned_memory else buffer.data
|
||||
self._transfer_tensor_to_device(buffer, source, current_stream)
|
||||
|
||||
def _onload_from_disk(self, current_stream):
|
||||
if self.stream is not None:
|
||||
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
|
||||
|
||||
for key, tensor_obj in self.key_to_tensor.items():
|
||||
self.cpu_param_dict[tensor_obj] = loaded_cpu_tensors[key]
|
||||
|
||||
with self._pinned_memory_tensors() as pinned_memory:
|
||||
for key, tensor_obj in self.key_to_tensor.items():
|
||||
self._transfer_tensor_to_device(tensor_obj, pinned_memory[tensor_obj], current_stream)
|
||||
|
||||
self.cpu_param_dict.clear()
|
||||
|
||||
else:
|
||||
onload_device = (
|
||||
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
|
||||
)
|
||||
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
|
||||
for key, tensor_obj in self.key_to_tensor.items():
|
||||
tensor_obj.data = loaded_tensors[key]
|
||||
|
||||
def _onload_from_memory(self, current_stream):
|
||||
if self.stream is not None:
|
||||
with self._pinned_memory_tensors() as pinned_memory:
|
||||
self._process_tensors_from_modules(pinned_memory, current_stream)
|
||||
else:
|
||||
self._process_tensors_from_modules(None, current_stream)
|
||||
|
||||
@torch.compiler.disable()
|
||||
def onload_(self):
|
||||
r"""Onloads the group of modules to the onload_device."""
|
||||
torch_accelerator_module = (
|
||||
getattr(torch, torch.accelerator.current_accelerator().type)
|
||||
if hasattr(torch, "accelerator")
|
||||
@@ -175,67 +241,30 @@ class ModuleGroup:
|
||||
self.stream.synchronize()
|
||||
|
||||
with context:
|
||||
if self.stream is not None:
|
||||
with self._pinned_memory_tensors() as pinned_memory:
|
||||
for group_module in self.modules:
|
||||
for param in group_module.parameters():
|
||||
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.record_stream:
|
||||
param.data.record_stream(current_stream)
|
||||
for buffer in group_module.buffers():
|
||||
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.record_stream:
|
||||
buffer.data.record_stream(current_stream)
|
||||
|
||||
for param in self.parameters:
|
||||
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.record_stream:
|
||||
param.data.record_stream(current_stream)
|
||||
|
||||
for buffer in self.buffers:
|
||||
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.record_stream:
|
||||
buffer.data.record_stream(current_stream)
|
||||
|
||||
if self.offload_to_disk_path:
|
||||
self._onload_from_disk(current_stream)
|
||||
else:
|
||||
for group_module in self.modules:
|
||||
for param in group_module.parameters():
|
||||
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
for buffer in group_module.buffers():
|
||||
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
self._onload_from_memory(current_stream)
|
||||
|
||||
for param in self.parameters:
|
||||
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
def _offload_to_disk(self):
|
||||
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
|
||||
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO
|
||||
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
|
||||
# we perform a write.
|
||||
# Check if the file has been saved in this session or if it already exists on disk.
|
||||
if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
|
||||
os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
|
||||
tensors_to_save = {key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()}
|
||||
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
|
||||
|
||||
for buffer in self.buffers:
|
||||
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if self.record_stream:
|
||||
buffer.data.record_stream(current_stream)
|
||||
# The group is now considered offloaded to disk for the rest of the session.
|
||||
self._is_offloaded_to_disk = True
|
||||
|
||||
@torch.compiler.disable()
|
||||
def offload_(self):
|
||||
r"""Offloads the group of modules to the offload_device."""
|
||||
if self.offload_to_disk_path:
|
||||
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
|
||||
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO
|
||||
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
|
||||
# we perform a write.
|
||||
# Check if the file has been saved in this session or if it already exists on disk.
|
||||
if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
|
||||
os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
|
||||
tensors_to_save = {
|
||||
key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()
|
||||
}
|
||||
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
|
||||
|
||||
# The group is now considered offloaded to disk for the rest of the session.
|
||||
self._is_offloaded_to_disk = True
|
||||
|
||||
# We do this to free up the RAM which is still holding the up tensor data.
|
||||
for tensor_obj in self.tensor_to_key.keys():
|
||||
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
|
||||
return
|
||||
# We do this to free up the RAM which is still holding the up tensor data.
|
||||
for tensor_obj in self.tensor_to_key.keys():
|
||||
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
|
||||
|
||||
def _offload_to_memory(self):
|
||||
torch_accelerator_module = (
|
||||
getattr(torch, torch.accelerator.current_accelerator().type)
|
||||
if hasattr(torch, "accelerator")
|
||||
@@ -260,6 +289,14 @@ class ModuleGroup:
|
||||
for buffer in self.buffers:
|
||||
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
|
||||
@torch.compiler.disable()
|
||||
def offload_(self):
|
||||
r"""Offloads the group of modules to the offload_device."""
|
||||
if self.offload_to_disk_path:
|
||||
self._offload_to_disk()
|
||||
else:
|
||||
self._offload_to_memory()
|
||||
|
||||
|
||||
class GroupOffloadingHook(ModelHook):
|
||||
r"""
|
||||
@@ -271,9 +308,12 @@ class GroupOffloadingHook(ModelHook):
|
||||
|
||||
_is_stateful = False
|
||||
|
||||
def __init__(self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None) -> None:
|
||||
def __init__(
|
||||
self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None, *, config: GroupOffloadingConfig
|
||||
) -> None:
|
||||
self.group = group
|
||||
self.next_group = next_group
|
||||
self.config = config
|
||||
|
||||
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
if self.group.offload_leader == module:
|
||||
@@ -419,7 +459,7 @@ def apply_group_offloading(
|
||||
module: torch.nn.Module,
|
||||
onload_device: torch.device,
|
||||
offload_device: torch.device = torch.device("cpu"),
|
||||
offload_type: str = "block_level",
|
||||
offload_type: Union[str, GroupOffloadingType] = "block_level",
|
||||
num_blocks_per_group: Optional[int] = None,
|
||||
non_blocking: bool = False,
|
||||
use_stream: bool = False,
|
||||
@@ -461,7 +501,7 @@ def apply_group_offloading(
|
||||
The device to which the group of modules are onloaded.
|
||||
offload_device (`torch.device`, defaults to `torch.device("cpu")`):
|
||||
The device to which the group of modules are offloaded. This should typically be the CPU. Default is CPU.
|
||||
offload_type (`str`, defaults to "block_level"):
|
||||
offload_type (`str` or `GroupOffloadingType`, defaults to "block_level"):
|
||||
The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
|
||||
"block_level".
|
||||
offload_to_disk_path (`str`, *optional*, defaults to `None`):
|
||||
@@ -504,6 +544,8 @@ def apply_group_offloading(
|
||||
```
|
||||
"""
|
||||
|
||||
offload_type = GroupOffloadingType(offload_type)
|
||||
|
||||
stream = None
|
||||
if use_stream:
|
||||
if torch.cuda.is_available():
|
||||
@@ -513,83 +555,47 @@ def apply_group_offloading(
|
||||
else:
|
||||
raise ValueError("Using streams for data transfer requires a CUDA device, or an Intel XPU device.")
|
||||
|
||||
if not use_stream and record_stream:
|
||||
raise ValueError("`record_stream` cannot be True when `use_stream=False`.")
|
||||
if offload_type == GroupOffloadingType.BLOCK_LEVEL and num_blocks_per_group is None:
|
||||
raise ValueError("`num_blocks_per_group` must be provided when using `offload_type='block_level'.")
|
||||
|
||||
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
|
||||
|
||||
if offload_type == "block_level":
|
||||
if num_blocks_per_group is None:
|
||||
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
|
||||
config = GroupOffloadingConfig(
|
||||
onload_device=onload_device,
|
||||
offload_device=offload_device,
|
||||
offload_type=offload_type,
|
||||
num_blocks_per_group=num_blocks_per_group,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
)
|
||||
_apply_group_offloading(module, config)
|
||||
|
||||
_apply_group_offloading_block_level(
|
||||
module=module,
|
||||
num_blocks_per_group=num_blocks_per_group,
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
elif offload_type == "leaf_level":
|
||||
_apply_group_offloading_leaf_level(
|
||||
module=module,
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
|
||||
if config.offload_type == GroupOffloadingType.BLOCK_LEVEL:
|
||||
_apply_group_offloading_block_level(module, config)
|
||||
elif config.offload_type == GroupOffloadingType.LEAF_LEVEL:
|
||||
_apply_group_offloading_leaf_level(module, config)
|
||||
else:
|
||||
raise ValueError(f"Unsupported offload_type: {offload_type}")
|
||||
assert False
|
||||
|
||||
|
||||
def _apply_group_offloading_block_level(
|
||||
module: torch.nn.Module,
|
||||
num_blocks_per_group: int,
|
||||
offload_device: torch.device,
|
||||
onload_device: torch.device,
|
||||
non_blocking: bool,
|
||||
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
|
||||
record_stream: Optional[bool] = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
offload_to_disk_path: Optional[str] = None,
|
||||
) -> None:
|
||||
def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
|
||||
r"""
|
||||
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
|
||||
the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module to which group offloading is applied.
|
||||
offload_device (`torch.device`):
|
||||
The device to which the group of modules are offloaded. This should typically be the CPU.
|
||||
offload_to_disk_path (`str`, *optional*, defaults to `None`):
|
||||
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
|
||||
RAM environment settings where a reasonable speed-memory trade-off is desired.
|
||||
onload_device (`torch.device`):
|
||||
The device to which the group of modules are onloaded.
|
||||
non_blocking (`bool`):
|
||||
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
|
||||
and data transfer.
|
||||
stream (`torch.cuda.Stream`or `torch.Stream`, *optional*):
|
||||
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
|
||||
for overlapping computation and data transfer.
|
||||
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
|
||||
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
|
||||
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
|
||||
details.
|
||||
low_cpu_mem_usage (`bool`, defaults to `False`):
|
||||
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
|
||||
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
|
||||
the CPU memory is a bottleneck but may counteract the benefits of using streams.
|
||||
"""
|
||||
if stream is not None and num_blocks_per_group != 1:
|
||||
|
||||
if config.stream is not None and config.num_blocks_per_group != 1:
|
||||
logger.warning(
|
||||
f"Using streams is only supported for num_blocks_per_group=1. Got {num_blocks_per_group=}. Setting it to 1."
|
||||
f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1."
|
||||
)
|
||||
num_blocks_per_group = 1
|
||||
config.num_blocks_per_group = 1
|
||||
|
||||
# Create module groups for ModuleList and Sequential blocks
|
||||
modules_with_group_offloading = set()
|
||||
@@ -601,19 +607,19 @@ def _apply_group_offloading_block_level(
|
||||
modules_with_group_offloading.add(name)
|
||||
continue
|
||||
|
||||
for i in range(0, len(submodule), num_blocks_per_group):
|
||||
current_modules = submodule[i : i + num_blocks_per_group]
|
||||
for i in range(0, len(submodule), config.num_blocks_per_group):
|
||||
current_modules = submodule[i : i + config.num_blocks_per_group]
|
||||
group = ModuleGroup(
|
||||
modules=current_modules,
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
offload_device=config.offload_device,
|
||||
onload_device=config.onload_device,
|
||||
offload_to_disk_path=config.offload_to_disk_path,
|
||||
offload_leader=current_modules[-1],
|
||||
onload_leader=current_modules[0],
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
non_blocking=config.non_blocking,
|
||||
stream=config.stream,
|
||||
record_stream=config.record_stream,
|
||||
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
||||
onload_self=True,
|
||||
)
|
||||
matched_module_groups.append(group)
|
||||
@@ -623,7 +629,7 @@ def _apply_group_offloading_block_level(
|
||||
# Apply group offloading hooks to the module groups
|
||||
for i, group in enumerate(matched_module_groups):
|
||||
for group_module in group.modules:
|
||||
_apply_group_offloading_hook(group_module, group, None)
|
||||
_apply_group_offloading_hook(group_module, group, None, config=config)
|
||||
|
||||
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
|
||||
# when the forward pass of this module is called. This is because the top-level module is not
|
||||
@@ -638,9 +644,9 @@ def _apply_group_offloading_block_level(
|
||||
unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
|
||||
unmatched_group = ModuleGroup(
|
||||
modules=unmatched_modules,
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
offload_device=config.offload_device,
|
||||
onload_device=config.onload_device,
|
||||
offload_to_disk_path=config.offload_to_disk_path,
|
||||
offload_leader=module,
|
||||
onload_leader=module,
|
||||
parameters=parameters,
|
||||
@@ -650,54 +656,19 @@ def _apply_group_offloading_block_level(
|
||||
record_stream=False,
|
||||
onload_self=True,
|
||||
)
|
||||
if stream is None:
|
||||
_apply_group_offloading_hook(module, unmatched_group, None)
|
||||
if config.stream is None:
|
||||
_apply_group_offloading_hook(module, unmatched_group, None, config=config)
|
||||
else:
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
|
||||
|
||||
|
||||
def _apply_group_offloading_leaf_level(
|
||||
module: torch.nn.Module,
|
||||
offload_device: torch.device,
|
||||
onload_device: torch.device,
|
||||
non_blocking: bool,
|
||||
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
|
||||
record_stream: Optional[bool] = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
offload_to_disk_path: Optional[str] = None,
|
||||
) -> None:
|
||||
def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
|
||||
r"""
|
||||
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
|
||||
requirements. However, it can be slower compared to other offloading methods due to the excessive number of device
|
||||
synchronizations. When using devices that support streams to overlap data transfer and computation, this method can
|
||||
reduce memory usage without any performance degradation.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module to which group offloading is applied.
|
||||
offload_device (`torch.device`):
|
||||
The device to which the group of modules are offloaded. This should typically be the CPU.
|
||||
onload_device (`torch.device`):
|
||||
The device to which the group of modules are onloaded.
|
||||
offload_to_disk_path (`str`, *optional*, defaults to `None`):
|
||||
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
|
||||
RAM environment settings where a reasonable speed-memory trade-off is desired.
|
||||
non_blocking (`bool`):
|
||||
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
|
||||
and data transfer.
|
||||
stream (`torch.cuda.Stream` or `torch.Stream`, *optional*):
|
||||
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
|
||||
for overlapping computation and data transfer.
|
||||
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
|
||||
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
|
||||
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
|
||||
details.
|
||||
low_cpu_mem_usage (`bool`, defaults to `False`):
|
||||
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
|
||||
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
|
||||
the CPU memory is a bottleneck but may counteract the benefits of using streams.
|
||||
"""
|
||||
|
||||
# Create module groups for leaf modules and apply group offloading hooks
|
||||
modules_with_group_offloading = set()
|
||||
for name, submodule in module.named_modules():
|
||||
@@ -705,18 +676,18 @@ def _apply_group_offloading_leaf_level(
|
||||
continue
|
||||
group = ModuleGroup(
|
||||
modules=[submodule],
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
offload_device=config.offload_device,
|
||||
onload_device=config.onload_device,
|
||||
offload_to_disk_path=config.offload_to_disk_path,
|
||||
offload_leader=submodule,
|
||||
onload_leader=submodule,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
non_blocking=config.non_blocking,
|
||||
stream=config.stream,
|
||||
record_stream=config.record_stream,
|
||||
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_group_offloading_hook(submodule, group, None)
|
||||
_apply_group_offloading_hook(submodule, group, None, config=config)
|
||||
modules_with_group_offloading.add(name)
|
||||
|
||||
# Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
|
||||
@@ -747,33 +718,32 @@ def _apply_group_offloading_leaf_level(
|
||||
parameters = parent_to_parameters.get(name, [])
|
||||
buffers = parent_to_buffers.get(name, [])
|
||||
parent_module = module_dict[name]
|
||||
assert getattr(parent_module, "_diffusers_hook", None) is None
|
||||
group = ModuleGroup(
|
||||
modules=[],
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_device=config.offload_device,
|
||||
onload_device=config.onload_device,
|
||||
offload_leader=parent_module,
|
||||
onload_leader=parent_module,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
offload_to_disk_path=config.offload_to_disk_path,
|
||||
parameters=parameters,
|
||||
buffers=buffers,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
non_blocking=config.non_blocking,
|
||||
stream=config.stream,
|
||||
record_stream=config.record_stream,
|
||||
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_group_offloading_hook(parent_module, group, None)
|
||||
_apply_group_offloading_hook(parent_module, group, None, config=config)
|
||||
|
||||
if stream is not None:
|
||||
if config.stream is not None:
|
||||
# When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
|
||||
# and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the
|
||||
# execution order and apply prefetching in the correct order.
|
||||
unmatched_group = ModuleGroup(
|
||||
modules=[],
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
offload_device=config.offload_device,
|
||||
onload_device=config.onload_device,
|
||||
offload_to_disk_path=config.offload_to_disk_path,
|
||||
offload_leader=module,
|
||||
onload_leader=module,
|
||||
parameters=None,
|
||||
@@ -781,23 +751,25 @@ def _apply_group_offloading_leaf_level(
|
||||
non_blocking=False,
|
||||
stream=None,
|
||||
record_stream=False,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
|
||||
|
||||
|
||||
def _apply_group_offloading_hook(
|
||||
module: torch.nn.Module,
|
||||
group: ModuleGroup,
|
||||
next_group: Optional[ModuleGroup] = None,
|
||||
*,
|
||||
config: GroupOffloadingConfig,
|
||||
) -> None:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
|
||||
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
|
||||
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
|
||||
if registry.get_hook(_GROUP_OFFLOADING) is None:
|
||||
hook = GroupOffloadingHook(group, next_group)
|
||||
hook = GroupOffloadingHook(group, next_group, config=config)
|
||||
registry.register_hook(hook, _GROUP_OFFLOADING)
|
||||
|
||||
|
||||
@@ -805,13 +777,15 @@ def _apply_lazy_group_offloading_hook(
|
||||
module: torch.nn.Module,
|
||||
group: ModuleGroup,
|
||||
next_group: Optional[ModuleGroup] = None,
|
||||
*,
|
||||
config: GroupOffloadingConfig,
|
||||
) -> None:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
|
||||
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
|
||||
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
|
||||
if registry.get_hook(_GROUP_OFFLOADING) is None:
|
||||
hook = GroupOffloadingHook(group, next_group)
|
||||
hook = GroupOffloadingHook(group, next_group, config=config)
|
||||
registry.register_hook(hook, _GROUP_OFFLOADING)
|
||||
|
||||
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()
|
||||
@@ -878,15 +852,48 @@ def _raise_error_if_accelerate_model_or_sequential_hook_present(module: torch.nn
|
||||
)
|
||||
|
||||
|
||||
def _is_group_offload_enabled(module: torch.nn.Module) -> bool:
|
||||
def _get_top_level_group_offload_hook(module: torch.nn.Module) -> Optional[GroupOffloadingHook]:
|
||||
for submodule in module.modules():
|
||||
if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
|
||||
return True
|
||||
return False
|
||||
if hasattr(submodule, "_diffusers_hook"):
|
||||
group_offloading_hook = submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING)
|
||||
if group_offloading_hook is not None:
|
||||
return group_offloading_hook
|
||||
return None
|
||||
|
||||
|
||||
def _is_group_offload_enabled(module: torch.nn.Module) -> bool:
|
||||
top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
|
||||
return top_level_group_offload_hook is not None
|
||||
|
||||
|
||||
def _get_group_onload_device(module: torch.nn.Module) -> torch.device:
|
||||
for submodule in module.modules():
|
||||
if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
|
||||
return submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING).group.onload_device
|
||||
top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
|
||||
if top_level_group_offload_hook is not None:
|
||||
return top_level_group_offload_hook.config.onload_device
|
||||
raise ValueError("Group offloading is not enabled for the provided module.")
|
||||
|
||||
|
||||
def _maybe_remove_and_reapply_group_offloading(module: torch.nn.Module) -> None:
|
||||
r"""
|
||||
Removes the group offloading hook from the module and re-applies it. This is useful when the module has been
|
||||
modified in-place and the group offloading hook references-to-tensors needs to be updated. The in-place
|
||||
modification can happen in a number of ways, for example, fusing QKV or unloading/loading LoRAs on-the-fly.
|
||||
|
||||
In this implementation, we make an assumption that group offloading has only been applied at the top-level module,
|
||||
and therefore all submodules have the same onload and offload devices. If this assumption is not true, say in the
|
||||
case where user has applied group offloading at multiple levels, this function will not work as expected.
|
||||
|
||||
There is some performance penalty associated with doing this when non-default streams are used, because we need to
|
||||
retrace the execution order of the layers with `LazyPrefetchGroupOffloadingHook`.
|
||||
"""
|
||||
top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
|
||||
|
||||
if top_level_group_offload_hook is None:
|
||||
return
|
||||
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
registry.remove_hook(_GROUP_OFFLOADING, recurse=True)
|
||||
registry.remove_hook(_LAYER_EXECUTION_TRACKER, recurse=True)
|
||||
registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=True)
|
||||
|
||||
_apply_group_offloading(module, top_level_group_offload_hook.config)
|
||||
|
||||
@@ -25,6 +25,7 @@ import torch.nn as nn
|
||||
from huggingface_hub import model_info
|
||||
from huggingface_hub.constants import HF_HUB_OFFLINE
|
||||
|
||||
from ..hooks.group_offloading import _is_group_offload_enabled, _maybe_remove_and_reapply_group_offloading
|
||||
from ..models.modeling_utils import ModelMixin, load_state_dict
|
||||
from ..utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
@@ -391,7 +392,9 @@ def _load_lora_into_text_encoder(
|
||||
adapter_name = get_adapter_name(text_encoder)
|
||||
|
||||
# <Unsafe code
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
|
||||
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = _func_optionally_disable_offloading(
|
||||
_pipeline
|
||||
)
|
||||
# inject LoRA layers and load the state dict
|
||||
# in transformers we automatically check whether the adapter name is already in use or not
|
||||
text_encoder.load_adapter(
|
||||
@@ -410,6 +413,10 @@ def _load_lora_into_text_encoder(
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
elif is_group_offload:
|
||||
for component in _pipeline.components.values():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
_maybe_remove_and_reapply_group_offloading(component)
|
||||
# Unsafe code />
|
||||
|
||||
if prefix is not None and not state_dict:
|
||||
@@ -424,27 +431,45 @@ def _load_lora_into_text_encoder(
|
||||
|
||||
|
||||
def _func_optionally_disable_offloading(_pipeline):
|
||||
"""
|
||||
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
|
||||
|
||||
Args:
|
||||
_pipeline (`DiffusionPipeline`):
|
||||
The pipeline to disable offloading for.
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` or `is_group_offload` is True.
|
||||
"""
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
is_group_offload = False
|
||||
|
||||
if _pipeline is not None and _pipeline.hf_device_map is None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
||||
if not is_model_cpu_offload:
|
||||
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
|
||||
if not is_sequential_cpu_offload:
|
||||
is_sequential_cpu_offload = (
|
||||
isinstance(component._hf_hook, AlignDevicesHook)
|
||||
or hasattr(component._hf_hook, "hooks")
|
||||
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
||||
)
|
||||
if not isinstance(component, nn.Module):
|
||||
continue
|
||||
is_group_offload = is_group_offload or _is_group_offload_enabled(component)
|
||||
if not hasattr(component, "_hf_hook"):
|
||||
continue
|
||||
is_model_cpu_offload = is_model_cpu_offload or isinstance(component._hf_hook, CpuOffload)
|
||||
is_sequential_cpu_offload = is_sequential_cpu_offload or (
|
||||
isinstance(component._hf_hook, AlignDevicesHook)
|
||||
or hasattr(component._hf_hook, "hooks")
|
||||
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
if is_sequential_cpu_offload or is_model_cpu_offload:
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
for _, component in _pipeline.components.items():
|
||||
if not isinstance(component, nn.Module) or not hasattr(component, "_hf_hook"):
|
||||
continue
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
||||
return (is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload)
|
||||
|
||||
|
||||
class LoraBaseMixin:
|
||||
@@ -453,6 +478,24 @@ class LoraBaseMixin:
|
||||
_lora_loadable_modules = []
|
||||
_merged_adapters = set()
|
||||
|
||||
@property
|
||||
def lora_scale(self) -> float:
|
||||
"""
|
||||
Returns the lora scale which can be set at run time by the pipeline. # if `_lora_scale` has not been set,
|
||||
return 1.
|
||||
"""
|
||||
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
|
||||
|
||||
@property
|
||||
def num_fused_loras(self):
|
||||
"""Returns the number of LoRAs that have been fused."""
|
||||
return len(self._merged_adapters)
|
||||
|
||||
@property
|
||||
def fused_loras(self):
|
||||
"""Returns names of the LoRAs that have been fused."""
|
||||
return self._merged_adapters
|
||||
|
||||
def load_lora_weights(self, **kwargs):
|
||||
raise NotImplementedError("`load_lora_weights()` is not implemented.")
|
||||
|
||||
@@ -464,33 +507,6 @@ class LoraBaseMixin:
|
||||
def lora_state_dict(cls, **kwargs):
|
||||
raise NotImplementedError("`lora_state_dict()` is not implemented.")
|
||||
|
||||
@classmethod
|
||||
def _optionally_disable_offloading(cls, _pipeline):
|
||||
"""
|
||||
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
|
||||
|
||||
Args:
|
||||
_pipeline (`DiffusionPipeline`):
|
||||
The pipeline to disable offloading for.
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
||||
"""
|
||||
return _func_optionally_disable_offloading(_pipeline=_pipeline)
|
||||
|
||||
@classmethod
|
||||
def _fetch_state_dict(cls, *args, **kwargs):
|
||||
deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
|
||||
deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
|
||||
return _fetch_state_dict(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _best_guess_weight_name(cls, *args, **kwargs):
|
||||
deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
|
||||
deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
|
||||
return _best_guess_weight_name(*args, **kwargs)
|
||||
|
||||
def unload_lora_weights(self):
|
||||
"""
|
||||
Unloads the LoRA parameters.
|
||||
@@ -661,19 +677,37 @@ class LoraBaseMixin:
|
||||
self._merged_adapters = self._merged_adapters - {adapter}
|
||||
module.unmerge()
|
||||
|
||||
@property
|
||||
def num_fused_loras(self):
|
||||
return len(self._merged_adapters)
|
||||
|
||||
@property
|
||||
def fused_loras(self):
|
||||
return self._merged_adapters
|
||||
|
||||
def set_adapters(
|
||||
self,
|
||||
adapter_names: Union[List[str], str],
|
||||
adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
|
||||
):
|
||||
"""
|
||||
Set the currently active adapters for use in the pipeline.
|
||||
|
||||
Args:
|
||||
adapter_names (`List[str]` or `str`):
|
||||
The names of the adapters to use.
|
||||
adapter_weights (`Union[List[float], float]`, *optional*):
|
||||
The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
|
||||
adapters.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights(
|
||||
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
|
||||
)
|
||||
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
|
||||
```
|
||||
"""
|
||||
if isinstance(adapter_weights, dict):
|
||||
components_passed = set(adapter_weights.keys())
|
||||
lora_components = set(self._lora_loadable_modules)
|
||||
@@ -743,6 +777,24 @@ class LoraBaseMixin:
|
||||
set_adapters_for_text_encoder(adapter_names, model, _component_adapter_weights[component])
|
||||
|
||||
def disable_lora(self):
|
||||
"""
|
||||
Disables the active LoRA layers of the pipeline.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights(
|
||||
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
|
||||
)
|
||||
pipeline.disable_lora()
|
||||
```
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
@@ -755,6 +807,24 @@ class LoraBaseMixin:
|
||||
disable_lora_for_text_encoder(model)
|
||||
|
||||
def enable_lora(self):
|
||||
"""
|
||||
Enables the active LoRA layers of the pipeline.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights(
|
||||
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
|
||||
)
|
||||
pipeline.enable_lora()
|
||||
```
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
@@ -768,10 +838,26 @@ class LoraBaseMixin:
|
||||
|
||||
def delete_adapters(self, adapter_names: Union[List[str], str]):
|
||||
"""
|
||||
Delete an adapter's LoRA layers from the pipeline.
|
||||
|
||||
Args:
|
||||
Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s).
|
||||
adapter_names (`Union[List[str], str]`):
|
||||
The names of the adapter to delete. Can be a single string or a list of strings
|
||||
The names of the adapters to delete.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights(
|
||||
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic"
|
||||
)
|
||||
pipeline.delete_adapters("cinematic")
|
||||
```
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
@@ -872,6 +958,24 @@ class LoraBaseMixin:
|
||||
adapter_name
|
||||
].to(device)
|
||||
|
||||
def enable_lora_hotswap(self, **kwargs) -> None:
|
||||
"""
|
||||
Hotswap adapters without triggering recompilation of a model or if the ranks of the loaded adapters are
|
||||
different.
|
||||
|
||||
Args:
|
||||
target_rank (`int`):
|
||||
The highest rank among all the adapters that will be loaded.
|
||||
check_compiled (`str`, *optional*, defaults to `"error"`):
|
||||
How to handle a model that is already compiled. The check can return the following messages:
|
||||
- "error" (default): raise an error
|
||||
- "warn": issue a warning
|
||||
- "ignore": do nothing
|
||||
"""
|
||||
for key, component in self.components.items():
|
||||
if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules):
|
||||
component.enable_lora_hotswap(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def pack_weights(layers, prefix):
|
||||
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
||||
@@ -887,6 +991,7 @@ class LoraBaseMixin:
|
||||
safe_serialization: bool,
|
||||
lora_adapter_metadata: Optional[dict] = None,
|
||||
):
|
||||
"""Writes the state dict of the LoRA layers (optionally with metadata) to disk."""
|
||||
if os.path.isfile(save_directory):
|
||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
return
|
||||
@@ -927,28 +1032,6 @@ class LoraBaseMixin:
|
||||
save_function(state_dict, save_path)
|
||||
logger.info(f"Model weights saved in {save_path}")
|
||||
|
||||
@property
|
||||
def lora_scale(self) -> float:
|
||||
# property function that returns the lora scale which can be set at run time by the pipeline.
|
||||
# if _lora_scale has not been set, return 1
|
||||
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
|
||||
|
||||
def enable_lora_hotswap(self, **kwargs) -> None:
|
||||
"""Enables the possibility to hotswap LoRA adapters.
|
||||
|
||||
Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
|
||||
the loaded adapters differ.
|
||||
|
||||
Args:
|
||||
target_rank (`int`):
|
||||
The highest rank among all the adapters that will be loaded.
|
||||
check_compiled (`str`, *optional*, defaults to `"error"`):
|
||||
How to handle the case when the model is already compiled, which should generally be avoided. The
|
||||
options are:
|
||||
- "error" (default): raise an error
|
||||
- "warn": issue a warning
|
||||
- "ignore": do nothing
|
||||
"""
|
||||
for key, component in self.components.items():
|
||||
if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules):
|
||||
component.enable_lora_hotswap(**kwargs)
|
||||
@classmethod
|
||||
def _optionally_disable_offloading(cls, _pipeline):
|
||||
return _func_optionally_disable_offloading(_pipeline=_pipeline)
|
||||
|
||||
@@ -22,6 +22,7 @@ from typing import Dict, List, Literal, Optional, Union
|
||||
import safetensors
|
||||
import torch
|
||||
|
||||
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
|
||||
from ..utils import (
|
||||
MIN_PEFT_VERSION,
|
||||
USE_PEFT_BACKEND,
|
||||
@@ -85,17 +86,6 @@ class PeftAdapterMixin:
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
|
||||
def _optionally_disable_offloading(cls, _pipeline):
|
||||
"""
|
||||
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
|
||||
|
||||
Args:
|
||||
_pipeline (`DiffusionPipeline`):
|
||||
The pipeline to disable offloading for.
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
||||
"""
|
||||
return _func_optionally_disable_offloading(_pipeline=_pipeline)
|
||||
|
||||
def load_lora_adapter(
|
||||
@@ -254,20 +244,29 @@ class PeftAdapterMixin:
|
||||
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
|
||||
}
|
||||
|
||||
# create LoraConfig
|
||||
lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(self)
|
||||
|
||||
# create LoraConfig
|
||||
lora_config = _create_lora_config(
|
||||
state_dict,
|
||||
network_alphas,
|
||||
metadata,
|
||||
rank,
|
||||
model_state_dict=self.state_dict(),
|
||||
adapter_name=adapter_name,
|
||||
)
|
||||
|
||||
# <Unsafe code
|
||||
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
|
||||
# Now we remove any existing hooks to `_pipeline`.
|
||||
|
||||
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
||||
# otherwise loading LoRA weights will lead to an error.
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
|
||||
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
|
||||
_pipeline
|
||||
)
|
||||
peft_kwargs = {}
|
||||
if is_peft_version(">=", "0.13.1"):
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
@@ -358,6 +357,10 @@ class PeftAdapterMixin:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
elif is_group_offload:
|
||||
for component in _pipeline.components.values():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
_maybe_remove_and_reapply_group_offloading(component)
|
||||
# Unsafe code />
|
||||
|
||||
if prefix is not None and not state_dict:
|
||||
@@ -444,7 +447,7 @@ class PeftAdapterMixin:
|
||||
weights: Optional[Union[float, Dict, List[float], List[Dict], List[None]]] = None,
|
||||
):
|
||||
"""
|
||||
Set the currently active adapters for use in the UNet.
|
||||
Set the currently active adapters for use in the diffusion network (e.g. unet, transformer, etc.).
|
||||
|
||||
Args:
|
||||
adapter_names (`List[str]` or `str`):
|
||||
@@ -466,7 +469,7 @@ class PeftAdapterMixin:
|
||||
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
|
||||
)
|
||||
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
|
||||
pipeline.unet.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
|
||||
```
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
@@ -697,6 +700,10 @@ class PeftAdapterMixin:
|
||||
recurse_remove_peft_layers(self)
|
||||
if hasattr(self, "peft_config"):
|
||||
del self.peft_config
|
||||
if hasattr(self, "_hf_peft_config_loaded"):
|
||||
self._hf_peft_config_loaded = None
|
||||
|
||||
_maybe_remove_and_reapply_group_offloading(self)
|
||||
|
||||
def disable_lora(self):
|
||||
"""
|
||||
@@ -714,7 +721,7 @@ class PeftAdapterMixin:
|
||||
pipeline.load_lora_weights(
|
||||
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
|
||||
)
|
||||
pipeline.disable_lora()
|
||||
pipeline.unet.disable_lora()
|
||||
```
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
@@ -737,7 +744,7 @@ class PeftAdapterMixin:
|
||||
pipeline.load_lora_weights(
|
||||
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
|
||||
)
|
||||
pipeline.enable_lora()
|
||||
pipeline.unet.enable_lora()
|
||||
```
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
@@ -764,7 +771,7 @@ class PeftAdapterMixin:
|
||||
pipeline.load_lora_weights(
|
||||
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic"
|
||||
)
|
||||
pipeline.delete_adapters("cinematic")
|
||||
pipeline.unet.delete_adapters("cinematic")
|
||||
```
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
|
||||
@@ -31,6 +31,7 @@ from .single_file_utils import (
|
||||
convert_autoencoder_dc_checkpoint_to_diffusers,
|
||||
convert_chroma_transformer_checkpoint_to_diffusers,
|
||||
convert_controlnet_checkpoint,
|
||||
convert_cosmos_transformer_checkpoint_to_diffusers,
|
||||
convert_flux_transformer_checkpoint_to_diffusers,
|
||||
convert_hidream_transformer_to_diffusers,
|
||||
convert_hunyuan_video_transformer_to_diffusers,
|
||||
@@ -143,6 +144,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"checkpoint_mapping_fn": convert_hidream_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"CosmosTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_cosmos_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -127,6 +127,16 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
|
||||
"wan_vae": "decoder.middle.0.residual.0.gamma",
|
||||
"hidream": "double_stream_blocks.0.block.adaLN_modulation.1.bias",
|
||||
"cosmos-1.0": [
|
||||
"net.x_embedder.proj.1.weight",
|
||||
"net.blocks.block1.blocks.0.block.attn.to_q.0.weight",
|
||||
"net.extra_pos_embedder.pos_emb_h",
|
||||
],
|
||||
"cosmos-2.0": [
|
||||
"net.x_embedder.proj.1.weight",
|
||||
"net.blocks.0.self_attn.q_proj.weight",
|
||||
"net.pos_embedder.dim_spatial_range",
|
||||
],
|
||||
}
|
||||
|
||||
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
@@ -193,6 +203,14 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
|
||||
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
|
||||
"hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"},
|
||||
"cosmos-1.0-t2w-7B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-7B-Text2World"},
|
||||
"cosmos-1.0-t2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-14B-Text2World"},
|
||||
"cosmos-1.0-v2w-7B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-7B-Video2World"},
|
||||
"cosmos-1.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-14B-Video2World"},
|
||||
"cosmos-2.0-t2i-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Text2Image"},
|
||||
"cosmos-2.0-t2i-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Text2Image"},
|
||||
"cosmos-2.0-v2w-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Video2World"},
|
||||
"cosmos-2.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Video2World"},
|
||||
}
|
||||
|
||||
# Use to configure model sample size when original config is provided
|
||||
@@ -704,11 +722,32 @@ def infer_diffusers_model_type(checkpoint):
|
||||
model_type = "wan-t2v-14B"
|
||||
else:
|
||||
model_type = "wan-i2v-14B"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint:
|
||||
# All Wan models use the same VAE so we can use the same default model repo to fetch the config
|
||||
model_type = "wan-t2v-14B"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["hidream"] in checkpoint:
|
||||
model_type = "hidream"
|
||||
|
||||
elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["cosmos-1.0"]):
|
||||
x_embedder_shape = checkpoint[CHECKPOINT_KEY_NAMES["cosmos-1.0"][0]].shape
|
||||
if x_embedder_shape[1] == 68:
|
||||
model_type = "cosmos-1.0-t2w-7B" if x_embedder_shape[0] == 4096 else "cosmos-1.0-t2w-14B"
|
||||
elif x_embedder_shape[1] == 72:
|
||||
model_type = "cosmos-1.0-v2w-7B" if x_embedder_shape[0] == 4096 else "cosmos-1.0-v2w-14B"
|
||||
else:
|
||||
raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 1.0 model.")
|
||||
|
||||
elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["cosmos-2.0"]):
|
||||
x_embedder_shape = checkpoint[CHECKPOINT_KEY_NAMES["cosmos-2.0"][0]].shape
|
||||
if x_embedder_shape[1] == 68:
|
||||
model_type = "cosmos-2.0-t2i-2B" if x_embedder_shape[0] == 2048 else "cosmos-2.0-t2i-14B"
|
||||
elif x_embedder_shape[1] == 72:
|
||||
model_type = "cosmos-2.0-v2w-2B" if x_embedder_shape[0] == 2048 else "cosmos-2.0-v2w-14B"
|
||||
else:
|
||||
raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 2.0 model.")
|
||||
|
||||
else:
|
||||
model_type = "v1"
|
||||
|
||||
@@ -3479,3 +3518,116 @@ def convert_chroma_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_cosmos_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
|
||||
|
||||
def remove_keys_(key: str, state_dict):
|
||||
state_dict.pop(key)
|
||||
|
||||
def rename_transformer_blocks_(key: str, state_dict):
|
||||
block_index = int(key.split(".")[1].removeprefix("block"))
|
||||
new_key = key
|
||||
old_prefix = f"blocks.block{block_index}"
|
||||
new_prefix = f"transformer_blocks.{block_index}"
|
||||
new_key = new_prefix + new_key.removeprefix(old_prefix)
|
||||
state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0 = {
|
||||
"t_embedder.1": "time_embed.t_embedder",
|
||||
"affline_norm": "time_embed.norm",
|
||||
".blocks.0.block.attn": ".attn1",
|
||||
".blocks.1.block.attn": ".attn2",
|
||||
".blocks.2.block": ".ff",
|
||||
".blocks.0.adaLN_modulation.1": ".norm1.linear_1",
|
||||
".blocks.0.adaLN_modulation.2": ".norm1.linear_2",
|
||||
".blocks.1.adaLN_modulation.1": ".norm2.linear_1",
|
||||
".blocks.1.adaLN_modulation.2": ".norm2.linear_2",
|
||||
".blocks.2.adaLN_modulation.1": ".norm3.linear_1",
|
||||
".blocks.2.adaLN_modulation.2": ".norm3.linear_2",
|
||||
"to_q.0": "to_q",
|
||||
"to_q.1": "norm_q",
|
||||
"to_k.0": "to_k",
|
||||
"to_k.1": "norm_k",
|
||||
"to_v.0": "to_v",
|
||||
"layer1": "net.0.proj",
|
||||
"layer2": "net.2",
|
||||
"proj.1": "proj",
|
||||
"x_embedder": "patch_embed",
|
||||
"extra_pos_embedder": "learnable_pos_embed",
|
||||
"final_layer.adaLN_modulation.1": "norm_out.linear_1",
|
||||
"final_layer.adaLN_modulation.2": "norm_out.linear_2",
|
||||
"final_layer.linear": "proj_out",
|
||||
}
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0 = {
|
||||
"blocks.block": rename_transformer_blocks_,
|
||||
"logvar.0.freqs": remove_keys_,
|
||||
"logvar.0.phases": remove_keys_,
|
||||
"logvar.1.weight": remove_keys_,
|
||||
"pos_embedder.seq": remove_keys_,
|
||||
}
|
||||
|
||||
TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 = {
|
||||
"t_embedder.1": "time_embed.t_embedder",
|
||||
"t_embedding_norm": "time_embed.norm",
|
||||
"blocks": "transformer_blocks",
|
||||
"adaln_modulation_self_attn.1": "norm1.linear_1",
|
||||
"adaln_modulation_self_attn.2": "norm1.linear_2",
|
||||
"adaln_modulation_cross_attn.1": "norm2.linear_1",
|
||||
"adaln_modulation_cross_attn.2": "norm2.linear_2",
|
||||
"adaln_modulation_mlp.1": "norm3.linear_1",
|
||||
"adaln_modulation_mlp.2": "norm3.linear_2",
|
||||
"self_attn": "attn1",
|
||||
"cross_attn": "attn2",
|
||||
"q_proj": "to_q",
|
||||
"k_proj": "to_k",
|
||||
"v_proj": "to_v",
|
||||
"output_proj": "to_out.0",
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
"mlp.layer1": "ff.net.0.proj",
|
||||
"mlp.layer2": "ff.net.2",
|
||||
"x_embedder.proj.1": "patch_embed.proj",
|
||||
"final_layer.adaln_modulation.1": "norm_out.linear_1",
|
||||
"final_layer.adaln_modulation.2": "norm_out.linear_2",
|
||||
"final_layer.linear": "proj_out",
|
||||
}
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0 = {
|
||||
"accum_video_sample_counter": remove_keys_,
|
||||
"accum_image_sample_counter": remove_keys_,
|
||||
"accum_iteration": remove_keys_,
|
||||
"accum_train_in_hours": remove_keys_,
|
||||
"pos_embedder.seq": remove_keys_,
|
||||
"pos_embedder.dim_spatial_range": remove_keys_,
|
||||
"pos_embedder.dim_temporal_range": remove_keys_,
|
||||
"_extra_state": remove_keys_,
|
||||
}
|
||||
|
||||
PREFIX_KEY = "net."
|
||||
if "net.blocks.block1.blocks.0.block.attn.to_q.0.weight" in checkpoint:
|
||||
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0
|
||||
else:
|
||||
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
|
||||
|
||||
state_dict_keys = list(converted_state_dict.keys())
|
||||
for key in state_dict_keys:
|
||||
new_key = key[:]
|
||||
if new_key.startswith(PREFIX_KEY):
|
||||
new_key = new_key.removeprefix(PREFIX_KEY)
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
converted_state_dict[new_key] = converted_state_dict.pop(key)
|
||||
|
||||
state_dict_keys = list(converted_state_dict.keys())
|
||||
for key in state_dict_keys:
|
||||
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, converted_state_dict)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
@@ -427,7 +427,8 @@ class TextualInversionLoaderMixin:
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
if is_sequential_cpu_offload or is_model_cpu_offload:
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
# 7.2 save expected device and dtype
|
||||
device = text_encoder.device
|
||||
|
||||
@@ -22,6 +22,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
|
||||
from ..models.embeddings import (
|
||||
ImageProjection,
|
||||
IPAdapterFaceIDImageProjection,
|
||||
@@ -203,6 +204,7 @@ class UNet2DConditionLoadersMixin:
|
||||
is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys())
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
is_group_offload = False
|
||||
|
||||
if is_lora:
|
||||
deprecation_message = "Using the `load_attn_procs()` method has been deprecated and will be removed in a future version. Please use `load_lora_adapter()`."
|
||||
@@ -211,7 +213,7 @@ class UNet2DConditionLoadersMixin:
|
||||
if is_custom_diffusion:
|
||||
attn_processors = self._process_custom_diffusion(state_dict=state_dict)
|
||||
elif is_lora:
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = self._process_lora(
|
||||
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._process_lora(
|
||||
state_dict=state_dict,
|
||||
unet_identifier_key=self.unet_name,
|
||||
network_alphas=network_alphas,
|
||||
@@ -230,7 +232,9 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
# For LoRA, the UNet is already offloaded at this stage as it is handled inside `_process_lora`.
|
||||
if is_custom_diffusion and _pipeline is not None:
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline=_pipeline)
|
||||
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
|
||||
_pipeline=_pipeline
|
||||
)
|
||||
|
||||
# only custom diffusion needs to set attn processors
|
||||
self.set_attn_processor(attn_processors)
|
||||
@@ -241,6 +245,10 @@ class UNet2DConditionLoadersMixin:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
elif is_group_offload:
|
||||
for component in _pipeline.components.values():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
_maybe_remove_and_reapply_group_offloading(component)
|
||||
# Unsafe code />
|
||||
|
||||
def _process_custom_diffusion(self, state_dict):
|
||||
@@ -307,6 +315,7 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
is_group_offload = False
|
||||
state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict
|
||||
|
||||
if len(state_dict_to_be_used) > 0:
|
||||
@@ -356,7 +365,9 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
||||
# otherwise loading LoRA weights will lead to an error
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
|
||||
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
|
||||
_pipeline
|
||||
)
|
||||
peft_kwargs = {}
|
||||
if is_peft_version(">=", "0.13.1"):
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
@@ -389,22 +400,11 @@ class UNet2DConditionLoadersMixin:
|
||||
if warn_msg:
|
||||
logger.warning(warn_msg)
|
||||
|
||||
return is_model_cpu_offload, is_sequential_cpu_offload
|
||||
return is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
|
||||
def _optionally_disable_offloading(cls, _pipeline):
|
||||
"""
|
||||
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
|
||||
|
||||
Args:
|
||||
_pipeline (`DiffusionPipeline`):
|
||||
The pipeline to disable offloading for.
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
||||
"""
|
||||
return _func_optionally_disable_offloading(_pipeline=_pipeline)
|
||||
|
||||
def save_attn_procs(
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, Dict, List, Union
|
||||
|
||||
from torch import nn
|
||||
|
||||
from ..utils import logging
|
||||
|
||||
|
||||
@@ -52,7 +54,7 @@ def _maybe_expand_lora_scales(
|
||||
weight_for_adapter,
|
||||
blocks_with_transformer,
|
||||
transformer_per_block,
|
||||
unet.state_dict(),
|
||||
model=unet,
|
||||
default_scale=default_scale,
|
||||
)
|
||||
for weight_for_adapter in weight_scales
|
||||
@@ -65,7 +67,7 @@ def _maybe_expand_lora_scales_for_one_adapter(
|
||||
scales: Union[float, Dict],
|
||||
blocks_with_transformer: Dict[str, int],
|
||||
transformer_per_block: Dict[str, int],
|
||||
state_dict: None,
|
||||
model: nn.Module,
|
||||
default_scale: float = 1.0,
|
||||
):
|
||||
"""
|
||||
@@ -154,6 +156,7 @@ def _maybe_expand_lora_scales_for_one_adapter(
|
||||
|
||||
del scales[updown]
|
||||
|
||||
state_dict = model.state_dict()
|
||||
for layer in scales.keys():
|
||||
if not any(_translate_into_actual_layer_name(layer) in module for module in state_dict.keys()):
|
||||
raise ValueError(
|
||||
|
||||
@@ -110,8 +110,11 @@ class CosmosPatchEmbed3d(nn.Module):
|
||||
self.patch_size = patch_size
|
||||
self.patch_method = patch_method
|
||||
|
||||
self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=False)
|
||||
self.register_buffer("_arange", torch.arange(_WAVELETS[patch_method].shape[0]), persistent=False)
|
||||
wavelets = _WAVELETS.get(patch_method).clone()
|
||||
arange = torch.arange(wavelets.shape[0])
|
||||
|
||||
self.register_buffer("wavelets", wavelets, persistent=False)
|
||||
self.register_buffer("_arange", arange, persistent=False)
|
||||
|
||||
def _dwt(self, hidden_states: torch.Tensor, mode: str = "reflect", rescale=False) -> torch.Tensor:
|
||||
dtype = hidden_states.dtype
|
||||
@@ -185,12 +188,11 @@ class CosmosUnpatcher3d(nn.Module):
|
||||
self.patch_size = patch_size
|
||||
self.patch_method = patch_method
|
||||
|
||||
self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=False)
|
||||
self.register_buffer(
|
||||
"_arange",
|
||||
torch.arange(_WAVELETS[patch_method].shape[0]),
|
||||
persistent=False,
|
||||
)
|
||||
wavelets = _WAVELETS.get(patch_method).clone()
|
||||
arange = torch.arange(wavelets.shape[0])
|
||||
|
||||
self.register_buffer("wavelets", wavelets, persistent=False)
|
||||
self.register_buffer("_arange", arange, persistent=False)
|
||||
|
||||
def _idwt(self, hidden_states: torch.Tensor, rescale: bool = False) -> torch.Tensor:
|
||||
device = hidden_states.device
|
||||
|
||||
@@ -1199,11 +1199,11 @@ def apply_rotary_emb(
|
||||
|
||||
if use_real_unbind_dim == -1:
|
||||
# Used for flux, cogvideox, hunyuan-dit
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
elif use_real_unbind_dim == -2:
|
||||
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2]
|
||||
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
||||
else:
|
||||
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
||||
|
||||
@@ -266,6 +266,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
_keep_in_fp32_modules = None
|
||||
_skip_layerwise_casting_patterns = None
|
||||
_supports_group_offloading = True
|
||||
_repeated_blocks = []
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -1404,6 +1405,39 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
else:
|
||||
return super().float(*args)
|
||||
|
||||
def compile_repeated_blocks(self, *args, **kwargs):
|
||||
"""
|
||||
Compiles *only* the frequently repeated sub-modules of a model (e.g. the Transformer layers) instead of
|
||||
compiling the entire model. This technique—often called **regional compilation** (see the PyTorch recipe
|
||||
https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) can reduce end-to-end compile time
|
||||
substantially, while preserving the runtime speed-ups you would expect from a full `torch.compile`.
|
||||
|
||||
The set of sub-modules to compile is discovered by the presence of **`_repeated_blocks`** attribute in the
|
||||
model definition. Define this attribute on your model subclass as a list/tuple of class names (strings). Every
|
||||
module whose class name matches will be compiled.
|
||||
|
||||
Once discovered, each matching sub-module is compiled by calling `submodule.compile(*args, **kwargs)`. Any
|
||||
positional or keyword arguments you supply to `compile_repeated_blocks` are forwarded verbatim to
|
||||
`torch.compile`.
|
||||
"""
|
||||
repeated_blocks = getattr(self, "_repeated_blocks", None)
|
||||
|
||||
if not repeated_blocks:
|
||||
raise ValueError(
|
||||
"`_repeated_blocks` attribute is empty. "
|
||||
f"Set `_repeated_blocks` for the class `{self.__class__.__name__}` to benefit from faster compilation. "
|
||||
)
|
||||
has_compiled_region = False
|
||||
for submod in self.modules():
|
||||
if submod.__class__.__name__ in repeated_blocks:
|
||||
submod.compile(*args, **kwargs)
|
||||
has_compiled_region = True
|
||||
|
||||
if not has_compiled_region:
|
||||
raise ValueError(
|
||||
f"Regional compilation failed because {repeated_blocks} classes are not found in the model. "
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _load_pretrained_model(
|
||||
cls,
|
||||
|
||||
@@ -407,6 +407,7 @@ class ChromaTransformer2DModel(
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"]
|
||||
_repeated_blocks = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
||||
|
||||
@register_to_config
|
||||
|
||||
@@ -20,6 +20,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin
|
||||
from ...utils import is_torchvision_available
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention
|
||||
@@ -377,7 +378,7 @@ class CosmosLearnablePositionalEmbed(nn.Module):
|
||||
return (emb / norm).type_as(hidden_states)
|
||||
|
||||
|
||||
class CosmosTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
r"""
|
||||
A Transformer model for video-like data used in [Cosmos](https://github.com/NVIDIA/Cosmos).
|
||||
|
||||
|
||||
@@ -227,6 +227,7 @@ class FluxTransformer2DModel(
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
||||
_repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -870,6 +870,12 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
"HunyuanVideoPatchEmbed",
|
||||
"HunyuanVideoTokenRefiner",
|
||||
]
|
||||
_repeated_blocks = [
|
||||
"HunyuanVideoTransformerBlock",
|
||||
"HunyuanVideoSingleTransformerBlock",
|
||||
"HunyuanVideoPatchEmbed",
|
||||
"HunyuanVideoTokenRefiner",
|
||||
]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -328,6 +328,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_skip_layerwise_casting_patterns = ["norm"]
|
||||
_repeated_blocks = ["LTXVideoTransformerBlock"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
@@ -481,7 +482,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
|
||||
|
||||
def apply_rotary_emb(x, freqs):
|
||||
cos, sin = freqs
|
||||
x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, H, D // 2]
|
||||
x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, C // 2]
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2)
|
||||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||
return out
|
||||
|
||||
@@ -345,6 +345,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
||||
_no_split_modules = ["WanTransformerBlock"]
|
||||
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
|
||||
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
|
||||
_repeated_blocks = ["WanTransformerBlock"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -167,6 +167,7 @@ class UNet2DConditionModel(
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
|
||||
_skip_layerwise_casting_patterns = ["norm"]
|
||||
_repeated_blocks = ["BasicTransformerBlock"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -140,6 +140,7 @@ else:
|
||||
"FluxFillPipeline",
|
||||
"FluxPriorReduxPipeline",
|
||||
"ReduxImageEncoder",
|
||||
"FluxKontextPipeline",
|
||||
]
|
||||
_import_structure["audioldm"] = ["AudioLDMPipeline"]
|
||||
_import_structure["audioldm2"] = [
|
||||
@@ -609,6 +610,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxFillPipeline,
|
||||
FluxImg2ImgPipeline,
|
||||
FluxInpaintPipeline,
|
||||
FluxKontextPipeline,
|
||||
FluxPipeline,
|
||||
FluxPriorReduxPipeline,
|
||||
ReduxImageEncoder,
|
||||
|
||||
@@ -41,7 +41,7 @@ from ...utils import (
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ...utils.import_utils import is_transformers_version
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...utils.torch_utils import empty_device_cache, randn_tensor
|
||||
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
|
||||
from .modeling_audioldm2 import AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
|
||||
|
||||
@@ -267,9 +267,7 @@ class AudioLDM2Pipeline(DiffusionPipeline):
|
||||
|
||||
if self.device.type != "cpu":
|
||||
self.to("cpu", silence_dtype_warnings=True)
|
||||
device_mod = getattr(torch, device.type, None)
|
||||
if hasattr(device_mod, "empty_cache") and device_mod.is_available():
|
||||
device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
||||
empty_device_cache(device.type)
|
||||
|
||||
model_sequence = [
|
||||
self.text_encoder.text_model,
|
||||
|
||||
@@ -52,20 +52,21 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> import torch
|
||||
>>> from diffusers import ChromaPipeline
|
||||
|
||||
>>> model_id = "lodestones/Chroma"
|
||||
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
|
||||
>>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
|
||||
>>> text_encoder = AutoModel.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="text_encoder_2")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="tokenizer_2")
|
||||
>>> pipe = ChromaImg2ImgPipeline.from_pretrained(
|
||||
... "black-forest-labs/FLUX.1-schnell",
|
||||
>>> pipe = ChromaPipeline.from_pretrained(
|
||||
... model_id,
|
||||
... transformer=transformer,
|
||||
... text_encoder=text_encoder,
|
||||
... tokenizer=tokenizer,
|
||||
... torch_dtype=torch.bfloat16,
|
||||
... )
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
>>> prompt = "A cat holding a sign that says hello world"
|
||||
>>> negative_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
|
||||
>>> prompt = [
|
||||
... "A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done."
|
||||
... ]
|
||||
>>> negative_prompt = [
|
||||
... "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
|
||||
... ]
|
||||
>>> image = pipe(prompt, negative_prompt=negative_prompt).images[0]
|
||||
>>> image.save("chroma.png")
|
||||
```
|
||||
|
||||
@@ -51,26 +51,21 @@ EXAMPLE_DOC_STRING = """
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import ChromaTransformer2DModel, ChromaImg2ImgPipeline
|
||||
>>> from transformers import AutoModel, Autotokenizer
|
||||
|
||||
>>> model_id = "lodestones/Chroma"
|
||||
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
|
||||
>>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
|
||||
>>> text_encoder = AutoModel.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="text_encoder_2")
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="tokenizer_2")
|
||||
>>> pipe = ChromaImg2ImgPipeline.from_pretrained(
|
||||
... "black-forest-labs/FLUX.1-schnell",
|
||||
... model_id,
|
||||
... transformer=transformer,
|
||||
... text_encoder=text_encoder,
|
||||
... tokenizer=tokenizer,
|
||||
... torch_dtype=torch.bfloat16,
|
||||
... )
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
>>> image = load_image(
|
||||
>>> init_image = load_image(
|
||||
... "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
||||
... )
|
||||
>>> prompt = "a scenic fastasy landscape with a river and mountains in the background, vibrant colors, detailed, high resolution"
|
||||
>>> negative_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
|
||||
>>> image = pipe(prompt, image=image, negative_prompt=negative_prompt).images[0]
|
||||
>>> image = pipe(prompt, image=init_image, negative_prompt=negative_prompt).images[0]
|
||||
>>> image.save("chroma-img2img.png")
|
||||
```
|
||||
"""
|
||||
|
||||
@@ -294,7 +294,7 @@ def prepare_face_models(model_path, device, dtype):
|
||||
|
||||
Parameters:
|
||||
- model_path: Path to the directory containing model files.
|
||||
- device: The device (e.g., 'cuda', 'cpu') where models will be loaded.
|
||||
- device: The device (e.g., 'cuda', 'xpu', 'cpu') where models will be loaded.
|
||||
- dtype: Data type (e.g., torch.float32) for model inference.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -37,7 +37,7 @@ from ...utils import (
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
|
||||
from ...utils.torch_utils import empty_device_cache, is_compiled_module, is_torch_version, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
@@ -1339,7 +1339,7 @@ class StableDiffusionControlNetPipeline(
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.unet.to("cpu")
|
||||
self.controlnet.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
empty_device_cache()
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
|
||||
|
||||
@@ -36,7 +36,7 @@ from ...utils import (
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import is_compiled_module, randn_tensor
|
||||
from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
@@ -1311,7 +1311,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.unet.to("cpu")
|
||||
self.controlnet.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
empty_device_cache()
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
|
||||
|
||||
@@ -38,7 +38,7 @@ from ...utils import (
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import is_compiled_module, randn_tensor
|
||||
from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
@@ -1500,7 +1500,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.unet.to("cpu")
|
||||
self.controlnet.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
empty_device_cache()
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
|
||||
|
||||
@@ -51,7 +51,7 @@ from ...utils import (
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import is_compiled_module, randn_tensor
|
||||
from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
||||
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||
|
||||
@@ -1858,7 +1858,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.unet.to("cpu")
|
||||
self.controlnet.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
empty_device_cache()
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
|
||||
@@ -1465,7 +1465,11 @@ class StableDiffusionXLControlNetPipeline(
|
||||
|
||||
# Relevant thread:
|
||||
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
|
||||
if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
|
||||
if (
|
||||
torch.cuda.is_available()
|
||||
and (is_unet_compiled and is_controlnet_compiled)
|
||||
and is_torch_higher_equal_2_1
|
||||
):
|
||||
torch._inductor.cudagraph_mark_step_begin()
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user