Compare commits
85 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1cd5155bb8 | |||
| b14bffeffe | |||
| e66c4d0dab | |||
| 7d2c7d5553 | |||
| 78135f1478 | |||
| 9ff72433fa | |||
| c1926cef6b | |||
| 8421c1461b | |||
| cfdeebd4a8 | |||
| 6a51427b6a | |||
| 5effcd3e64 | |||
| 619b9658e2 | |||
| b58f67f2d5 | |||
| 8ac6de963c | |||
| 2be66e6aa0 | |||
| cf258948b2 | |||
| d8408677c5 | |||
| 63b631f383 | |||
| acf79b3487 | |||
| fc72e0f261 | |||
| 0763a7edf4 | |||
| 963ffca434 | |||
| 30f2e9bd20 | |||
| 2312b27f79 | |||
| 6db33337a4 | |||
| beb856685d | |||
| a9d3f6c359 | |||
| cd344393e2 | |||
| c44fba8899 | |||
| 922c5f5c3c | |||
| 8d386f7990 | |||
| 827b6c25f9 | |||
| 784b351f32 | |||
| cbee7cbc6b | |||
| c96bfa5c80 | |||
| 6b288ec44d | |||
| fdec8bd675 | |||
| 2eeda25321 | |||
| 069186fac5 | |||
| 69c83d6eed | |||
| e44fc75acb | |||
| e47cc1fc1a | |||
| 75bd1e83cb | |||
| 0389333113 | |||
| 1fb86e34c0 | |||
| 8d477daed5 | |||
| ad5ecd1251 | |||
| 074e12358b | |||
| 047bf49291 | |||
| c4b5d2ff6b | |||
| 7ac6e286ee | |||
| b5fd6f13f5 | |||
| 64b3e0f539 | |||
| 2e86a3f023 | |||
| cd6ca9df29 | |||
| e564abe292 | |||
| 3139d39fa7 | |||
| 12358622e5 | |||
| 805aa93789 | |||
| f6f7afa1d7 | |||
| 637e2302ac | |||
| 99c0483b67 | |||
| cc7d88f247 | |||
| ea40933f36 | |||
| 0583a8d12a | |||
| 7d0b9c4d4e | |||
| acf479bded | |||
| 03bf77c4af | |||
| 3b2830618d | |||
| c3c94fe71b | |||
| 365a938884 | |||
| 345907f32d | |||
| 07d0fbf3ec | |||
| 1d2204d3a0 | |||
| d38c50c8dd | |||
| e255920719 | |||
| 40ab1c03f3 | |||
| 5c94937dc7 | |||
| d74483c47a | |||
| 1dbd26fa23 | |||
| dac623b59f | |||
| 8d6dc2be5d | |||
| d720b2132e | |||
| 9cc96a64f1 | |||
| 5b972fbd6a |
@@ -347,6 +347,64 @@ jobs:
|
||||
pip install slack_sdk tabulate
|
||||
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
run_nightly_quantization_tests:
|
||||
name: Torch quantization nightly tests
|
||||
strategy:
|
||||
fail-fast: false
|
||||
max-parallel: 2
|
||||
matrix:
|
||||
config:
|
||||
- backend: "bitsandbytes"
|
||||
test_location: "bnb"
|
||||
runs-on:
|
||||
group: aws-g6e-xlarge-plus
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "20gb" --ipc host --gpus 0
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: NVIDIA-SMI
|
||||
run: nvidia-smi
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install -U ${{ matrix.config.backend }}
|
||||
python -m uv pip install pytest-reportlog
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
- name: ${{ matrix.config.backend }} quantization tests on GPU
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
# https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms
|
||||
CUBLAS_WORKSPACE_CONFIG: :16:8
|
||||
BIG_GPU_MEMORY: 40
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
--make-reports=tests_${{ matrix.config.backend }}_torch_cuda \
|
||||
--report-log=tests_${{ matrix.config.backend }}_torch_cuda.log \
|
||||
tests/quantization/${{ matrix.config.test_location }}
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_${{ matrix.config.backend }}_torch_cuda_stats.txt
|
||||
cat reports/tests_${{ matrix.config.backend }}_torch_cuda_failures_short.txt
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: torch_cuda_${{ matrix.config.backend }}_reports
|
||||
path: reports
|
||||
- name: Generate Report and Notify Channel
|
||||
if: always()
|
||||
run: |
|
||||
pip install slack_sdk tabulate
|
||||
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
# M1 runner currently not well supported
|
||||
# TODO: (Dhruv) add these back when we setup better testing for Apple Silicon
|
||||
# run_nightly_tests_apple_m1:
|
||||
|
||||
@@ -114,7 +114,7 @@ Check out the [Quickstart](https://huggingface.co/docs/diffusers/quicktour) to l
|
||||
| [Tutorial](https://huggingface.co/docs/diffusers/tutorials/tutorial_overview) | A basic crash course for learning how to use the library's most important features like using models and schedulers to build your own diffusion system, and training your own diffusion model. |
|
||||
| [Loading](https://huggingface.co/docs/diffusers/using-diffusers/loading_overview) | Guides for how to load and configure all the components (pipelines, models, and schedulers) of the library, as well as how to use different schedulers. |
|
||||
| [Pipelines for inference](https://huggingface.co/docs/diffusers/using-diffusers/pipeline_overview) | Guides for how to use pipelines for different inference tasks, batched generation, controlling generated outputs and randomness, and how to contribute a pipeline to the library. |
|
||||
| [Optimization](https://huggingface.co/docs/diffusers/optimization/opt_overview) | Guides for how to optimize your diffusion model to run faster and consume less memory. |
|
||||
| [Optimization](https://huggingface.co/docs/diffusers/optimization/fp16) | Guides for how to optimize your diffusion model to run faster and consume less memory. |
|
||||
| [Training](https://huggingface.co/docs/diffusers/training/overview) | Guides for how to train a diffusion model for different tasks with different training techniques. |
|
||||
## Contribution
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ ENV PATH="/opt/venv/bin:$PATH"
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
||||
python3.10 -m uv pip install --no-cache-dir \
|
||||
"torch<2.5.0" \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
"onnxruntime-gpu>=1.13.1" \
|
||||
|
||||
@@ -29,7 +29,7 @@ ENV PATH="/opt/venv/bin:$PATH"
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
||||
python3.10 -m uv pip install --no-cache-dir \
|
||||
"torch<2.5.0" \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
invisible_watermark && \
|
||||
|
||||
@@ -29,7 +29,7 @@ ENV PATH="/opt/venv/bin:$PATH"
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
||||
python3.10 -m uv pip install --no-cache-dir \
|
||||
"torch<2.5.0" \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
invisible_watermark \
|
||||
|
||||
@@ -29,7 +29,7 @@ ENV PATH="/opt/venv/bin:$PATH"
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
||||
python3.10 -m uv pip install --no-cache-dir \
|
||||
"torch<2.5.0" \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
invisible_watermark && \
|
||||
|
||||
@@ -29,7 +29,7 @@ ENV PATH="/opt/venv/bin:$PATH"
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
||||
python3.10 -m pip install --no-cache-dir \
|
||||
"torch<2.5.0" \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
invisible_watermark && \
|
||||
|
||||
@@ -55,6 +55,8 @@
|
||||
- sections:
|
||||
- local: using-diffusers/overview_techniques
|
||||
title: Overview
|
||||
- local: using-diffusers/create_a_server
|
||||
title: Create a server
|
||||
- local: training/distributed_inference
|
||||
title: Distributed inference
|
||||
- local: using-diffusers/merge_loras
|
||||
|
||||
@@ -29,16 +29,32 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
|
||||
|
||||
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).
|
||||
|
||||
There are two models available that can be used with the text-to-video and video-to-video CogVideoX pipelines:
|
||||
- [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b): The recommended dtype for running this model is `fp16`.
|
||||
- [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b): The recommended dtype for running this model is `bf16`.
|
||||
There are three official CogVideoX checkpoints for text-to-video and video-to-video.
|
||||
|
||||
There is one model available that can be used with the image-to-video CogVideoX pipeline:
|
||||
- [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V): The recommended dtype for running this model is `bf16`.
|
||||
| checkpoints | recommended inference dtype |
|
||||
|:---:|:---:|
|
||||
| [`THUDM/CogVideoX-2b`](https://huggingface.co/THUDM/CogVideoX-2b) | torch.float16 |
|
||||
| [`THUDM/CogVideoX-5b`](https://huggingface.co/THUDM/CogVideoX-5b) | torch.bfloat16 |
|
||||
| [`THUDM/CogVideoX1.5-5b`](https://huggingface.co/THUDM/CogVideoX1.5-5b) | torch.bfloat16 |
|
||||
|
||||
There are two models that support pose controllable generation (by the [Alibaba-PAI](https://huggingface.co/alibaba-pai) team):
|
||||
- [`alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose): The recommended dtype for running this model is `bf16`.
|
||||
- [`alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose): The recommended dtype for running this model is `bf16`.
|
||||
There are two official CogVideoX checkpoints available for image-to-video.
|
||||
|
||||
| checkpoints | recommended inference dtype |
|
||||
|:---:|:---:|
|
||||
| [`THUDM/CogVideoX-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-5b-I2V) | torch.bfloat16 |
|
||||
| [`THUDM/CogVideoX-1.5-5b-I2V`](https://huggingface.co/THUDM/CogVideoX-1.5-5b-I2V) | torch.bfloat16 |
|
||||
|
||||
For the CogVideoX 1.5 series:
|
||||
- Text-to-video (T2V) works best at a resolution of 1360x768 because it was trained with that specific resolution.
|
||||
- Image-to-video (I2V) works for multiple resolutions. The width can vary from 768 to 1360, but the height must be 768. The height/width must be divisible by 16.
|
||||
- Both T2V and I2V models support generation with 81 and 161 frames and work best at this value. Exporting videos at 16 FPS is recommended.
|
||||
|
||||
There are two official CogVideoX checkpoints that support pose controllable generation (by the [Alibaba-PAI](https://huggingface.co/alibaba-pai) team).
|
||||
|
||||
| checkpoints | recommended inference dtype |
|
||||
|:---:|:---:|
|
||||
| [`alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-2b-Pose) | torch.bfloat16 |
|
||||
| [`alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose`](https://huggingface.co/alibaba-pai/CogVideoX-Fun-V1.1-5b-Pose) | torch.bfloat16 |
|
||||
|
||||
## Inference
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ This controlnet code is mainly implemented by [The InstantX Team](https://huggin
|
||||
| ControlNet type | Developer | Link |
|
||||
| -------- | ---------- | ---- |
|
||||
| Canny | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/InstantX/SD3-Controlnet-Canny) |
|
||||
| Depth | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/InstantX/SD3-Controlnet-Depth) |
|
||||
| Pose | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/InstantX/SD3-Controlnet-Pose) |
|
||||
| Tile | [The InstantX Team](https://huggingface.co/InstantX) | [Link](https://huggingface.co/InstantX/SD3-Controlnet-Tile) |
|
||||
| Inpainting | [The AlimamaCreative Team](https://huggingface.co/alimama-creative) | [link](https://huggingface.co/alimama-creative/SD3-Controlnet-Inpainting) |
|
||||
|
||||
@@ -22,12 +22,20 @@ Flux can be quite expensive to run on consumer hardware devices. However, you ca
|
||||
|
||||
</Tip>
|
||||
|
||||
Flux comes in two variants:
|
||||
Flux comes in the following variants:
|
||||
|
||||
* Timestep-distilled (`black-forest-labs/FLUX.1-schnell`)
|
||||
* Guidance-distilled (`black-forest-labs/FLUX.1-dev`)
|
||||
| model type | model id |
|
||||
|:----------:|:--------:|
|
||||
| Timestep-distilled | [`black-forest-labs/FLUX.1-schnell`](https://huggingface.co/black-forest-labs/FLUX.1-schnell) |
|
||||
| Guidance-distilled | [`black-forest-labs/FLUX.1-dev`](https://huggingface.co/black-forest-labs/FLUX.1-dev) |
|
||||
| Fill Inpainting/Outpainting (Guidance-distilled) | [`black-forest-labs/FLUX.1-Fill-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev) |
|
||||
| Canny Control (Guidance-distilled) | [`black-forest-labs/FLUX.1-Canny-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev) |
|
||||
| Depth Control (Guidance-distilled) | [`black-forest-labs/FLUX.1-Depth-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev) |
|
||||
| Canny Control (LoRA) | [`black-forest-labs/FLUX.1-Canny-dev-lora`](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev-lora) |
|
||||
| Depth Control (LoRA) | [`black-forest-labs/FLUX.1-Depth-dev-lora`](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev-lora) |
|
||||
| Redux (Adapter) | [`black-forest-labs/FLUX.1-Redux-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev) |
|
||||
|
||||
Both checkpoints have slightly difference usage which we detail below.
|
||||
All checkpoints have different usage which we detail below.
|
||||
|
||||
### Timestep-distilled
|
||||
|
||||
@@ -77,7 +85,132 @@ out = pipe(
|
||||
out.save("image.png")
|
||||
```
|
||||
|
||||
### Fill Inpainting/Outpainting
|
||||
|
||||
* Flux Fill pipeline does not require `strength` as an input like regular inpainting pipelines.
|
||||
* It supports both inpainting and outpainting.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxFillPipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/cup.png")
|
||||
mask = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/cup_mask.png")
|
||||
|
||||
repo_id = "black-forest-labs/FLUX.1-Fill-dev"
|
||||
pipe = FluxFillPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16).to("cuda")
|
||||
|
||||
image = pipe(
|
||||
prompt="a white paper cup",
|
||||
image=image,
|
||||
mask_image=mask,
|
||||
height=1632,
|
||||
width=1232,
|
||||
max_sequence_length=512,
|
||||
generator=torch.Generator("cpu").manual_seed(0)
|
||||
).images[0]
|
||||
image.save(f"output.png")
|
||||
```
|
||||
|
||||
### Canny Control
|
||||
|
||||
**Note:** `black-forest-labs/Flux.1-Canny-dev` is _not_ a [`ControlNetModel`] model. ControlNet models are a separate component from the UNet/Transformer whose residuals are added to the actual underlying model. Canny Control is an alternate architecture that achieves effectively the same results as a ControlNet model would, by using channel-wise concatenation with input control condition and ensuring the transformer learns structure control by following the condition as closely as possible.
|
||||
|
||||
```python
|
||||
# !pip install -U controlnet-aux
|
||||
import torch
|
||||
from controlnet_aux import CannyDetector
|
||||
from diffusers import FluxControlPipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-Canny-dev", torch_dtype=torch.bfloat16).to("cuda")
|
||||
|
||||
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
|
||||
control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
|
||||
|
||||
processor = CannyDetector()
|
||||
control_image = processor(control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024)
|
||||
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
control_image=control_image,
|
||||
height=1024,
|
||||
width=1024,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=30.0,
|
||||
).images[0]
|
||||
image.save("output.png")
|
||||
```
|
||||
|
||||
### Depth Control
|
||||
|
||||
**Note:** `black-forest-labs/Flux.1-Depth-dev` is _not_ a ControlNet model. [`ControlNetModel`] models are a separate component from the UNet/Transformer whose residuals are added to the actual underlying model. Depth Control is an alternate architecture that achieves effectively the same results as a ControlNet model would, by using channel-wise concatenation with input control condition and ensuring the transformer learns structure control by following the condition as closely as possible.
|
||||
|
||||
```python
|
||||
# !pip install git+https://github.com/huggingface/image_gen_aux
|
||||
import torch
|
||||
from diffusers import FluxControlPipeline, FluxTransformer2DModel
|
||||
from diffusers.utils import load_image
|
||||
from image_gen_aux import DepthPreprocessor
|
||||
|
||||
pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-Depth-dev", torch_dtype=torch.bfloat16).to("cuda")
|
||||
|
||||
prompt = "A robot made of exotic candies and chocolates of different kinds. The background is filled with confetti and celebratory gifts."
|
||||
control_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/robot.png")
|
||||
|
||||
processor = DepthPreprocessor.from_pretrained("LiheYoung/depth-anything-large-hf")
|
||||
control_image = processor(control_image)[0].convert("RGB")
|
||||
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
control_image=control_image,
|
||||
height=1024,
|
||||
width=1024,
|
||||
num_inference_steps=30,
|
||||
guidance_scale=10.0,
|
||||
generator=torch.Generator().manual_seed(42),
|
||||
).images[0]
|
||||
image.save("output.png")
|
||||
```
|
||||
|
||||
### Redux
|
||||
|
||||
* Flux Redux pipeline is an adapter for FLUX.1 base models. It can be used with both flux-dev and flux-schnell, for image-to-image generation.
|
||||
* You can first use the `FluxPriorReduxPipeline` to get the `prompt_embeds` and `pooled_prompt_embeds`, and then feed them into the `FluxPipeline` for image-to-image generation.
|
||||
* When use `FluxPriorReduxPipeline` with a base pipeline, you can set `text_encoder=None` and `text_encoder_2=None` in the base pipeline, in order to save VRAM.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxPriorReduxPipeline, FluxPipeline
|
||||
from diffusers.utils import load_image
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
|
||||
repo_redux = "black-forest-labs/FLUX.1-Redux-dev"
|
||||
repo_base = "black-forest-labs/FLUX.1-dev"
|
||||
pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(repo_redux, torch_dtype=dtype).to(device)
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
repo_base,
|
||||
text_encoder=None,
|
||||
text_encoder_2=None,
|
||||
torch_dtype=torch.bfloat16
|
||||
).to(device)
|
||||
|
||||
image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img5.png")
|
||||
pipe_prior_output = pipe_prior_redux(image)
|
||||
images = pipe(
|
||||
guidance_scale=2.5,
|
||||
num_inference_steps=50,
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
**pipe_prior_output,
|
||||
).images
|
||||
images[0].save("flux-redux.png")
|
||||
```
|
||||
|
||||
## Running FP16 inference
|
||||
|
||||
Flux can generate high-quality images with FP16 (i.e. to accelerate inference on Turing/Volta GPUs) but produces different outputs compared to FP32/BF16. The issue is that some activations in the text encoders have to be clipped when running in FP16, which affects the overall image. Forcing text encoders to run with FP32 inference thus removes this output difference. See [here](https://github.com/huggingface/diffusers/pull/9097#issuecomment-2272292516) for details.
|
||||
|
||||
FP16 inference code:
|
||||
@@ -188,3 +321,27 @@ image.save("flux-fp8-dev.png")
|
||||
[[autodoc]] FluxControlNetImg2ImgPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## FluxControlPipeline
|
||||
|
||||
[[autodoc]] FluxControlPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## FluxControlImg2ImgPipeline
|
||||
|
||||
[[autodoc]] FluxControlImg2ImgPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## FluxPriorReduxPipeline
|
||||
|
||||
[[autodoc]] FluxPriorReduxPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## FluxFillPipeline
|
||||
|
||||
[[autodoc]] FluxFillPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
@@ -96,6 +96,10 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusion3PAGImg2ImgPipeline
|
||||
[[autodoc]] StableDiffusion3PAGImg2ImgPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## PixArtSigmaPAGPipeline
|
||||
[[autodoc]] PixArtSigmaPAGPipeline
|
||||
|
||||
@@ -181,7 +181,7 @@ Then we load the [v1-5 checkpoint](https://huggingface.co/stable-diffusion-v1-5/
|
||||
|
||||
```python
|
||||
model_ckpt_1_5 = "stable-diffusion-v1-5/stable-diffusion-v1-5"
|
||||
sd_pipeline_1_5 = StableDiffusionPipeline.from_pretrained(model_ckpt_1_5, torch_dtype=weight_dtype).to(device)
|
||||
sd_pipeline_1_5 = StableDiffusionPipeline.from_pretrained(model_ckpt_1_5, torch_dtype=torch.float16).to("cuda")
|
||||
|
||||
images_1_5 = sd_pipeline_1_5(prompts, num_images_per_prompt=1, generator=generator, output_type="np").images
|
||||
```
|
||||
@@ -280,7 +280,7 @@ from diffusers import StableDiffusionInstructPix2PixPipeline
|
||||
|
||||
instruct_pix2pix_pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
||||
"timbrooks/instruct-pix2pix", torch_dtype=torch.float16
|
||||
).to(device)
|
||||
).to("cuda")
|
||||
```
|
||||
|
||||
Now, we perform the edits:
|
||||
@@ -326,9 +326,9 @@ from transformers import (
|
||||
|
||||
clip_id = "openai/clip-vit-large-patch14"
|
||||
tokenizer = CLIPTokenizer.from_pretrained(clip_id)
|
||||
text_encoder = CLIPTextModelWithProjection.from_pretrained(clip_id).to(device)
|
||||
text_encoder = CLIPTextModelWithProjection.from_pretrained(clip_id).to("cuda")
|
||||
image_processor = CLIPImageProcessor.from_pretrained(clip_id)
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained(clip_id).to(device)
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained(clip_id).to("cuda")
|
||||
```
|
||||
|
||||
Notice that we are using a particular CLIP checkpoint, i.e., `openai/clip-vit-large-patch14`. This is because the Stable Diffusion pre-training was performed with this CLIP variant. For more details, refer to the [documentation](https://huggingface.co/docs/transformers/model_doc/clip).
|
||||
@@ -350,7 +350,7 @@ class DirectionalSimilarity(nn.Module):
|
||||
|
||||
def preprocess_image(self, image):
|
||||
image = self.image_processor(image, return_tensors="pt")["pixel_values"]
|
||||
return {"pixel_values": image.to(device)}
|
||||
return {"pixel_values": image.to("cuda")}
|
||||
|
||||
def tokenize_text(self, text):
|
||||
inputs = self.tokenizer(
|
||||
@@ -360,7 +360,7 @@ class DirectionalSimilarity(nn.Module):
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
return {"input_ids": inputs.input_ids.to(device)}
|
||||
return {"input_ids": inputs.input_ids.to("cuda")}
|
||||
|
||||
def encode_image(self, image):
|
||||
preprocessed_image = self.preprocess_image(image)
|
||||
@@ -459,6 +459,7 @@ with ZipFile(local_filepath, "r") as zipper:
|
||||
```python
|
||||
from PIL import Image
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
dataset_path = "sample-imagenet-images"
|
||||
image_paths = sorted([os.path.join(dataset_path, x) for x in os.listdir(dataset_path)])
|
||||
@@ -477,6 +478,7 @@ Now that the images are loaded, let's apply some lightweight pre-processing on t
|
||||
|
||||
```python
|
||||
from torchvision.transforms import functional as F
|
||||
import torch
|
||||
|
||||
|
||||
def preprocess_image(image):
|
||||
@@ -498,6 +500,10 @@ dit_pipeline = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", torch_dtype=
|
||||
dit_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(dit_pipeline.scheduler.config)
|
||||
dit_pipeline = dit_pipeline.to("cuda")
|
||||
|
||||
seed = 0
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
|
||||
words = [
|
||||
"cassette player",
|
||||
"chainsaw",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Create a dataset for training
|
||||
|
||||
There are many datasets on the [Hub](https://huggingface.co/datasets?task_categories=task_categories:text-to-image&sort=downloads) to train a model on, but if you can't find one you're interested in or want to use your own, you can create a dataset with the 🤗 [Datasets](hf.co/docs/datasets) library. The dataset structure depends on the task you want to train your model on. The most basic dataset structure is a directory of images for tasks like unconditional image generation. Another dataset structure may be a directory of images and a text file containing their corresponding text captions for tasks like text-to-image generation.
|
||||
There are many datasets on the [Hub](https://huggingface.co/datasets?task_categories=task_categories:text-to-image&sort=downloads) to train a model on, but if you can't find one you're interested in or want to use your own, you can create a dataset with the 🤗 [Datasets](https://huggingface.co/docs/datasets) library. The dataset structure depends on the task you want to train your model on. The most basic dataset structure is a directory of images for tasks like unconditional image generation. Another dataset structure may be a directory of images and a text file containing their corresponding text captions for tasks like text-to-image generation.
|
||||
|
||||
This guide will show you two ways to create a dataset to finetune on:
|
||||
|
||||
@@ -87,4 +87,4 @@ accelerate launch --mixed_precision="fp16" train_text_to_image.py \
|
||||
|
||||
Now that you've created a dataset, you can plug it into the `train_data_dir` (if your dataset is local) or `dataset_name` (if your dataset is on the Hub) arguments of a training script.
|
||||
|
||||
For your next steps, feel free to try and use your dataset to train a model for [unconditional generation](unconditional_training) or [text-to-image generation](text2image)!
|
||||
For your next steps, feel free to try and use your dataset to train a model for [unconditional generation](unconditional_training) or [text-to-image generation](text2image)!
|
||||
|
||||
@@ -75,7 +75,7 @@ For convenience, create a `TrainingConfig` class containing the training hyperpa
|
||||
|
||||
... push_to_hub = True # whether to upload the saved model to the HF Hub
|
||||
... hub_model_id = "<your-username>/<my-awesome-model>" # the name of the repository to create on the HF Hub
|
||||
... hub_private_repo = False
|
||||
... hub_private_repo = None
|
||||
... overwrite_output_dir = True # overwrite the old model when re-running the notebook
|
||||
... seed = 0
|
||||
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
|
||||
# Create a server
|
||||
|
||||
Diffusers' pipelines can be used as an inference engine for a server. It supports concurrent and multithreaded requests to generate images that may be requested by multiple users at the same time.
|
||||
|
||||
This guide will show you how to use the [`StableDiffusion3Pipeline`] in a server, but feel free to use any pipeline you want.
|
||||
|
||||
|
||||
Start by navigating to the `examples/server` folder and installing all of the dependencies.
|
||||
|
||||
```py
|
||||
pip install .
|
||||
pip install -f requirements.txt
|
||||
```
|
||||
|
||||
Launch the server with the following command.
|
||||
|
||||
```py
|
||||
python server.py
|
||||
```
|
||||
|
||||
The server is accessed at http://localhost:8000. You can curl this model with the following command.
|
||||
```
|
||||
curl -X POST -H "Content-Type: application/json" --data '{"model": "something", "prompt": "a kitten in front of a fireplace"}' http://localhost:8000/v1/images/generations
|
||||
```
|
||||
|
||||
If you need to upgrade some dependencies, you can use either [pip-tools](https://github.com/jazzband/pip-tools) or [uv](https://github.com/astral-sh/uv). For example, upgrade the dependencies with `uv` using the following command.
|
||||
|
||||
```
|
||||
uv pip compile requirements.in -o requirements.txt
|
||||
```
|
||||
|
||||
|
||||
The server is built with [FastAPI](https://fastapi.tiangolo.com/async/). The endpoint for `v1/images/generations` is shown below.
|
||||
```py
|
||||
@app.post("/v1/images/generations")
|
||||
async def generate_image(image_input: TextToImageInput):
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
scheduler = shared_pipeline.pipeline.scheduler.from_config(shared_pipeline.pipeline.scheduler.config)
|
||||
pipeline = StableDiffusion3Pipeline.from_pipe(shared_pipeline.pipeline, scheduler=scheduler)
|
||||
generator = torch.Generator(device="cuda")
|
||||
generator.manual_seed(random.randint(0, 10000000))
|
||||
output = await loop.run_in_executor(None, lambda: pipeline(image_input.prompt, generator = generator))
|
||||
logger.info(f"output: {output}")
|
||||
image_url = save_image(output.images[0])
|
||||
return {"data": [{"url": image_url}]}
|
||||
except Exception as e:
|
||||
if isinstance(e, HTTPException):
|
||||
raise e
|
||||
elif hasattr(e, 'message'):
|
||||
raise HTTPException(status_code=500, detail=e.message + traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e) + traceback.format_exc())
|
||||
```
|
||||
The `generate_image` function is defined as asynchronous with the [async](https://fastapi.tiangolo.com/async/) keyword so that FastAPI knows that whatever is happening in this function won't necessarily return a result right away. Once it hits some point in the function that it needs to await some other [Task](https://docs.python.org/3/library/asyncio-task.html#asyncio.Task), the main thread goes back to answering other HTTP requests. This is shown in the code below with the [await](https://fastapi.tiangolo.com/async/#async-and-await) keyword.
|
||||
```py
|
||||
output = await loop.run_in_executor(None, lambda: pipeline(image_input.prompt, generator = generator))
|
||||
```
|
||||
At this point, the execution of the pipeline function is placed onto a [new thread](https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.run_in_executor), and the main thread performs other things until a result is returned from the `pipeline`.
|
||||
|
||||
Another important aspect of this implementation is creating a `pipeline` from `shared_pipeline`. The goal behind this is to avoid loading the underlying model more than once onto the GPU while still allowing for each new request that is running on a separate thread to have its own generator and scheduler. The scheduler, in particular, is not thread-safe, and it will cause errors like: `IndexError: index 21 is out of bounds for dimension 0 with size 21` if you try to use the same scheduler across multiple threads.
|
||||
@@ -121,7 +121,7 @@ image = pipe(prompt=prompt, image=init_image, mask_image=mask_image, num_inferen
|
||||
|
||||
### 이미지 결과물을 정제하기
|
||||
|
||||
[base 모델 체크포인트](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)에서, StableDiffusion-XL 또한 고주파 품질을 향상시키는 이미지를 생성하기 위해 낮은 노이즈 단계 이미지를 제거하는데 특화된 [refiner 체크포인트](huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0)를 포함하고 있습니다. 이 refiner 체크포인트는 이미지 품질을 향상시키기 위해 base 체크포인트를 실행한 후 "두 번째 단계" 파이프라인에 사용될 수 있습니다.
|
||||
[base 모델 체크포인트](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)에서, StableDiffusion-XL 또한 고주파 품질을 향상시키는 이미지를 생성하기 위해 낮은 노이즈 단계 이미지를 제거하는데 특화된 [refiner 체크포인트](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0)를 포함하고 있습니다. 이 refiner 체크포인트는 이미지 품질을 향상시키기 위해 base 체크포인트를 실행한 후 "두 번째 단계" 파이프라인에 사용될 수 있습니다.
|
||||
|
||||
refiner를 사용할 때, 쉽게 사용할 수 있습니다
|
||||
- 1.) base 모델과 refiner을 사용하는데, 이는 *Denoisers의 앙상블*을 위한 첫 번째 제안된 [eDiff-I](https://research.nvidia.com/labs/dir/eDiff-I/)를 사용하거나
|
||||
@@ -215,7 +215,7 @@ image = refiner(
|
||||
|
||||
#### 2.) 노이즈가 완전히 제거된 기본 이미지에서 이미지 출력을 정제하기
|
||||
|
||||
일반적인 [`StableDiffusionImg2ImgPipeline`] 방식에서, 기본 모델에서 생성된 완전히 노이즈가 제거된 이미지는 [refiner checkpoint](huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0)를 사용해 더 향상시킬 수 있습니다.
|
||||
일반적인 [`StableDiffusionImg2ImgPipeline`] 방식에서, 기본 모델에서 생성된 완전히 노이즈가 제거된 이미지는 [refiner checkpoint](https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0)를 사용해 더 향상시킬 수 있습니다.
|
||||
|
||||
이를 위해, 보통의 "base" text-to-image 파이프라인을 수행 후에 image-to-image 파이프라인으로써 refiner를 실행시킬 수 있습니다. base 모델의 출력을 잠재 공간에 남겨둘 수 있습니다.
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# 학습을 위한 데이터셋 만들기
|
||||
|
||||
[Hub](https://huggingface.co/datasets?task_categories=task_categories:text-to-image&sort=downloads) 에는 모델 교육을 위한 많은 데이터셋이 있지만,
|
||||
관심이 있거나 사용하고 싶은 데이터셋을 찾을 수 없는 경우 🤗 [Datasets](hf.co/docs/datasets) 라이브러리를 사용하여 데이터셋을 만들 수 있습니다.
|
||||
관심이 있거나 사용하고 싶은 데이터셋을 찾을 수 없는 경우 🤗 [Datasets](https://huggingface.co/docs/datasets) 라이브러리를 사용하여 데이터셋을 만들 수 있습니다.
|
||||
데이터셋 구조는 모델을 학습하려는 작업에 따라 달라집니다.
|
||||
가장 기본적인 데이터셋 구조는 unconditional 이미지 생성과 같은 작업을 위한 이미지 디렉토리입니다.
|
||||
또 다른 데이터셋 구조는 이미지 디렉토리와 text-to-image 생성과 같은 작업에 해당하는 텍스트 캡션이 포함된 텍스트 파일일 수 있습니다.
|
||||
|
||||
@@ -36,7 +36,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
[cloneofsimo](https://github.com/cloneofsimo)는 인기 있는 [lora](https://github.com/cloneofsimo/lora) GitHub 리포지토리에서 Stable Diffusion을 위한 LoRA 학습을 최초로 시도했습니다. 🧨 Diffusers는 [text-to-image 생성](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image#training-with-lora) 및 [DreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#training-with-low-rank-adaptation-of-large-language-models-lora)을 지원합니다. 이 가이드는 두 가지를 모두 수행하는 방법을 보여줍니다.
|
||||
|
||||
모델을 저장하거나 커뮤니티와 공유하려면 Hugging Face 계정에 로그인하세요(아직 계정이 없는 경우 [생성](hf.co/join)하세요):
|
||||
모델을 저장하거나 커뮤니티와 공유하려면 Hugging Face 계정에 로그인하세요(아직 계정이 없는 경우 [생성](https://huggingface.co/join)하세요):
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
|
||||
@@ -76,7 +76,7 @@ huggingface-cli login
|
||||
... output_dir = "ddpm-butterflies-128" # 로컬 및 HF Hub에 저장되는 모델명
|
||||
|
||||
... push_to_hub = True # 저장된 모델을 HF Hub에 업로드할지 여부
|
||||
... hub_private_repo = False
|
||||
... hub_private_repo = None
|
||||
... overwrite_output_dir = True # 노트북을 다시 실행할 때 이전 모델에 덮어씌울지
|
||||
... seed = 0
|
||||
|
||||
|
||||
@@ -2154,6 +2154,7 @@ def main(args):
|
||||
|
||||
# encode batch prompts when custom prompts are provided for each image -
|
||||
if train_dataset.custom_instance_prompts:
|
||||
elems_to_repeat = 1
|
||||
if freeze_text_encoder:
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
|
||||
prompts, text_encoders, tokenizers
|
||||
@@ -2168,17 +2169,21 @@ def main(args):
|
||||
max_sequence_length=args.max_sequence_length,
|
||||
add_special_tokens=add_special_tokens_t5,
|
||||
)
|
||||
else:
|
||||
elems_to_repeat = len(prompts)
|
||||
|
||||
if not freeze_text_encoder:
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
|
||||
text_encoders=[text_encoder_one, text_encoder_two],
|
||||
tokenizers=[None, None],
|
||||
text_input_ids_list=[tokens_one, tokens_two],
|
||||
text_input_ids_list=[
|
||||
tokens_one.repeat(elems_to_repeat, 1),
|
||||
tokens_two.repeat(elems_to_repeat, 1),
|
||||
],
|
||||
max_sequence_length=args.max_sequence_length,
|
||||
device=accelerator.device,
|
||||
prompt=prompts,
|
||||
)
|
||||
|
||||
# Convert images to latent space
|
||||
if args.cache_latents:
|
||||
model_input = latents_cache[step].sample()
|
||||
@@ -2371,6 +2376,9 @@ def main(args):
|
||||
epoch=epoch,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
images = None
|
||||
del pipeline
|
||||
|
||||
if freeze_text_encoder:
|
||||
del text_encoder_one, text_encoder_two
|
||||
free_memory()
|
||||
@@ -2448,6 +2456,8 @@ def main(args):
|
||||
commit_message="End of training",
|
||||
ignore_patterns=["step_*", "epoch_*"],
|
||||
)
|
||||
images = None
|
||||
del pipeline
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ from accelerate.logging import get_logger
|
||||
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from packaging import version
|
||||
from peft import LoraConfig
|
||||
from peft import LoraConfig, set_peft_model_state_dict
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
@@ -59,12 +59,13 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.loaders import StableDiffusionLoraLoaderMixin
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import compute_snr
|
||||
from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr
|
||||
from diffusers.utils import (
|
||||
check_min_version,
|
||||
convert_all_state_dict_to_peft,
|
||||
convert_state_dict_to_diffusers,
|
||||
convert_state_dict_to_kohya,
|
||||
convert_unet_state_dict_to_peft,
|
||||
is_wandb_available,
|
||||
)
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
@@ -1319,6 +1320,37 @@ def main(args):
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
lora_state_dict, network_alphas = StableDiffusionPipeline.lora_state_dict(input_dir)
|
||||
|
||||
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
|
||||
if incompatible_keys is not None:
|
||||
# check only for unexpected keys
|
||||
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
||||
if unexpected_keys:
|
||||
logger.warning(
|
||||
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
||||
f" {unexpected_keys}. "
|
||||
)
|
||||
|
||||
if args.train_text_encoder:
|
||||
# Do we need to call `scale_lora_layers()` here?
|
||||
_set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
|
||||
|
||||
_set_state_dict_into_text_encoder(
|
||||
lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_one_
|
||||
)
|
||||
|
||||
# Make sure the trainable params are in float32. This is again needed since the base models
|
||||
# are in `weight_dtype`. More details:
|
||||
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
|
||||
if args.mixed_precision == "fp16":
|
||||
models = [unet_]
|
||||
if args.train_text_encoder:
|
||||
models.extend([text_encoder_one_])
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
cast_training_params(models)
|
||||
lora_state_dict, network_alphas = StableDiffusionLoraLoaderMixin.lora_state_dict(input_dir)
|
||||
StableDiffusionLoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
|
||||
|
||||
|
||||
+135
-27
@@ -11,22 +11,22 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
|
||||
| Example | Description | Code Example | Colab | Author |
|
||||
|:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:|
|
||||
|Adaptive Mask Inpainting|Adaptive Mask Inpainting algorithm from [Beyond the Contact: Discovering Comprehensive Affordance for 3D Objects from Pre-trained 2D Diffusion Models](https://github.com/snuvclab/coma) (ECCV '24, Oral) provides a way to insert human inside the scene image without altering the background, by inpainting with adapting mask.|[Adaptive Mask Inpainting](#adaptive-mask-inpainting)|-|[Hyeonwoo Kim](https://sshowbiz.xyz),[Sookwan Han](https://jellyheadandrew.github.io)|
|
||||
|Flux with CFG|[Flux with CFG](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md) provides an implementation of using CFG in [Flux](https://blackforestlabs.ai/announcing-black-forest-labs/).|[Flux with CFG](#flux-with-cfg)|NA|[Linoy Tsaban](https://github.com/linoytsaban), [Apolinário](https://github.com/apolinario), and [Sayak Paul](https://github.com/sayakpaul)|
|
||||
|Flux with CFG|[Flux with CFG](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md) provides an implementation of using CFG in [Flux](https://blackforestlabs.ai/announcing-black-forest-labs/).|[Flux with CFG](#flux-with-cfg)|[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/flux_with_cfg.ipynb)|[Linoy Tsaban](https://github.com/linoytsaban), [Apolinário](https://github.com/apolinario), and [Sayak Paul](https://github.com/sayakpaul)|
|
||||
|Differential Diffusion|[Differential Diffusion](https://github.com/exx8/differential-diffusion) modifies an image according to a text prompt, and according to a map that specifies the amount of change in each region.|[Differential Diffusion](#differential-diffusion)|[](https://huggingface.co/spaces/exx8/differential-diffusion) [](https://colab.research.google.com/github/exx8/differential-diffusion/blob/main/examples/SD2.ipynb)|[Eran Levin](https://github.com/exx8) and [Ohad Fried](https://www.ohadf.com/)|
|
||||
| HD-Painter | [HD-Painter](https://github.com/Picsart-AI-Research/HD-Painter) enables prompt-faithfull and high resolution (up to 2k) image inpainting upon any diffusion-based image inpainting method. | [HD-Painter](#hd-painter) | [](https://huggingface.co/spaces/PAIR/HD-Painter) | [Manukyan Hayk](https://github.com/haikmanukyan) and [Sargsyan Andranik](https://github.com/AndranikSargsyan) |
|
||||
| Marigold Monocular Depth Estimation | A universal monocular depth estimator, utilizing Stable Diffusion, delivering sharp predictions in the wild. (See the [project page](https://marigoldmonodepth.github.io) and [full codebase](https://github.com/prs-eth/marigold) for more details.) | [Marigold Depth Estimation](#marigold-depth-estimation) | [](https://huggingface.co/spaces/toshas/marigold) [](https://colab.research.google.com/drive/12G8reD13DdpMie5ZQlaFNo2WCGeNUH-u?usp=sharing) | [Bingxin Ke](https://github.com/markkua) and [Anton Obukhov](https://github.com/toshas) |
|
||||
| LLM-grounded Diffusion (LMD+) | LMD greatly improves the prompt following ability of text-to-image generation models by introducing an LLM as a front-end prompt parser and layout planner. [Project page.](https://llm-grounded-diffusion.github.io/) [See our full codebase (also with diffusers).](https://github.com/TonyLianLong/LLM-groundedDiffusion) | [LLM-grounded Diffusion (LMD+)](#llm-grounded-diffusion) | [Huggingface Demo](https://huggingface.co/spaces/longlian/llm-grounded-diffusion) [](https://colab.research.google.com/drive/1SXzMSeAB-LJYISb2yrUOdypLz4OYWUKj) | [Long (Tony) Lian](https://tonylian.com/) |
|
||||
| CLIP Guided Stable Diffusion | Doing CLIP guidance for text to image generation with Stable Diffusion | [CLIP Guided Stable Diffusion](#clip-guided-stable-diffusion) | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/CLIP_Guided_Stable_diffusion_with_diffusers.ipynb) | [Suraj Patil](https://github.com/patil-suraj/) |
|
||||
| One Step U-Net (Dummy) | Example showcasing of how to use Community Pipelines (see <https://github.com/huggingface/diffusers/issues/841>) | [One Step U-Net](#one-step-unet) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
|
||||
| Stable Diffusion Interpolation | Interpolate the latent space of Stable Diffusion between different prompts/seeds | [Stable Diffusion Interpolation](#stable-diffusion-interpolation) | - | [Nate Raw](https://github.com/nateraw/) |
|
||||
| Stable Diffusion Mega | **One** Stable Diffusion Pipeline with all functionalities of [Text2Image](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py), [Image2Image](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py) and [Inpainting](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py) | [Stable Diffusion Mega](#stable-diffusion-mega) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
|
||||
| Long Prompt Weighting Stable Diffusion | **One** Stable Diffusion Pipeline without tokens length limit, and support parsing weighting in prompt. | [Long Prompt Weighting Stable Diffusion](#long-prompt-weighting-stable-diffusion) | - | [SkyTNT](https://github.com/SkyTNT) |
|
||||
| Speech to Image | Using automatic-speech-recognition to transcribe text and Stable Diffusion to generate images | [Speech to Image](#speech-to-image) | - | [Mikail Duzenli](https://github.com/MikailINTech)
|
||||
| Wild Card Stable Diffusion | Stable Diffusion Pipeline that supports prompts that contain wildcard terms (indicated by surrounding double underscores), with values instantiated randomly from a corresponding txt file or a dictionary of possible values | [Wildcard Stable Diffusion](#wildcard-stable-diffusion) | - | [Shyam Sudhakaran](https://github.com/shyamsn97) |
|
||||
| Stable Diffusion Interpolation | Interpolate the latent space of Stable Diffusion between different prompts/seeds | [Stable Diffusion Interpolation](#stable-diffusion-interpolation) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_diffusion_interpolation.ipynb) | [Nate Raw](https://github.com/nateraw/) |
|
||||
| Stable Diffusion Mega | **One** Stable Diffusion Pipeline with all functionalities of [Text2Image](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py), [Image2Image](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py) and [Inpainting](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py) | [Stable Diffusion Mega](#stable-diffusion-mega) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_diffusion_mega.ipynb) | [Patrick von Platen](https://github.com/patrickvonplaten/) |
|
||||
| Long Prompt Weighting Stable Diffusion | **One** Stable Diffusion Pipeline without tokens length limit, and support parsing weighting in prompt. | [Long Prompt Weighting Stable Diffusion](#long-prompt-weighting-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/long_prompt_weighting_stable_diffusion.ipynb) | [SkyTNT](https://github.com/SkyTNT) |
|
||||
| Speech to Image | Using automatic-speech-recognition to transcribe text and Stable Diffusion to generate images | [Speech to Image](#speech-to-image) |[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/speech_to_image.ipynb) | [Mikail Duzenli](https://github.com/MikailINTech)
|
||||
| Wild Card Stable Diffusion | Stable Diffusion Pipeline that supports prompts that contain wildcard terms (indicated by surrounding double underscores), with values instantiated randomly from a corresponding txt file or a dictionary of possible values | [Wildcard Stable Diffusion](#wildcard-stable-diffusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/wildcard_stable_diffusion.ipynb) | [Shyam Sudhakaran](https://github.com/shyamsn97) |
|
||||
| [Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) | Stable Diffusion Pipeline that supports prompts that contain "|" in prompts (as an AND condition) and weights (separated by "|" as well) to positively / negatively weight prompts. | [Composable Stable Diffusion](#composable-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) |
|
||||
| Seed Resizing Stable Diffusion | Stable Diffusion Pipeline that supports resizing an image and retaining the concepts of the 512 by 512 generation. | [Seed Resizing](#seed-resizing) | - | [Mark Rich](https://github.com/MarkRich) |
|
||||
| Imagic Stable Diffusion | Stable Diffusion Pipeline that enables writing a text prompt to edit an existing image | [Imagic Stable Diffusion](#imagic-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) |
|
||||
| Multilingual Stable Diffusion | Stable Diffusion Pipeline that supports prompts in 50 different languages. | [Multilingual Stable Diffusion](#multilingual-stable-diffusion-pipeline) | - | [Juan Carlos Piñeros](https://github.com/juancopi81) |
|
||||
| Multilingual Stable Diffusion | Stable Diffusion Pipeline that supports prompts in 50 different languages. | [Multilingual Stable Diffusion](#multilingual-stable-diffusion-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/multilingual_stable_diffusion.ipynb) | [Juan Carlos Piñeros](https://github.com/juancopi81) |
|
||||
| GlueGen Stable Diffusion | Stable Diffusion Pipeline that supports prompts in different languages using GlueGen adapter. | [GlueGen Stable Diffusion](#gluegen-stable-diffusion-pipeline) | - | [Phạm Hồng Vinh](https://github.com/rootonchair) |
|
||||
| Image to Image Inpainting Stable Diffusion | Stable Diffusion Pipeline that enables the overlaying of two images and subsequent inpainting | [Image to Image Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | - | [Alex McKinney](https://github.com/vvvm23) |
|
||||
| Text Based Inpainting Stable Diffusion | Stable Diffusion Inpainting Pipeline that enables passing a text prompt to generate the mask for inpainting | [Text Based Inpainting Stable Diffusion](#text-based-inpainting-stable-diffusion) | - | [Dhruv Karan](https://github.com/unography) |
|
||||
@@ -41,8 +41,8 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
|
||||
| DDIM Noise Comparative Analysis Pipeline | Investigating how the diffusion models learn visual concepts from each noise level (which is a contribution of [P2 weighting (CVPR 2022)](https://arxiv.org/abs/2204.00227)) | [DDIM Noise Comparative Analysis Pipeline](#ddim-noise-comparative-analysis-pipeline) | - | [Aengus (Duc-Anh)](https://github.com/aengusng8) |
|
||||
| CLIP Guided Img2Img Stable Diffusion Pipeline | Doing CLIP guidance for image to image generation with Stable Diffusion | [CLIP Guided Img2Img Stable Diffusion](#clip-guided-img2img-stable-diffusion) | - | [Nipun Jindal](https://github.com/nipunjindal/) |
|
||||
| TensorRT Stable Diffusion Text to Image Pipeline | Accelerates the Stable Diffusion Text2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Text to Image Pipeline](#tensorrt-text2image-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
|
||||
| EDICT Image Editing Pipeline | Diffusion pipeline for text-guided image editing | [EDICT Image Editing Pipeline](#edict-image-editing-pipeline) | - | [Joqsan Azocar](https://github.com/Joqsan) |
|
||||
| Stable Diffusion RePaint | Stable Diffusion pipeline using [RePaint](https://arxiv.org/abs/2201.09865) for inpainting. | [Stable Diffusion RePaint](#stable-diffusion-repaint ) | - | [Markus Pobitzer](https://github.com/Markus-Pobitzer) |
|
||||
| EDICT Image Editing Pipeline | Diffusion pipeline for text-guided image editing | [EDICT Image Editing Pipeline](#edict-image-editing-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/edict_image_pipeline.ipynb) | [Joqsan Azocar](https://github.com/Joqsan) |
|
||||
| Stable Diffusion RePaint | Stable Diffusion pipeline using [RePaint](https://arxiv.org/abs/2201.09865) for inpainting. | [Stable Diffusion RePaint](#stable-diffusion-repaint )|[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/stable_diffusion_repaint.ipynb)| [Markus Pobitzer](https://github.com/Markus-Pobitzer) |
|
||||
| TensorRT Stable Diffusion Image to Image Pipeline | Accelerates the Stable Diffusion Image2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Image to Image Pipeline](#tensorrt-image2image-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) |
|
||||
| Stable Diffusion IPEX Pipeline | Accelerate Stable Diffusion inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [Stable Diffusion on IPEX](#stable-diffusion-on-ipex) | - | [Yingjie Han](https://github.com/yingjie-han/) |
|
||||
| CLIP Guided Images Mixing Stable Diffusion Pipeline | Сombine images using usual diffusion models. | [CLIP Guided Images Mixing Using Stable Diffusion](#clip-guided-images-mixing-with-stable-diffusion) | - | [Karachev Denis](https://github.com/TheDenk) |
|
||||
@@ -61,13 +61,13 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
|
||||
| Regional Prompting Pipeline | Assign multiple prompts for different regions | [Regional Prompting Pipeline](#regional-prompting-pipeline) | - | [hako-mikan](https://github.com/hako-mikan) |
|
||||
| LDM3D-sr (LDM3D upscaler) | Upscale low resolution RGB and depth inputs to high resolution | [StableDiffusionUpscaleLDM3D Pipeline](https://github.com/estelleafl/diffusers/tree/ldm3d_upscaler_community/examples/community#stablediffusionupscaleldm3d-pipeline) | - | [Estelle Aflalo](https://github.com/estelleafl) |
|
||||
| AnimateDiff ControlNet Pipeline | Combines AnimateDiff with precise motion control using ControlNets | [AnimateDiff ControlNet Pipeline](#animatediff-controlnet-pipeline) | [](https://colab.research.google.com/drive/1SKboYeGjEQmQPWoFC0aLYpBlYdHXkvAu?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) and [Edoardo Botta](https://github.com/EdoardoBotta) |
|
||||
| DemoFusion Pipeline | Implementation of [DemoFusion: Democratising High-Resolution Image Generation With No $$$](https://arxiv.org/abs/2311.16973) | [DemoFusion Pipeline](#demofusion) | - | [Ruoyi Du](https://github.com/RuoyiDu) |
|
||||
| Instaflow Pipeline | Implementation of [InstaFlow! One-Step Stable Diffusion with Rectified Flow](https://arxiv.org/abs/2309.06380) | [Instaflow Pipeline](#instaflow-pipeline) | - | [Ayush Mangal](https://github.com/ayushtues) |
|
||||
| DemoFusion Pipeline | Implementation of [DemoFusion: Democratising High-Resolution Image Generation With No $$$](https://arxiv.org/abs/2311.16973) | [DemoFusion Pipeline](#demofusion) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/demo_fusion.ipynb) | [Ruoyi Du](https://github.com/RuoyiDu) |
|
||||
| Instaflow Pipeline | Implementation of [InstaFlow! One-Step Stable Diffusion with Rectified Flow](https://arxiv.org/abs/2309.06380) | [Instaflow Pipeline](#instaflow-pipeline) | [Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/insta_flow.ipynb) | [Ayush Mangal](https://github.com/ayushtues) |
|
||||
| Null-Text Inversion Pipeline | Implement [Null-text Inversion for Editing Real Images using Guided Diffusion Models](https://arxiv.org/abs/2211.09794) as a pipeline. | [Null-Text Inversion](https://github.com/google/prompt-to-prompt/) | - | [Junsheng Luan](https://github.com/Junsheng121) |
|
||||
| Rerender A Video Pipeline | Implementation of [[SIGGRAPH Asia 2023] Rerender A Video: Zero-Shot Text-Guided Video-to-Video Translation](https://arxiv.org/abs/2306.07954) | [Rerender A Video Pipeline](#rerender-a-video) | - | [Yifan Zhou](https://github.com/SingleZombie) |
|
||||
| StyleAligned Pipeline | Implementation of [Style Aligned Image Generation via Shared Attention](https://arxiv.org/abs/2312.02133) | [StyleAligned Pipeline](#stylealigned-pipeline) | [](https://drive.google.com/file/d/15X2E0jFPTajUIjS0FzX50OaHsCbP2lQ0/view?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
|
||||
| AnimateDiff Image-To-Video Pipeline | Experimental Image-To-Video support for AnimateDiff (open to improvements) | [AnimateDiff Image To Video Pipeline](#animatediff-image-to-video-pipeline) | [](https://drive.google.com/file/d/1TvzCDPHhfFtdcJZe4RLloAwyoLKuttWK/view?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
|
||||
| IP Adapter FaceID Stable Diffusion | Stable Diffusion Pipeline that supports IP Adapter Face ID | [IP Adapter Face ID](#ip-adapter-face-id) | - | [Fabio Rigano](https://github.com/fabiorigano) |
|
||||
| IP Adapter FaceID Stable Diffusion | Stable Diffusion Pipeline that supports IP Adapter Face ID | [IP Adapter Face ID](#ip-adapter-face-id) |[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/ip_adapter_face_id.ipynb)| [Fabio Rigano](https://github.com/fabiorigano) |
|
||||
| InstantID Pipeline | Stable Diffusion XL Pipeline that supports InstantID | [InstantID Pipeline](#instantid-pipeline) | [](https://huggingface.co/spaces/InstantX/InstantID) | [Haofan Wang](https://github.com/haofanwang) |
|
||||
| UFOGen Scheduler | Scheduler for UFOGen Model (compatible with Stable Diffusion pipelines) | [UFOGen Scheduler](#ufogen-scheduler) | - | [dg845](https://github.com/dg845) |
|
||||
| Stable Diffusion XL IPEX Pipeline | Accelerate Stable Diffusion XL inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [Stable Diffusion XL on IPEX](#stable-diffusion-xl-on-ipex) | - | [Dan Li](https://github.com/ustcuna/) |
|
||||
@@ -251,24 +251,30 @@ Example usage:
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
model_name = "black-forest-labs/FLUX.1-dev"
|
||||
prompt = "a watercolor painting of a unicorn"
|
||||
negative_prompt = "pink"
|
||||
|
||||
# Load the diffusion pipeline
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
custom_pipeline="pipeline_flux_with_cfg"
|
||||
)
|
||||
pipeline.enable_model_cpu_offload()
|
||||
prompt = "a watercolor painting of a unicorn"
|
||||
negative_prompt = "pink"
|
||||
|
||||
# Generate the image
|
||||
img = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
true_cfg=1.5,
|
||||
guidance_scale=3.5,
|
||||
num_images_per_prompt=1,
|
||||
generator=torch.manual_seed(0)
|
||||
).images[0]
|
||||
|
||||
# Save the generated image
|
||||
img.save("cfg_flux.png")
|
||||
print("Image generated and saved successfully.")
|
||||
```
|
||||
|
||||
### Differential Diffusion
|
||||
@@ -841,6 +847,8 @@ out = pipe(
|
||||
wildcard_files=["object.txt", "animal.txt"],
|
||||
num_prompt_samples=1
|
||||
)
|
||||
out.images[0].save("image.png")
|
||||
torch.cuda.empty_cache()
|
||||
```
|
||||
|
||||
### Composable Stable diffusion
|
||||
@@ -2617,16 +2625,17 @@ for obj in range(bs):
|
||||
|
||||
### Stable Diffusion XL Reference
|
||||
|
||||
This pipeline uses the Reference. Refer to the [stable_diffusion_reference](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#stable-diffusion-reference).
|
||||
This pipeline uses the Reference. Refer to the [Stable Diffusion Reference](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#stable-diffusion-reference) section for more information.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from PIL import Image
|
||||
# from diffusers import DiffusionPipeline
|
||||
from diffusers.utils import load_image
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.schedulers import UniPCMultistepScheduler
|
||||
|
||||
input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
|
||||
from .stable_diffusion_xl_reference import StableDiffusionXLReferencePipeline
|
||||
|
||||
input_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_reference_input_cat.jpg")
|
||||
|
||||
# pipe = DiffusionPipeline.from_pretrained(
|
||||
# "stabilityai/stable-diffusion-xl-base-1.0",
|
||||
@@ -2644,7 +2653,7 @@ pipe = StableDiffusionXLReferencePipeline.from_pretrained(
|
||||
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
result_img = pipe(ref_image=input_image,
|
||||
prompt="1girl",
|
||||
prompt="a dog",
|
||||
num_inference_steps=20,
|
||||
reference_attn=True,
|
||||
reference_adain=True).images[0]
|
||||
@@ -2652,14 +2661,14 @@ result_img = pipe(ref_image=input_image,
|
||||
|
||||
Reference Image
|
||||
|
||||

|
||||

|
||||
|
||||
Output Image
|
||||
|
||||
`prompt: 1 girl`
|
||||
`prompt: a dog`
|
||||
|
||||
`reference_attn=True, reference_adain=True, num_inference_steps=20`
|
||||

|
||||
`reference_attn=False, reference_adain=True, num_inference_steps=20`
|
||||

|
||||
|
||||
Reference Image
|
||||

|
||||
@@ -2681,6 +2690,88 @@ Output Image
|
||||
`reference_attn=True, reference_adain=True, num_inference_steps=20`
|
||||

|
||||
|
||||
### Stable Diffusion XL ControlNet Reference
|
||||
|
||||
This pipeline uses the Reference Control and with ControlNet. Refer to the [Stable Diffusion ControlNet Reference](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#stable-diffusion-controlnet-reference) and [Stable Diffusion XL Reference](https://github.com/huggingface/diffusers/blob/main/examples/community/README.md#stable-diffusion-xl-reference) sections for more information.
|
||||
|
||||
```py
|
||||
from diffusers import ControlNetModel, AutoencoderKL
|
||||
from diffusers.schedulers import UniPCMultistepScheduler
|
||||
from diffusers.utils import load_image
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import cv2
|
||||
from PIL import Image
|
||||
|
||||
from .stable_diffusion_xl_controlnet_reference import StableDiffusionXLControlNetReferencePipeline
|
||||
|
||||
# download an image
|
||||
canny_image = load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_reference_input_cat.jpg"
|
||||
)
|
||||
|
||||
ref_image = load_image(
|
||||
"https://hf.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/hf-logo.png"
|
||||
)
|
||||
|
||||
# initialize the models and pipeline
|
||||
controlnet_conditioning_scale = 0.5 # recommended for good generalization
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
"diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
|
||||
pipe = StableDiffusionXLControlNetReferencePipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, vae=vae, torch_dtype=torch.float16
|
||||
).to("cuda:0")
|
||||
|
||||
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
# get canny image
|
||||
image = np.array(canny_image)
|
||||
image = cv2.Canny(image, 100, 200)
|
||||
image = image[:, :, None]
|
||||
image = np.concatenate([image, image, image], axis=2)
|
||||
canny_image = Image.fromarray(image)
|
||||
|
||||
# generate image
|
||||
image = pipe(
|
||||
prompt="a cat",
|
||||
num_inference_steps=20,
|
||||
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
||||
image=canny_image,
|
||||
ref_image=ref_image,
|
||||
reference_attn=False,
|
||||
reference_adain=True,
|
||||
style_fidelity=1.0,
|
||||
generator=torch.Generator("cuda").manual_seed(42)
|
||||
).images[0]
|
||||
```
|
||||
|
||||
Canny ControlNet Image
|
||||
|
||||

|
||||
|
||||
Reference Image
|
||||
|
||||

|
||||
|
||||
Output Image
|
||||
|
||||
`prompt: a cat`
|
||||
|
||||
`reference_attn=True, reference_adain=True, num_inference_steps=20, style_fidelity=1.0`
|
||||
|
||||

|
||||
|
||||
`reference_attn=False, reference_adain=True, num_inference_steps=20, style_fidelity=1.0`
|
||||
|
||||

|
||||
|
||||
`reference_attn=True, reference_adain=False, num_inference_steps=20, style_fidelity=1.0`
|
||||
|
||||

|
||||
|
||||
### Stable diffusion fabric pipeline
|
||||
|
||||
FABRIC approach applicable to a wide range of popular diffusion models, which exploits
|
||||
@@ -3376,6 +3467,20 @@ best quality, 3persons in garden, a boy blue shirt BREAK
|
||||
best quality, 3persons in garden, an old man red suit
|
||||
```
|
||||
|
||||
### Use base prompt
|
||||
|
||||
You can use a base prompt to apply the prompt to all areas. You can set a base prompt by adding `ADDBASE` at the end. Base prompts can also be combined with common prompts, but the base prompt must be specified first.
|
||||
|
||||
```
|
||||
2d animation style ADDBASE
|
||||
masterpiece, high quality ADDCOMM
|
||||
(blue sky)++ BREAK
|
||||
green hair twintail BREAK
|
||||
book shelf BREAK
|
||||
messy desk BREAK
|
||||
orange++ dress and sofa
|
||||
```
|
||||
|
||||
### Negative prompt
|
||||
|
||||
Negative prompts are equally effective across all regions, but it is possible to set region-specific prompts for negative prompts as well. The number of BREAKs must be the same as the number of prompts. If the number of prompts does not match, the negative prompts will be used without being divided into regions.
|
||||
@@ -3406,6 +3511,7 @@ pipe(prompt=prompt, rp_args=rp_args)
|
||||
### Optional Parameters
|
||||
|
||||
- `save_mask`: In `Prompt` mode, choose whether to output the generated mask along with the image. The default is `False`.
|
||||
- `base_ratio`: Used with `ADDBASE`. Sets the ratio of the base prompt; if base ratio is set to 0.2, then resulting images will consist of `20%*BASE_PROMPT + 80%*REGION_PROMPT`
|
||||
|
||||
The Pipeline supports `compel` syntax. Input prompts using the `compel` structure will be automatically applied and processed.
|
||||
|
||||
@@ -3734,6 +3840,7 @@ The original repo can be found at [repo](https://github.com/PRIS-CV/DemoFusion).
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
@@ -3857,9 +3964,10 @@ You can also combine it with LORA out of the box, like <https://huggingface.co/a
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("XCLIU/instaflow_0_9B_from_sd_1_5", torch_dtype=torch.float16, custom_pipeline="instaflow_one_step")
|
||||
pipe.to("cuda") ### if GPU is not available, comment this line
|
||||
pipe.to(device) ### if GPU is not available, comment this line
|
||||
pipe.load_lora_weights("artificialguybr/logo-redmond-1-5v-logo-lora-for-liberteredmond-sd-1-5")
|
||||
prompt = "logo, A logo for a fitness app, dynamic running figure, energetic colors (red, orange) ),LogoRedAF ,"
|
||||
images = pipe(prompt=prompt,
|
||||
@@ -4692,4 +4800,4 @@ with torch.no_grad():
|
||||
```
|
||||
|
||||
In the folder examples/pixart there is also a script that can be used to train new models.
|
||||
Please check the script `train_controlnet_hf_diffusers.sh` on how to start the training.
|
||||
Please check the script `train_controlnet_hf_diffusers.sh` on how to start the training.
|
||||
|
||||
@@ -6,9 +6,9 @@ If a community script doesn't work as expected, please open an issue and ping th
|
||||
|
||||
| Example | Description | Code Example | Colab | Author |
|
||||
|:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:|
|
||||
| Using IP-Adapter with negative noise | Using negative noise with IP-adapter to better control the generation (see the [original post](https://github.com/huggingface/diffusers/discussions/7167) on the forum for more details) | [IP-Adapter Negative Noise](#ip-adapter-negative-noise) | | [Álvaro Somoza](https://github.com/asomoza)|
|
||||
| asymmetric tiling |configure seamless image tiling independently for the X and Y axes | [Asymmetric Tiling](#asymmetric-tiling ) | | [alexisrolland](https://github.com/alexisrolland)|
|
||||
| Prompt scheduling callback |Allows changing prompts during a generation | [Prompt Scheduling](#prompt-scheduling ) | | [hlky](https://github.com/hlky)|
|
||||
| Using IP-Adapter with Negative Noise | Using negative noise with IP-adapter to better control the generation (see the [original post](https://github.com/huggingface/diffusers/discussions/7167) on the forum for more details) | [IP-Adapter Negative Noise](#ip-adapter-negative-noise) |[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/ip_adapter_negative_noise.ipynb) | [Álvaro Somoza](https://github.com/asomoza)|
|
||||
| Asymmetric Tiling |configure seamless image tiling independently for the X and Y axes | [Asymmetric Tiling](#Asymmetric-Tiling ) |[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/asymetric_tiling.ipynb) | [alexisrolland](https://github.com/alexisrolland)|
|
||||
| Prompt Scheduling Callback |Allows changing prompts during a generation | [Prompt Scheduling-Callback](#Prompt-Scheduling-Callback ) |[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/prompt_scheduling_callback.ipynb) | [hlky](https://github.com/hlky)|
|
||||
|
||||
|
||||
## Example usages
|
||||
@@ -312,4 +312,6 @@ image = pipeline(
|
||||
callback_on_step_end=callback,
|
||||
callback_on_step_end_tensor_inputs=["prompt_embeds"],
|
||||
).images[0]
|
||||
torch.cuda.empty_cache()
|
||||
image.save('image.png')
|
||||
```
|
||||
|
||||
@@ -868,7 +868,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
blocks = list(zip(self.resnets, self.attentions))
|
||||
|
||||
for i, (resnet, attn) in enumerate(blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
@@ -1029,7 +1029,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
@@ -1191,7 +1191,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
@@ -1364,7 +1364,7 @@ class MatryoshkaTransformer2DModel(LegacyModelMixin, LegacyConfigMixin):
|
||||
|
||||
# Blocks
|
||||
for block in self.transformer_blocks:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
|
||||
@@ -3,13 +3,12 @@ from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
import torchvision.transforms.functional as FF
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils import USE_PEFT_BACKEND
|
||||
|
||||
|
||||
try:
|
||||
@@ -17,6 +16,7 @@ try:
|
||||
except ImportError:
|
||||
Compel = None
|
||||
|
||||
KBASE = "ADDBASE"
|
||||
KCOMM = "ADDCOMM"
|
||||
KBRK = "BREAK"
|
||||
|
||||
@@ -34,6 +34,11 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
|
||||
Optional
|
||||
rp_args["save_mask"]: True/False (save masks in prompt mode)
|
||||
rp_args["power"]: int (power for attention maps in prompt mode)
|
||||
rp_args["base_ratio"]:
|
||||
float (Sets the ratio of the base prompt)
|
||||
ex) 0.2 (20%*BASE_PROMPT + 80%*REGION_PROMPT)
|
||||
[Use base prompt](https://github.com/hako-mikan/sd-webui-regional-prompter?tab=readme-ov-file#use-base-prompt)
|
||||
|
||||
Pipeline for text-to-image generation using Stable Diffusion.
|
||||
|
||||
@@ -70,6 +75,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
image_encoder: CLIPVisionModelWithProjection = None,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
@@ -80,6 +86,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
scheduler,
|
||||
safety_checker,
|
||||
feature_extractor,
|
||||
image_encoder,
|
||||
requires_safety_checker,
|
||||
)
|
||||
self.register_modules(
|
||||
@@ -90,6 +97,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
image_encoder=image_encoder,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -110,17 +118,40 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
rp_args: Dict[str, str] = None,
|
||||
):
|
||||
active = KBRK in prompt[0] if isinstance(prompt, list) else KBRK in prompt
|
||||
use_base = KBASE in prompt[0] if isinstance(prompt, list) else KBASE in prompt
|
||||
if negative_prompt is None:
|
||||
negative_prompt = "" if isinstance(prompt, str) else [""] * len(prompt)
|
||||
|
||||
device = self._execution_device
|
||||
regions = 0
|
||||
|
||||
self.base_ratio = float(rp_args["base_ratio"]) if "base_ratio" in rp_args else 0.0
|
||||
self.power = int(rp_args["power"]) if "power" in rp_args else 1
|
||||
|
||||
prompts = prompt if isinstance(prompt, list) else [prompt]
|
||||
n_prompts = negative_prompt if isinstance(prompt, str) else [negative_prompt]
|
||||
n_prompts = negative_prompt if isinstance(prompt, list) else [negative_prompt]
|
||||
self.batch = batch = num_images_per_prompt * len(prompts)
|
||||
|
||||
if use_base:
|
||||
bases = prompts.copy()
|
||||
n_bases = n_prompts.copy()
|
||||
|
||||
for i, prompt in enumerate(prompts):
|
||||
parts = prompt.split(KBASE)
|
||||
if len(parts) == 2:
|
||||
bases[i], prompts[i] = parts
|
||||
elif len(parts) > 2:
|
||||
raise ValueError(f"Multiple instances of {KBASE} found in prompt: {prompt}")
|
||||
for i, prompt in enumerate(n_prompts):
|
||||
n_parts = prompt.split(KBASE)
|
||||
if len(n_parts) == 2:
|
||||
n_bases[i], n_prompts[i] = n_parts
|
||||
elif len(n_parts) > 2:
|
||||
raise ValueError(f"Multiple instances of {KBASE} found in negative prompt: {prompt}")
|
||||
|
||||
all_bases_cn, _ = promptsmaker(bases, num_images_per_prompt)
|
||||
all_n_bases_cn, _ = promptsmaker(n_bases, num_images_per_prompt)
|
||||
|
||||
all_prompts_cn, all_prompts_p = promptsmaker(prompts, num_images_per_prompt)
|
||||
all_n_prompts_cn, _ = promptsmaker(n_prompts, num_images_per_prompt)
|
||||
|
||||
@@ -137,8 +168,16 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
|
||||
conds = getcompelembs(all_prompts_cn)
|
||||
unconds = getcompelembs(all_n_prompts_cn)
|
||||
embs = getcompelembs(prompts)
|
||||
n_embs = getcompelembs(n_prompts)
|
||||
base_embs = getcompelembs(all_bases_cn) if use_base else None
|
||||
base_n_embs = getcompelembs(all_n_bases_cn) if use_base else None
|
||||
# When using base, it seems more reasonable to use base prompts as prompt_embeddings rather than regional prompts
|
||||
embs = getcompelembs(prompts) if not use_base else base_embs
|
||||
n_embs = getcompelembs(n_prompts) if not use_base else base_n_embs
|
||||
|
||||
if use_base and self.base_ratio > 0:
|
||||
conds = self.base_ratio * base_embs + (1 - self.base_ratio) * conds
|
||||
unconds = self.base_ratio * base_n_embs + (1 - self.base_ratio) * unconds
|
||||
|
||||
prompt = negative_prompt = None
|
||||
else:
|
||||
conds = self.encode_prompt(prompts, device, 1, True)[0]
|
||||
@@ -147,6 +186,18 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
if equal
|
||||
else self.encode_prompt(all_n_prompts_cn, device, 1, True)[0]
|
||||
)
|
||||
|
||||
if use_base and self.base_ratio > 0:
|
||||
base_embs = self.encode_prompt(bases, device, 1, True)[0]
|
||||
base_n_embs = (
|
||||
self.encode_prompt(n_bases, device, 1, True)[0]
|
||||
if equal
|
||||
else self.encode_prompt(all_n_bases_cn, device, 1, True)[0]
|
||||
)
|
||||
|
||||
conds = self.base_ratio * base_embs + (1 - self.base_ratio) * conds
|
||||
unconds = self.base_ratio * base_n_embs + (1 - self.base_ratio) * unconds
|
||||
|
||||
embs = n_embs = None
|
||||
|
||||
if not active:
|
||||
@@ -225,8 +276,6 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
@@ -247,16 +296,15 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
args = () if USE_PEFT_BACKEND else (scale,)
|
||||
query = attn.to_q(hidden_states, *args)
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states, *args)
|
||||
value = attn.to_v(encoder_hidden_states, *args)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
@@ -283,7 +331,7 @@ class RegionalPromptingStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states, *args)
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
@@ -410,9 +458,9 @@ def promptsmaker(prompts, batch):
|
||||
add = ""
|
||||
if KCOMM in prompt:
|
||||
add, prompt = prompt.split(KCOMM)
|
||||
add = add + " "
|
||||
prompts = prompt.split(KBRK)
|
||||
out_p.append([add + p for p in prompts])
|
||||
add = add.strip() + " "
|
||||
prompts = [p.strip() for p in prompt.split(KBRK)]
|
||||
out_p.append([add + p for i, p in enumerate(prompts)])
|
||||
out = [None] * batch * len(out_p[0]) * len(out_p)
|
||||
for p, prs in enumerate(out_p): # inputs prompts
|
||||
for r, pr in enumerate(prs): # prompts for regions
|
||||
@@ -449,7 +497,6 @@ def make_cells(ratios):
|
||||
add = []
|
||||
startend(add, inratios[1:])
|
||||
icells.append(add)
|
||||
|
||||
return ocells, icells, sum(len(cell) for cell in icells)
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,6 @@
|
||||
# Based on stable_diffusion_reference.py
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -7,28 +8,33 @@ import PIL.Image
|
||||
import torch
|
||||
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from diffusers.image_processor import PipelineImageInput
|
||||
from diffusers.models.attention import BasicTransformerBlock
|
||||
from diffusers.models.unets.unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D,
|
||||
CrossAttnUpBlock2D,
|
||||
DownBlock2D,
|
||||
UpBlock2D,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
|
||||
from diffusers.utils import PIL_INTERPOLATION, logging
|
||||
from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D
|
||||
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
|
||||
from diffusers.utils import PIL_INTERPOLATION, deprecate, is_torch_xla_available, logging, replace_example_docstring
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm # type: ignore
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import UniPCMultistepScheduler
|
||||
>>> from diffusers.schedulers import UniPCMultistepScheduler
|
||||
>>> from diffusers.utils import load_image
|
||||
|
||||
>>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
|
||||
>>> input_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_reference_input_cat.jpg")
|
||||
|
||||
>>> pipe = StableDiffusionXLReferencePipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
@@ -38,7 +44,7 @@ EXAMPLE_DOC_STRING = """
|
||||
|
||||
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
>>> result_img = pipe(ref_image=input_image,
|
||||
prompt="1girl",
|
||||
prompt="a dog",
|
||||
num_inference_steps=20,
|
||||
reference_attn=True,
|
||||
reference_adain=True).images[0]
|
||||
@@ -56,8 +62,6 @@ def torch_dfs(model: torch.nn.Module):
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
||||
|
||||
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
"""
|
||||
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
||||
@@ -72,33 +76,102 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
return noise_cfg
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
|
||||
def _default_height_width(self, height, width, image):
|
||||
# NOTE: It is possible that a list of images have different
|
||||
# dimensions for each image, so just checking the first image
|
||||
# is not _exactly_ correct, but it is simple.
|
||||
while isinstance(image, list):
|
||||
image = image[0]
|
||||
def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):
|
||||
refimage = refimage.to(device=device)
|
||||
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
|
||||
self.upcast_vae()
|
||||
refimage = refimage.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
if refimage.dtype != self.vae.dtype:
|
||||
refimage = refimage.to(dtype=self.vae.dtype)
|
||||
# encode the mask image into latents space so we can concatenate it to the latents
|
||||
if isinstance(generator, list):
|
||||
ref_image_latents = [
|
||||
self.vae.encode(refimage[i : i + 1]).latent_dist.sample(generator=generator[i])
|
||||
for i in range(batch_size)
|
||||
]
|
||||
ref_image_latents = torch.cat(ref_image_latents, dim=0)
|
||||
else:
|
||||
ref_image_latents = self.vae.encode(refimage).latent_dist.sample(generator=generator)
|
||||
ref_image_latents = self.vae.config.scaling_factor * ref_image_latents
|
||||
|
||||
if height is None:
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
height = image.height
|
||||
elif isinstance(image, torch.Tensor):
|
||||
height = image.shape[2]
|
||||
# duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method
|
||||
if ref_image_latents.shape[0] < batch_size:
|
||||
if not batch_size % ref_image_latents.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
|
||||
f" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed."
|
||||
" Make sure the number of images that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
ref_image_latents = ref_image_latents.repeat(batch_size // ref_image_latents.shape[0], 1, 1, 1)
|
||||
|
||||
height = (height // 8) * 8 # round down to nearest multiple of 8
|
||||
ref_image_latents = torch.cat([ref_image_latents] * 2) if do_classifier_free_guidance else ref_image_latents
|
||||
|
||||
if width is None:
|
||||
if isinstance(image, PIL.Image.Image):
|
||||
width = image.width
|
||||
elif isinstance(image, torch.Tensor):
|
||||
width = image.shape[3]
|
||||
# aligning device to prevent device errors when concating it with the latent model input
|
||||
ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
|
||||
return ref_image_latents
|
||||
|
||||
width = (width // 8) * 8
|
||||
|
||||
return height, width
|
||||
|
||||
def prepare_image(
|
||||
def prepare_ref_image(
|
||||
self,
|
||||
image,
|
||||
width,
|
||||
@@ -151,41 +224,42 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
|
||||
|
||||
return image
|
||||
|
||||
def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance):
|
||||
refimage = refimage.to(device=device)
|
||||
if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
|
||||
self.upcast_vae()
|
||||
refimage = refimage.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
if refimage.dtype != self.vae.dtype:
|
||||
refimage = refimage.to(dtype=self.vae.dtype)
|
||||
# encode the mask image into latents space so we can concatenate it to the latents
|
||||
if isinstance(generator, list):
|
||||
ref_image_latents = [
|
||||
self.vae.encode(refimage[i : i + 1]).latent_dist.sample(generator=generator[i])
|
||||
for i in range(batch_size)
|
||||
]
|
||||
ref_image_latents = torch.cat(ref_image_latents, dim=0)
|
||||
else:
|
||||
ref_image_latents = self.vae.encode(refimage).latent_dist.sample(generator=generator)
|
||||
ref_image_latents = self.vae.config.scaling_factor * ref_image_latents
|
||||
def check_ref_inputs(
|
||||
self,
|
||||
ref_image,
|
||||
reference_guidance_start,
|
||||
reference_guidance_end,
|
||||
style_fidelity,
|
||||
reference_attn,
|
||||
reference_adain,
|
||||
):
|
||||
ref_image_is_pil = isinstance(ref_image, PIL.Image.Image)
|
||||
ref_image_is_tensor = isinstance(ref_image, torch.Tensor)
|
||||
|
||||
# duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method
|
||||
if ref_image_latents.shape[0] < batch_size:
|
||||
if not batch_size % ref_image_latents.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
|
||||
f" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed."
|
||||
" Make sure the number of images that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
ref_image_latents = ref_image_latents.repeat(batch_size // ref_image_latents.shape[0], 1, 1, 1)
|
||||
if not ref_image_is_pil and not ref_image_is_tensor:
|
||||
raise TypeError(
|
||||
f"ref image must be passed and be one of PIL image or torch tensor, but is {type(ref_image)}"
|
||||
)
|
||||
|
||||
ref_image_latents = torch.cat([ref_image_latents] * 2) if do_classifier_free_guidance else ref_image_latents
|
||||
if not reference_attn and not reference_adain:
|
||||
raise ValueError("`reference_attn` or `reference_adain` must be True.")
|
||||
|
||||
# aligning device to prevent device errors when concating it with the latent model input
|
||||
ref_image_latents = ref_image_latents.to(device=device, dtype=dtype)
|
||||
return ref_image_latents
|
||||
if style_fidelity < 0.0:
|
||||
raise ValueError(f"style fidelity: {style_fidelity} can't be smaller than 0.")
|
||||
if style_fidelity > 1.0:
|
||||
raise ValueError(f"style fidelity: {style_fidelity} can't be larger than 1.0.")
|
||||
|
||||
if reference_guidance_start >= reference_guidance_end:
|
||||
raise ValueError(
|
||||
f"reference guidance start: {reference_guidance_start} cannot be larger or equal to reference guidance end: {reference_guidance_end}."
|
||||
)
|
||||
if reference_guidance_start < 0.0:
|
||||
raise ValueError(f"reference guidance start: {reference_guidance_start} can't be smaller than 0.")
|
||||
if reference_guidance_end > 1.0:
|
||||
raise ValueError(f"reference guidance end: {reference_guidance_end} can't be larger than 1.0.")
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
@@ -194,6 +268,8 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: List[int] = None,
|
||||
sigmas: List[float] = None,
|
||||
denoising_end: Optional[float] = None,
|
||||
guidance_scale: float = 5.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
@@ -206,28 +282,220 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
ip_adapter_image: Optional[PipelineImageInput] = None,
|
||||
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
original_size: Optional[Tuple[int, int]] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
target_size: Optional[Tuple[int, int]] = None,
|
||||
negative_original_size: Optional[Tuple[int, int]] = None,
|
||||
negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
negative_target_size: Optional[Tuple[int, int]] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
attention_auto_machine_weight: float = 1.0,
|
||||
gn_auto_machine_weight: float = 1.0,
|
||||
reference_guidance_start: float = 0.0,
|
||||
reference_guidance_end: float = 1.0,
|
||||
style_fidelity: float = 0.5,
|
||||
reference_attn: bool = True,
|
||||
reference_adain: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
assert reference_attn or reference_adain, "`reference_attn` or `reference_adain` must be True."
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
||||
used in both text-encoders
|
||||
ref_image (`torch.Tensor`, `PIL.Image.Image`):
|
||||
The Reference Control input condition. Reference Control uses this input condition to generate guidance to Unet. If
|
||||
the type is specified as `Torch.Tensor`, it is passed to Reference Control as is. `PIL.Image.Image` can
|
||||
also be accepted as an image.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
Anything below 512 pixels won't work well for
|
||||
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
||||
and checkpoints that are not specifically fine-tuned on low resolutions.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
Anything below 512 pixels won't work well for
|
||||
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
|
||||
and checkpoints that are not specifically fine-tuned on low resolutions.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
denoising_end (`float`, *optional*):
|
||||
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
|
||||
completed before it is intentionally prematurely terminated. As a result, the returned sample will
|
||||
still retain a substantial amount of noise as determined by the discrete timesteps selected by the
|
||||
scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
|
||||
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
|
||||
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
|
||||
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
||||
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
pooled_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
||||
input argument.
|
||||
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
||||
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
||||
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
||||
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
|
||||
contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
|
||||
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
||||
of a plain tuple.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
||||
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
|
||||
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||
Guidance rescale factor should fix overexposure when using zero terminal SNR.
|
||||
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
||||
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
|
||||
explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
||||
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
||||
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
|
||||
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
|
||||
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
||||
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
For most cases, `target_size` should be set to the desired height and width of the generated image. If
|
||||
not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
|
||||
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
||||
negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
To negatively condition the generation process based on a specific image resolution. Part of SDXL's
|
||||
micro-conditioning as explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
||||
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
||||
negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
|
||||
To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
|
||||
micro-conditioning as explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
||||
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
||||
negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
||||
To negatively condition the generation process based on a target image resolution. It should be as same
|
||||
as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
|
||||
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
|
||||
information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
|
||||
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
||||
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
||||
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
||||
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
||||
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
attention_auto_machine_weight (`float`):
|
||||
Weight of using reference query for self attention's context.
|
||||
If attention_auto_machine_weight=1.0, use reference query for all self attention's context.
|
||||
gn_auto_machine_weight (`float`):
|
||||
Weight of using reference adain. If gn_auto_machine_weight=2.0, use all reference adain plugins.
|
||||
reference_guidance_start (`float`, *optional*, defaults to 0.0):
|
||||
The percentage of total steps at which the reference ControlNet starts applying.
|
||||
reference_guidance_end (`float`, *optional*, defaults to 1.0):
|
||||
The percentage of total steps at which the reference ControlNet stops applying.
|
||||
style_fidelity (`float`):
|
||||
style fidelity of ref_uncond_xt. If style_fidelity=1.0, control more important,
|
||||
elif style_fidelity=0.0, prompt more important, else balanced.
|
||||
reference_attn (`bool`):
|
||||
Whether to use reference query for self attention's context.
|
||||
reference_adain (`bool`):
|
||||
Whether to use reference adain.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
callback = kwargs.pop("callback", None)
|
||||
callback_steps = kwargs.pop("callback_steps", None)
|
||||
|
||||
if callback is not None:
|
||||
deprecate(
|
||||
"callback",
|
||||
"1.0.0",
|
||||
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
||||
)
|
||||
if callback_steps is not None:
|
||||
deprecate(
|
||||
"callback_steps",
|
||||
"1.0.0",
|
||||
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
|
||||
)
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 0. Default height and width to unet
|
||||
# height, width = self._default_height_width(height, width, ref_image)
|
||||
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
original_size = original_size or (height, width)
|
||||
target_size = target_size or (height, width)
|
||||
|
||||
@@ -244,8 +512,27 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
ip_adapter_image,
|
||||
ip_adapter_image_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
self.check_ref_inputs(
|
||||
ref_image,
|
||||
reference_guidance_start,
|
||||
reference_guidance_end,
|
||||
style_fidelity,
|
||||
reference_attn,
|
||||
reference_adain,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._guidance_rescale = guidance_rescale
|
||||
self._clip_skip = clip_skip
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._denoising_end = denoising_end
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
@@ -256,15 +543,11 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
lora_scale = (
|
||||
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
||||
)
|
||||
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
@@ -275,17 +558,19 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
|
||||
prompt_2=prompt_2,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt_2=negative_prompt_2,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
lora_scale=lora_scale,
|
||||
clip_skip=self.clip_skip,
|
||||
)
|
||||
|
||||
# 4. Preprocess reference image
|
||||
ref_image = self.prepare_image(
|
||||
ref_image = self.prepare_ref_image(
|
||||
image=ref_image,
|
||||
width=width,
|
||||
height=height,
|
||||
@@ -296,9 +581,9 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
|
||||
)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
timesteps = self.scheduler.timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler, num_inference_steps, device, timesteps, sigmas
|
||||
)
|
||||
|
||||
# 6. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
@@ -312,6 +597,7 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 7. Prepare reference latent variables
|
||||
ref_image_latents = self.prepare_ref_latents(
|
||||
ref_image,
|
||||
@@ -319,13 +605,21 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
do_classifier_free_guidance,
|
||||
self.do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 9. Modify self attebtion and group norm
|
||||
# 8.1 Create tensor stating which reference controlnets to keep
|
||||
reference_keeps = []
|
||||
for i in range(len(timesteps)):
|
||||
reference_keep = 1.0 - float(
|
||||
i / len(timesteps) < reference_guidance_start or (i + 1) / len(timesteps) > reference_guidance_end
|
||||
)
|
||||
reference_keeps.append(reference_keep)
|
||||
|
||||
# 8.2 Modify self attention and group norm
|
||||
MODE = "write"
|
||||
uc_mask = (
|
||||
torch.Tensor([1] * batch_size * num_images_per_prompt + [0] * batch_size * num_images_per_prompt)
|
||||
@@ -333,6 +627,8 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
|
||||
.bool()
|
||||
)
|
||||
|
||||
do_classifier_free_guidance = self.do_classifier_free_guidance
|
||||
|
||||
def hacked_basic_transformer_inner_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -604,7 +900,7 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
|
||||
return hidden_states
|
||||
|
||||
def hacked_UpBlock2D_forward(
|
||||
self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, **kwargs
|
||||
self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, *args, **kwargs
|
||||
):
|
||||
eps = 1e-6
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
@@ -684,7 +980,7 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
|
||||
module.var_bank = []
|
||||
module.gn_weight *= 2
|
||||
|
||||
# 10. Prepare added time ids & embeddings
|
||||
# 9. Prepare added time ids & embeddings
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
if self.text_encoder_2 is None:
|
||||
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
||||
@@ -698,62 +994,101 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
if negative_original_size is not None and negative_target_size is not None:
|
||||
negative_add_time_ids = self._get_add_time_ids(
|
||||
negative_original_size,
|
||||
negative_crops_coords_top_left,
|
||||
negative_target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
else:
|
||||
negative_add_time_ids = add_time_ids
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
||||
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
|
||||
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device)
|
||||
add_text_embeds = add_text_embeds.to(device)
|
||||
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
# 11. Denoising loop
|
||||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
||||
image_embeds = self.prepare_ip_adapter_image_embeds(
|
||||
ip_adapter_image,
|
||||
ip_adapter_image_embeds,
|
||||
device,
|
||||
batch_size * num_images_per_prompt,
|
||||
self.do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
# 10. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
# 10.1 Apply denoising_end
|
||||
if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1:
|
||||
if (
|
||||
self.denoising_end is not None
|
||||
and isinstance(self.denoising_end, float)
|
||||
and self.denoising_end > 0
|
||||
and self.denoising_end < 1
|
||||
):
|
||||
discrete_timestep_cutoff = int(
|
||||
round(
|
||||
self.scheduler.config.num_train_timesteps
|
||||
- (denoising_end * self.scheduler.config.num_train_timesteps)
|
||||
- (self.denoising_end * self.scheduler.config.num_train_timesteps)
|
||||
)
|
||||
)
|
||||
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
||||
timesteps = timesteps[:num_inference_steps]
|
||||
|
||||
# 11. Optionally get Guidance Scale Embedding
|
||||
timestep_cond = None
|
||||
if self.unet.config.time_cond_proj_dim is not None:
|
||||
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
|
||||
timestep_cond = self.get_guidance_scale_embedding(
|
||||
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
||||
).to(device=device, dtype=latents.dtype)
|
||||
|
||||
self._num_timesteps = len(timesteps)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
||||
added_cond_kwargs["image_embeds"] = image_embeds
|
||||
|
||||
# ref only part
|
||||
noise = randn_tensor(
|
||||
ref_image_latents.shape, generator=generator, device=device, dtype=ref_image_latents.dtype
|
||||
)
|
||||
ref_xt = self.scheduler.add_noise(
|
||||
ref_image_latents,
|
||||
noise,
|
||||
t.reshape(
|
||||
1,
|
||||
),
|
||||
)
|
||||
ref_xt = self.scheduler.scale_model_input(ref_xt, t)
|
||||
if reference_keeps[i] > 0:
|
||||
noise = randn_tensor(
|
||||
ref_image_latents.shape, generator=generator, device=device, dtype=ref_image_latents.dtype
|
||||
)
|
||||
ref_xt = self.scheduler.add_noise(
|
||||
ref_image_latents,
|
||||
noise,
|
||||
t.reshape(
|
||||
1,
|
||||
),
|
||||
)
|
||||
ref_xt = self.scheduler.scale_model_input(ref_xt, t)
|
||||
|
||||
MODE = "write"
|
||||
|
||||
self.unet(
|
||||
ref_xt,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
MODE = "write"
|
||||
self.unet(
|
||||
ref_xt,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
# predict the noise residual
|
||||
MODE = "read"
|
||||
@@ -761,22 +1096,44 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
timestep_cond=timestep_cond,
|
||||
cross_attention_kwargs=self.cross_attention_kwargs,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
|
||||
negative_pooled_prompt_embeds = callback_outputs.pop(
|
||||
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
||||
)
|
||||
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
|
||||
negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
@@ -785,6 +1142,9 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if not output_type == "latent":
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
||||
@@ -792,25 +1152,43 @@ class StableDiffusionXLReferencePipeline(StableDiffusionXLPipeline):
|
||||
if needs_upcasting:
|
||||
self.upcast_vae()
|
||||
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
|
||||
elif latents.dtype != self.vae.dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
self.vae = self.vae.to(latents.dtype)
|
||||
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
# unscale/denormalize the latents
|
||||
# denormalize with the mean and std if available and not None
|
||||
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
|
||||
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
|
||||
if has_latents_mean and has_latents_std:
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
|
||||
)
|
||||
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
|
||||
else:
|
||||
latents = latents / self.vae.config.scaling_factor
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
# cast back to fp16 if needed
|
||||
if needs_upcasting:
|
||||
self.vae.to(dtype=torch.float16)
|
||||
else:
|
||||
image = latents
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
|
||||
# apply watermark if available
|
||||
if self.watermark is not None:
|
||||
image = self.watermark.apply_watermark(image)
|
||||
if not output_type == "latent":
|
||||
# apply watermark if available
|
||||
if self.watermark is not None:
|
||||
image = self.watermark.apply_watermark(image)
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
@@ -571,9 +571,6 @@ def parse_args(input_args=None):
|
||||
if args.dataset_name is None and args.train_data_dir is None:
|
||||
raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
|
||||
|
||||
if args.dataset_name is not None and args.train_data_dir is not None:
|
||||
raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
|
||||
|
||||
if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
|
||||
raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
|
||||
|
||||
@@ -615,6 +612,7 @@ def make_train_dataset(args, tokenizer, accelerator):
|
||||
args.dataset_name,
|
||||
args.dataset_config_name,
|
||||
cache_dir=args.cache_dir,
|
||||
data_dir=args.train_data_dir,
|
||||
)
|
||||
else:
|
||||
if args.train_data_dir is not None:
|
||||
|
||||
@@ -598,9 +598,6 @@ def parse_args(input_args=None):
|
||||
if args.dataset_name is None and args.train_data_dir is None:
|
||||
raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
|
||||
|
||||
if args.dataset_name is not None and args.train_data_dir is not None:
|
||||
raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
|
||||
|
||||
if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
|
||||
raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
|
||||
|
||||
@@ -642,6 +639,7 @@ def get_train_dataset(args, accelerator):
|
||||
args.dataset_name,
|
||||
args.dataset_config_name,
|
||||
cache_dir=args.cache_dir,
|
||||
data_dir=args.train_data_dir,
|
||||
)
|
||||
else:
|
||||
if args.train_data_dir is not None:
|
||||
|
||||
@@ -118,7 +118,7 @@ accelerate launch train_dreambooth_flux.py \
|
||||
|
||||
To better track our training experiments, we're using the following flags in the command above:
|
||||
|
||||
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
|
||||
* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.
|
||||
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
|
||||
|
||||
> [!NOTE]
|
||||
|
||||
@@ -105,7 +105,7 @@ accelerate launch train_dreambooth_sd3.py \
|
||||
|
||||
To better track our training experiments, we're using the following flags in the command above:
|
||||
|
||||
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
|
||||
* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.
|
||||
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
|
||||
|
||||
> [!NOTE]
|
||||
|
||||
@@ -99,7 +99,7 @@ accelerate launch train_dreambooth_lora_sdxl.py \
|
||||
|
||||
To better track our training experiments, we're using the following flags in the command above:
|
||||
|
||||
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
|
||||
* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.
|
||||
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
|
||||
|
||||
Our experiments were conducted on a single 40GB A100 GPU.
|
||||
|
||||
@@ -1648,11 +1648,15 @@ def main(args):
|
||||
prompt=prompts,
|
||||
)
|
||||
else:
|
||||
elems_to_repeat = len(prompts)
|
||||
if args.train_text_encoder:
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
|
||||
text_encoders=[text_encoder_one, text_encoder_two],
|
||||
tokenizers=[None, None],
|
||||
text_input_ids_list=[tokens_one, tokens_two],
|
||||
text_input_ids_list=[
|
||||
tokens_one.repeat(elems_to_repeat, 1),
|
||||
tokens_two.repeat(elems_to_repeat, 1),
|
||||
],
|
||||
max_sequence_length=args.max_sequence_length,
|
||||
device=accelerator.device,
|
||||
prompt=args.instance_prompt,
|
||||
|
||||
@@ -1294,10 +1294,13 @@ def main(args):
|
||||
for model in models:
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_one))): # or text_encoder_two
|
||||
# both text encoders are of the same class, so we check hidden size to distinguish between the two
|
||||
hidden_size = unwrap_model(model).config.hidden_size
|
||||
if hidden_size == 768:
|
||||
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
elif hidden_size == 1280:
|
||||
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
|
||||
@@ -67,6 +67,7 @@ from diffusers.utils import (
|
||||
convert_state_dict_to_diffusers,
|
||||
convert_state_dict_to_kohya,
|
||||
convert_unet_state_dict_to_peft,
|
||||
is_peft_version,
|
||||
is_wandb_available,
|
||||
)
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
@@ -1183,26 +1184,33 @@ def main(args):
|
||||
text_encoder_one.gradient_checkpointing_enable()
|
||||
text_encoder_two.gradient_checkpointing_enable()
|
||||
|
||||
def get_lora_config(rank, use_dora, target_modules):
|
||||
base_config = {
|
||||
"r": rank,
|
||||
"lora_alpha": rank,
|
||||
"init_lora_weights": "gaussian",
|
||||
"target_modules": target_modules,
|
||||
}
|
||||
if use_dora:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
base_config["use_dora"] = True
|
||||
|
||||
return LoraConfig(**base_config)
|
||||
|
||||
# now we will add new LoRA weights to the attention layers
|
||||
unet_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
use_dora=args.use_dora,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
)
|
||||
unet_target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
|
||||
unet_lora_config = get_lora_config(rank=args.rank, use_dora=args.use_dora, target_modules=unet_target_modules)
|
||||
unet.add_adapter(unet_lora_config)
|
||||
|
||||
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
|
||||
# So, instead, we monkey-patch the forward calls of its attention-blocks.
|
||||
if args.train_text_encoder:
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
use_dora=args.use_dora,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
)
|
||||
text_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
|
||||
text_lora_config = get_lora_config(rank=args.rank, use_dora=args.use_dora, target_modules=text_target_modules)
|
||||
text_encoder_one.add_adapter(text_lora_config)
|
||||
text_encoder_two.add_adapter(text_lora_config)
|
||||
|
||||
|
||||
@@ -0,0 +1,226 @@
|
||||
# IP Adapter Training Example
|
||||
|
||||
[IP Adapter](https://arxiv.org/abs/2308.06721) is a novel approach designed to enhance text-to-image models such as Stable Diffusion by enabling them to generate images based on image prompts rather than text prompts alone. Unlike traditional methods that rely solely on complex text prompts, IP Adapter introduces the concept of using image prompts, leveraging the idea that "an image is worth a thousand words." By decoupling cross-attention layers for text and image features, IP Adapter effectively integrates image prompts into the generation process without the need for extensive fine-tuning or large computing resources.
|
||||
|
||||
## Training locally with PyTorch
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then cd in the example folder and run
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Or for a default accelerate configuration without answering questions about your environment
|
||||
|
||||
```bash
|
||||
accelerate config default
|
||||
```
|
||||
|
||||
Or if your environment doesn't support an interactive shell e.g. a notebook
|
||||
|
||||
```python
|
||||
from accelerate.utils import write_basic_config
|
||||
write_basic_config()
|
||||
```
|
||||
|
||||
Certainly! Below is the documentation in pure Markdown format:
|
||||
|
||||
### Accelerate Launch Command Documentation
|
||||
|
||||
#### Description:
|
||||
The Accelerate launch command is used to train a model using multiple GPUs and mixed precision training. It launches the training script `tutorial_train_ip-adapter.py` with specified parameters and configurations.
|
||||
|
||||
#### Usage Example:
|
||||
|
||||
```
|
||||
accelerate launch --mixed_precision "fp16" \
|
||||
tutorial_train_ip-adapter.py \
|
||||
--pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5/" \
|
||||
--image_encoder_path="{image_encoder_path}" \
|
||||
--data_json_file="{data.json}" \
|
||||
--data_root_path="{image_path}" \
|
||||
--mixed_precision="fp16" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=8 \
|
||||
--dataloader_num_workers=4 \
|
||||
--learning_rate=1e-04 \
|
||||
--weight_decay=0.01 \
|
||||
--output_dir="{output_dir}" \
|
||||
--save_steps=10000
|
||||
```
|
||||
|
||||
### Multi-GPU Script:
|
||||
```
|
||||
accelerate launch --num_processes 8 --multi_gpu --mixed_precision "fp16" \
|
||||
tutorial_train_ip-adapter.py \
|
||||
--pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5/" \
|
||||
--image_encoder_path="{image_encoder_path}" \
|
||||
--data_json_file="{data.json}" \
|
||||
--data_root_path="{image_path}" \
|
||||
--mixed_precision="fp16" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=8 \
|
||||
--dataloader_num_workers=4 \
|
||||
--learning_rate=1e-04 \
|
||||
--weight_decay=0.01 \
|
||||
--output_dir="{output_dir}" \
|
||||
--save_steps=10000
|
||||
```
|
||||
|
||||
#### Parameters:
|
||||
- `--num_processes`: Number of processes to launch for distributed training (in this example, 8 processes).
|
||||
- `--multi_gpu`: Flag indicating the usage of multiple GPUs for training.
|
||||
- `--mixed_precision "fp16"`: Enables mixed precision training with 16-bit floating-point precision.
|
||||
- `tutorial_train_ip-adapter.py`: Name of the training script to be executed.
|
||||
- `--pretrained_model_name_or_path`: Path or identifier for a pretrained model.
|
||||
- `--image_encoder_path`: Path to the CLIP image encoder.
|
||||
- `--data_json_file`: Path to the training data in JSON format.
|
||||
- `--data_root_path`: Root path where training images are located.
|
||||
- `--resolution`: Resolution of input images (512x512 in this example).
|
||||
- `--train_batch_size`: Batch size for training data (8 in this example).
|
||||
- `--dataloader_num_workers`: Number of subprocesses for data loading (4 in this example).
|
||||
- `--learning_rate`: Learning rate for training (1e-04 in this example).
|
||||
- `--weight_decay`: Weight decay for regularization (0.01 in this example).
|
||||
- `--output_dir`: Directory to save model checkpoints and predictions.
|
||||
- `--save_steps`: Frequency of saving checkpoints during training (10000 in this example).
|
||||
|
||||
### Inference
|
||||
|
||||
#### Description:
|
||||
The provided inference code is used to load a trained model checkpoint and extract the components related to image projection and IP (Image Processing) adapter. These components are then saved into a binary file for later use in inference.
|
||||
|
||||
#### Usage Example:
|
||||
```python
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
# Load the trained model checkpoint in safetensors format
|
||||
ckpt = "checkpoint-50000/pytorch_model.safetensors"
|
||||
sd = load_file(ckpt) # Using safetensors load function
|
||||
|
||||
# Extract image projection and IP adapter components
|
||||
image_proj_sd = {}
|
||||
ip_sd = {}
|
||||
|
||||
for k in sd:
|
||||
if k.startswith("unet"):
|
||||
pass # Skip unet-related keys
|
||||
elif k.startswith("image_proj_model"):
|
||||
image_proj_sd[k.replace("image_proj_model.", "")] = sd[k]
|
||||
elif k.startswith("adapter_modules"):
|
||||
ip_sd[k.replace("adapter_modules.", "")] = sd[k]
|
||||
|
||||
# Save the components into separate safetensors files
|
||||
save_file(image_proj_sd, "image_proj.safetensors")
|
||||
save_file(ip_sd, "ip_adapter.safetensors")
|
||||
```
|
||||
|
||||
### Sample Inference Script using the CLIP Model
|
||||
|
||||
```python
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from transformers import CLIPProcessor, CLIPModel # Using the Hugging Face CLIP model
|
||||
|
||||
# Load model components from safetensors
|
||||
image_proj_ckpt = "image_proj.safetensors"
|
||||
ip_adapter_ckpt = "ip_adapter.safetensors"
|
||||
|
||||
# Load the saved weights
|
||||
image_proj_sd = load_file(image_proj_ckpt)
|
||||
ip_adapter_sd = load_file(ip_adapter_ckpt)
|
||||
|
||||
# Define the model Parameters
|
||||
class ImageProjectionModel(torch.nn.Module):
|
||||
def __init__(self, input_dim=768, output_dim=512): # CLIP's default embedding size is 768
|
||||
super().__init__()
|
||||
self.model = torch.nn.Linear(input_dim, output_dim)
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
class IPAdapterModel(torch.nn.Module):
|
||||
def __init__(self, input_dim=512, output_dim=10): # Example for 10 classes
|
||||
super().__init__()
|
||||
self.model = torch.nn.Linear(input_dim, output_dim)
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
# Initialize models
|
||||
image_proj_model = ImageProjectionModel()
|
||||
ip_adapter_model = IPAdapterModel()
|
||||
|
||||
# Load weights into models
|
||||
image_proj_model.load_state_dict(image_proj_sd)
|
||||
ip_adapter_model.load_state_dict(ip_adapter_sd)
|
||||
|
||||
# Set models to evaluation mode
|
||||
image_proj_model.eval()
|
||||
ip_adapter_model.eval()
|
||||
|
||||
#Inference pipeline
|
||||
def inference(image_tensor):
|
||||
"""
|
||||
Run inference using the loaded models.
|
||||
|
||||
Args:
|
||||
image_tensor: Preprocessed image tensor from CLIPProcessor
|
||||
|
||||
Returns:
|
||||
Final inference results
|
||||
"""
|
||||
with torch.no_grad():
|
||||
# Step 1: Project the image features
|
||||
image_proj = image_proj_model(image_tensor)
|
||||
|
||||
# Step 2: Pass the projected features through the IP Adapter
|
||||
result = ip_adapter_model(image_proj)
|
||||
|
||||
return result
|
||||
|
||||
# Using CLIP for image preprocessing
|
||||
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
||||
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
|
||||
#Image file path
|
||||
image_path = "path/to/image.jpg"
|
||||
|
||||
# Preprocess the image
|
||||
inputs = processor(images=image_path, return_tensors="pt")
|
||||
image_features = clip_model.get_image_features(inputs["pixel_values"])
|
||||
|
||||
# Normalize the image features as per CLIP's recommendations
|
||||
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
||||
|
||||
# Run inference
|
||||
output = inference(image_features)
|
||||
print("Inference output:", output)
|
||||
```
|
||||
|
||||
#### Parameters:
|
||||
- `ckpt`: Path to the trained model checkpoint file.
|
||||
- `map_location="cpu"`: Specifies that the model should be loaded onto the CPU.
|
||||
- `image_proj_sd`: Dictionary to store the components related to image projection.
|
||||
- `ip_sd`: Dictionary to store the components related to the IP adapter.
|
||||
- `"unet"`, `"image_proj_model"`, `"adapter_modules"`: Prefixes indicating components of the model.
|
||||
@@ -0,0 +1,4 @@
|
||||
accelerate
|
||||
torchvision
|
||||
transformers>=4.25.1
|
||||
ip_adapter
|
||||
@@ -0,0 +1,415 @@
|
||||
import argparse
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import ProjectConfiguration
|
||||
from ip_adapter.attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor
|
||||
from ip_adapter.ip_adapter_faceid import MLPProjModel
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
|
||||
|
||||
|
||||
# Dataset
|
||||
class MyDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self, json_file, tokenizer, size=512, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=""
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.size = size
|
||||
self.i_drop_rate = i_drop_rate
|
||||
self.t_drop_rate = t_drop_rate
|
||||
self.ti_drop_rate = ti_drop_rate
|
||||
self.image_root_path = image_root_path
|
||||
|
||||
self.data = json.load(
|
||||
open(json_file)
|
||||
) # list of dict: [{"image_file": "1.png", "id_embed_file": "faceid.bin"}]
|
||||
|
||||
self.transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(self.size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data[idx]
|
||||
text = item["text"]
|
||||
image_file = item["image_file"]
|
||||
|
||||
# read image
|
||||
raw_image = Image.open(os.path.join(self.image_root_path, image_file))
|
||||
image = self.transform(raw_image.convert("RGB"))
|
||||
|
||||
face_id_embed = torch.load(item["id_embed_file"], map_location="cpu")
|
||||
face_id_embed = torch.from_numpy(face_id_embed)
|
||||
|
||||
# drop
|
||||
drop_image_embed = 0
|
||||
rand_num = random.random()
|
||||
if rand_num < self.i_drop_rate:
|
||||
drop_image_embed = 1
|
||||
elif rand_num < (self.i_drop_rate + self.t_drop_rate):
|
||||
text = ""
|
||||
elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate):
|
||||
text = ""
|
||||
drop_image_embed = 1
|
||||
if drop_image_embed:
|
||||
face_id_embed = torch.zeros_like(face_id_embed)
|
||||
# get text and tokenize
|
||||
text_input_ids = self.tokenizer(
|
||||
text,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
|
||||
return {
|
||||
"image": image,
|
||||
"text_input_ids": text_input_ids,
|
||||
"face_id_embed": face_id_embed,
|
||||
"drop_image_embed": drop_image_embed,
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
|
||||
def collate_fn(data):
|
||||
images = torch.stack([example["image"] for example in data])
|
||||
text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0)
|
||||
face_id_embed = torch.stack([example["face_id_embed"] for example in data])
|
||||
drop_image_embeds = [example["drop_image_embed"] for example in data]
|
||||
|
||||
return {
|
||||
"images": images,
|
||||
"text_input_ids": text_input_ids,
|
||||
"face_id_embed": face_id_embed,
|
||||
"drop_image_embeds": drop_image_embeds,
|
||||
}
|
||||
|
||||
|
||||
class IPAdapter(torch.nn.Module):
|
||||
"""IP-Adapter"""
|
||||
|
||||
def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
|
||||
super().__init__()
|
||||
self.unet = unet
|
||||
self.image_proj_model = image_proj_model
|
||||
self.adapter_modules = adapter_modules
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.load_from_checkpoint(ckpt_path)
|
||||
|
||||
def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
|
||||
ip_tokens = self.image_proj_model(image_embeds)
|
||||
encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
|
||||
# Predict the noise residual
|
||||
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
return noise_pred
|
||||
|
||||
def load_from_checkpoint(self, ckpt_path: str):
|
||||
# Calculate original checksums
|
||||
orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
|
||||
orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
|
||||
|
||||
state_dict = torch.load(ckpt_path, map_location="cpu")
|
||||
|
||||
# Load state dict for image_proj_model and adapter_modules
|
||||
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
|
||||
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
|
||||
|
||||
# Calculate new checksums
|
||||
new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
|
||||
new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
|
||||
|
||||
# Verify if the weights have changed
|
||||
assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
|
||||
assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
|
||||
|
||||
print(f"Successfully loaded weights from checkpoint {ckpt_path}")
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_ip_adapter_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_json_file",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Training data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_root_path",
|
||||
type=str,
|
||||
default="",
|
||||
required=True,
|
||||
help="Training data root path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_encoder_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to CLIP image encoder",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="sd-ip_adapter",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help=(
|
||||
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=512,
|
||||
help=("The resolution for input images"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=1e-4,
|
||||
help="Learning rate to use.",
|
||||
)
|
||||
parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
||||
parser.add_argument("--num_train_epochs", type=int, default=100)
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataloader_num_workers",
|
||||
type=int,
|
||||
default=0,
|
||||
help=(
|
||||
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_steps",
|
||||
type=int,
|
||||
default=2000,
|
||||
help=("Save a checkpoint of the training state every X updates"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
||||
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
||||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report_to",
|
||||
type=str,
|
||||
default="tensorboard",
|
||||
help=(
|
||||
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
||||
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
||||
),
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||
|
||||
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with=args.report_to,
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
if args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Load scheduler, tokenizer and models.
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
|
||||
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
|
||||
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
|
||||
# image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path)
|
||||
# freeze parameters of models to save more memory
|
||||
unet.requires_grad_(False)
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
# image_encoder.requires_grad_(False)
|
||||
|
||||
# ip-adapter
|
||||
image_proj_model = MLPProjModel(
|
||||
cross_attention_dim=unet.config.cross_attention_dim,
|
||||
id_embeddings_dim=512,
|
||||
num_tokens=4,
|
||||
)
|
||||
# init adapter modules
|
||||
lora_rank = 128
|
||||
attn_procs = {}
|
||||
unet_sd = unet.state_dict()
|
||||
for name in unet.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = unet.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = unet.config.block_out_channels[block_id]
|
||||
if cross_attention_dim is None:
|
||||
attn_procs[name] = LoRAAttnProcessor(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank
|
||||
)
|
||||
else:
|
||||
layer_name = name.split(".processor")[0]
|
||||
weights = {
|
||||
"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
|
||||
"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
|
||||
}
|
||||
attn_procs[name] = LoRAIPAttnProcessor(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank
|
||||
)
|
||||
attn_procs[name].load_state_dict(weights, strict=False)
|
||||
unet.set_attn_processor(attn_procs)
|
||||
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
|
||||
|
||||
ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)
|
||||
|
||||
weight_dtype = torch.float32
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
# unet.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
# image_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# optimizer
|
||||
params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters())
|
||||
optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)
|
||||
|
||||
# dataloader
|
||||
train_dataset = MyDataset(
|
||||
args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path
|
||||
)
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn,
|
||||
batch_size=args.train_batch_size,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)
|
||||
|
||||
global_step = 0
|
||||
for epoch in range(0, args.num_train_epochs):
|
||||
begin = time.perf_counter()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
load_data_time = time.perf_counter() - begin
|
||||
with accelerator.accumulate(ip_adapter):
|
||||
# Convert images to latent space
|
||||
with torch.no_grad():
|
||||
latents = vae.encode(
|
||||
batch["images"].to(accelerator.device, dtype=weight_dtype)
|
||||
).latent_dist.sample()
|
||||
latents = latents * vae.config.scaling_factor
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
image_embeds = batch["face_id_embed"].to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
with torch.no_grad():
|
||||
encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0]
|
||||
|
||||
noise_pred = ip_adapter(noisy_latents, timesteps, encoder_hidden_states, image_embeds)
|
||||
|
||||
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
|
||||
|
||||
# Gather the losses across all processes for logging (if we use distributed training).
|
||||
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item()
|
||||
|
||||
# Backpropagate
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if accelerator.is_main_process:
|
||||
print(
|
||||
"Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format(
|
||||
epoch, step, load_data_time, time.perf_counter() - begin, avg_loss
|
||||
)
|
||||
)
|
||||
|
||||
global_step += 1
|
||||
|
||||
if global_step % args.save_steps == 0:
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
|
||||
begin = time.perf_counter()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,422 @@
|
||||
import argparse
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import ProjectConfiguration
|
||||
from ip_adapter.ip_adapter import ImageProjModel
|
||||
from ip_adapter.utils import is_torch2_available
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
|
||||
|
||||
|
||||
if is_torch2_available():
|
||||
from ip_adapter.attention_processor import AttnProcessor2_0 as AttnProcessor
|
||||
from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor
|
||||
else:
|
||||
from ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor
|
||||
|
||||
|
||||
# Dataset
|
||||
class MyDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self, json_file, tokenizer, size=512, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=""
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.size = size
|
||||
self.i_drop_rate = i_drop_rate
|
||||
self.t_drop_rate = t_drop_rate
|
||||
self.ti_drop_rate = ti_drop_rate
|
||||
self.image_root_path = image_root_path
|
||||
|
||||
self.data = json.load(open(json_file)) # list of dict: [{"image_file": "1.png", "text": "A dog"}]
|
||||
|
||||
self.transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(self.size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
self.clip_image_processor = CLIPImageProcessor()
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data[idx]
|
||||
text = item["text"]
|
||||
image_file = item["image_file"]
|
||||
|
||||
# read image
|
||||
raw_image = Image.open(os.path.join(self.image_root_path, image_file))
|
||||
image = self.transform(raw_image.convert("RGB"))
|
||||
clip_image = self.clip_image_processor(images=raw_image, return_tensors="pt").pixel_values
|
||||
|
||||
# drop
|
||||
drop_image_embed = 0
|
||||
rand_num = random.random()
|
||||
if rand_num < self.i_drop_rate:
|
||||
drop_image_embed = 1
|
||||
elif rand_num < (self.i_drop_rate + self.t_drop_rate):
|
||||
text = ""
|
||||
elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate):
|
||||
text = ""
|
||||
drop_image_embed = 1
|
||||
# get text and tokenize
|
||||
text_input_ids = self.tokenizer(
|
||||
text,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
|
||||
return {
|
||||
"image": image,
|
||||
"text_input_ids": text_input_ids,
|
||||
"clip_image": clip_image,
|
||||
"drop_image_embed": drop_image_embed,
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
|
||||
def collate_fn(data):
|
||||
images = torch.stack([example["image"] for example in data])
|
||||
text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0)
|
||||
clip_images = torch.cat([example["clip_image"] for example in data], dim=0)
|
||||
drop_image_embeds = [example["drop_image_embed"] for example in data]
|
||||
|
||||
return {
|
||||
"images": images,
|
||||
"text_input_ids": text_input_ids,
|
||||
"clip_images": clip_images,
|
||||
"drop_image_embeds": drop_image_embeds,
|
||||
}
|
||||
|
||||
|
||||
class IPAdapter(torch.nn.Module):
|
||||
"""IP-Adapter"""
|
||||
|
||||
def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
|
||||
super().__init__()
|
||||
self.unet = unet
|
||||
self.image_proj_model = image_proj_model
|
||||
self.adapter_modules = adapter_modules
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.load_from_checkpoint(ckpt_path)
|
||||
|
||||
def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
|
||||
ip_tokens = self.image_proj_model(image_embeds)
|
||||
encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
|
||||
# Predict the noise residual
|
||||
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
return noise_pred
|
||||
|
||||
def load_from_checkpoint(self, ckpt_path: str):
|
||||
# Calculate original checksums
|
||||
orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
|
||||
orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
|
||||
|
||||
state_dict = torch.load(ckpt_path, map_location="cpu")
|
||||
|
||||
# Load state dict for image_proj_model and adapter_modules
|
||||
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
|
||||
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
|
||||
|
||||
# Calculate new checksums
|
||||
new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
|
||||
new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
|
||||
|
||||
# Verify if the weights have changed
|
||||
assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
|
||||
assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
|
||||
|
||||
print(f"Successfully loaded weights from checkpoint {ckpt_path}")
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_ip_adapter_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_json_file",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Training data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_root_path",
|
||||
type=str,
|
||||
default="",
|
||||
required=True,
|
||||
help="Training data root path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_encoder_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to CLIP image encoder",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="sd-ip_adapter",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help=(
|
||||
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=512,
|
||||
help=("The resolution for input images"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=1e-4,
|
||||
help="Learning rate to use.",
|
||||
)
|
||||
parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
||||
parser.add_argument("--num_train_epochs", type=int, default=100)
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataloader_num_workers",
|
||||
type=int,
|
||||
default=0,
|
||||
help=(
|
||||
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_steps",
|
||||
type=int,
|
||||
default=2000,
|
||||
help=("Save a checkpoint of the training state every X updates"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
||||
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
||||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report_to",
|
||||
type=str,
|
||||
default="tensorboard",
|
||||
help=(
|
||||
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
||||
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
||||
),
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||
|
||||
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with=args.report_to,
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
if args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Load scheduler, tokenizer and models.
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
|
||||
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
|
||||
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path)
|
||||
# freeze parameters of models to save more memory
|
||||
unet.requires_grad_(False)
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
image_encoder.requires_grad_(False)
|
||||
|
||||
# ip-adapter
|
||||
image_proj_model = ImageProjModel(
|
||||
cross_attention_dim=unet.config.cross_attention_dim,
|
||||
clip_embeddings_dim=image_encoder.config.projection_dim,
|
||||
clip_extra_context_tokens=4,
|
||||
)
|
||||
# init adapter modules
|
||||
attn_procs = {}
|
||||
unet_sd = unet.state_dict()
|
||||
for name in unet.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = unet.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = unet.config.block_out_channels[block_id]
|
||||
if cross_attention_dim is None:
|
||||
attn_procs[name] = AttnProcessor()
|
||||
else:
|
||||
layer_name = name.split(".processor")[0]
|
||||
weights = {
|
||||
"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
|
||||
"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
|
||||
}
|
||||
attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
|
||||
attn_procs[name].load_state_dict(weights)
|
||||
unet.set_attn_processor(attn_procs)
|
||||
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
|
||||
|
||||
ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)
|
||||
|
||||
weight_dtype = torch.float32
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
# unet.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
image_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# optimizer
|
||||
params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters())
|
||||
optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)
|
||||
|
||||
# dataloader
|
||||
train_dataset = MyDataset(
|
||||
args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path
|
||||
)
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn,
|
||||
batch_size=args.train_batch_size,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)
|
||||
|
||||
global_step = 0
|
||||
for epoch in range(0, args.num_train_epochs):
|
||||
begin = time.perf_counter()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
load_data_time = time.perf_counter() - begin
|
||||
with accelerator.accumulate(ip_adapter):
|
||||
# Convert images to latent space
|
||||
with torch.no_grad():
|
||||
latents = vae.encode(
|
||||
batch["images"].to(accelerator.device, dtype=weight_dtype)
|
||||
).latent_dist.sample()
|
||||
latents = latents * vae.config.scaling_factor
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
with torch.no_grad():
|
||||
image_embeds = image_encoder(
|
||||
batch["clip_images"].to(accelerator.device, dtype=weight_dtype)
|
||||
).image_embeds
|
||||
image_embeds_ = []
|
||||
for image_embed, drop_image_embed in zip(image_embeds, batch["drop_image_embeds"]):
|
||||
if drop_image_embed == 1:
|
||||
image_embeds_.append(torch.zeros_like(image_embed))
|
||||
else:
|
||||
image_embeds_.append(image_embed)
|
||||
image_embeds = torch.stack(image_embeds_)
|
||||
|
||||
with torch.no_grad():
|
||||
encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0]
|
||||
|
||||
noise_pred = ip_adapter(noisy_latents, timesteps, encoder_hidden_states, image_embeds)
|
||||
|
||||
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
|
||||
|
||||
# Gather the losses across all processes for logging (if we use distributed training).
|
||||
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item()
|
||||
|
||||
# Backpropagate
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if accelerator.is_main_process:
|
||||
print(
|
||||
"Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format(
|
||||
epoch, step, load_data_time, time.perf_counter() - begin, avg_loss
|
||||
)
|
||||
)
|
||||
|
||||
global_step += 1
|
||||
|
||||
if global_step % args.save_steps == 0:
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
|
||||
begin = time.perf_counter()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,445 @@
|
||||
import argparse
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import ProjectConfiguration
|
||||
from ip_adapter.resampler import Resampler
|
||||
from ip_adapter.utils import is_torch2_available
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
|
||||
|
||||
|
||||
if is_torch2_available():
|
||||
from ip_adapter.attention_processor import AttnProcessor2_0 as AttnProcessor
|
||||
from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor
|
||||
else:
|
||||
from ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor
|
||||
|
||||
|
||||
# Dataset
|
||||
class MyDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self, json_file, tokenizer, size=512, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=""
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.size = size
|
||||
self.i_drop_rate = i_drop_rate
|
||||
self.t_drop_rate = t_drop_rate
|
||||
self.ti_drop_rate = ti_drop_rate
|
||||
self.image_root_path = image_root_path
|
||||
|
||||
self.data = json.load(open(json_file)) # list of dict: [{"image_file": "1.png", "text": "A dog"}]
|
||||
|
||||
self.transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(self.size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
self.clip_image_processor = CLIPImageProcessor()
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data[idx]
|
||||
text = item["text"]
|
||||
image_file = item["image_file"]
|
||||
|
||||
# read image
|
||||
raw_image = Image.open(os.path.join(self.image_root_path, image_file))
|
||||
image = self.transform(raw_image.convert("RGB"))
|
||||
clip_image = self.clip_image_processor(images=raw_image, return_tensors="pt").pixel_values
|
||||
|
||||
# drop
|
||||
drop_image_embed = 0
|
||||
rand_num = random.random()
|
||||
if rand_num < self.i_drop_rate:
|
||||
drop_image_embed = 1
|
||||
elif rand_num < (self.i_drop_rate + self.t_drop_rate):
|
||||
text = ""
|
||||
elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate):
|
||||
text = ""
|
||||
drop_image_embed = 1
|
||||
# get text and tokenize
|
||||
text_input_ids = self.tokenizer(
|
||||
text,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
|
||||
return {
|
||||
"image": image,
|
||||
"text_input_ids": text_input_ids,
|
||||
"clip_image": clip_image,
|
||||
"drop_image_embed": drop_image_embed,
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
|
||||
def collate_fn(data):
|
||||
images = torch.stack([example["image"] for example in data])
|
||||
text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0)
|
||||
clip_images = torch.cat([example["clip_image"] for example in data], dim=0)
|
||||
drop_image_embeds = [example["drop_image_embed"] for example in data]
|
||||
|
||||
return {
|
||||
"images": images,
|
||||
"text_input_ids": text_input_ids,
|
||||
"clip_images": clip_images,
|
||||
"drop_image_embeds": drop_image_embeds,
|
||||
}
|
||||
|
||||
|
||||
class IPAdapter(torch.nn.Module):
|
||||
"""IP-Adapter"""
|
||||
|
||||
def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
|
||||
super().__init__()
|
||||
self.unet = unet
|
||||
self.image_proj_model = image_proj_model
|
||||
self.adapter_modules = adapter_modules
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.load_from_checkpoint(ckpt_path)
|
||||
|
||||
def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
|
||||
ip_tokens = self.image_proj_model(image_embeds)
|
||||
encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
|
||||
# Predict the noise residual
|
||||
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
return noise_pred
|
||||
|
||||
def load_from_checkpoint(self, ckpt_path: str):
|
||||
# Calculate original checksums
|
||||
orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
|
||||
orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
|
||||
|
||||
state_dict = torch.load(ckpt_path, map_location="cpu")
|
||||
|
||||
# Check if 'latents' exists in both the saved state_dict and the current model's state_dict
|
||||
strict_load_image_proj_model = True
|
||||
if "latents" in state_dict["image_proj"] and "latents" in self.image_proj_model.state_dict():
|
||||
# Check if the shapes are mismatched
|
||||
if state_dict["image_proj"]["latents"].shape != self.image_proj_model.state_dict()["latents"].shape:
|
||||
print(f"Shapes of 'image_proj.latents' in checkpoint {ckpt_path} and current model do not match.")
|
||||
print("Removing 'latents' from checkpoint and loading the rest of the weights.")
|
||||
del state_dict["image_proj"]["latents"]
|
||||
strict_load_image_proj_model = False
|
||||
|
||||
# Load state dict for image_proj_model and adapter_modules
|
||||
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict_load_image_proj_model)
|
||||
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
|
||||
|
||||
# Calculate new checksums
|
||||
new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
|
||||
new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
|
||||
|
||||
# Verify if the weights have changed
|
||||
assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
|
||||
assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
|
||||
|
||||
print(f"Successfully loaded weights from checkpoint {ckpt_path}")
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_ip_adapter_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_tokens",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Number of tokens to query from the CLIP image encoding.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_json_file",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Training data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_root_path",
|
||||
type=str,
|
||||
default="",
|
||||
required=True,
|
||||
help="Training data root path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_encoder_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to CLIP image encoder",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="sd-ip_adapter",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help=(
|
||||
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=512,
|
||||
help=("The resolution for input images"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=1e-4,
|
||||
help="Learning rate to use.",
|
||||
)
|
||||
parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
||||
parser.add_argument("--num_train_epochs", type=int, default=100)
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataloader_num_workers",
|
||||
type=int,
|
||||
default=0,
|
||||
help=(
|
||||
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_steps",
|
||||
type=int,
|
||||
default=2000,
|
||||
help=("Save a checkpoint of the training state every X updates"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
||||
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
||||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report_to",
|
||||
type=str,
|
||||
default="tensorboard",
|
||||
help=(
|
||||
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
||||
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
||||
),
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||
|
||||
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with=args.report_to,
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
if args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Load scheduler, tokenizer and models.
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
|
||||
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
|
||||
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path)
|
||||
# freeze parameters of models to save more memory
|
||||
unet.requires_grad_(False)
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
image_encoder.requires_grad_(False)
|
||||
|
||||
# ip-adapter-plus
|
||||
image_proj_model = Resampler(
|
||||
dim=unet.config.cross_attention_dim,
|
||||
depth=4,
|
||||
dim_head=64,
|
||||
heads=12,
|
||||
num_queries=args.num_tokens,
|
||||
embedding_dim=image_encoder.config.hidden_size,
|
||||
output_dim=unet.config.cross_attention_dim,
|
||||
ff_mult=4,
|
||||
)
|
||||
# init adapter modules
|
||||
attn_procs = {}
|
||||
unet_sd = unet.state_dict()
|
||||
for name in unet.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = unet.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = unet.config.block_out_channels[block_id]
|
||||
if cross_attention_dim is None:
|
||||
attn_procs[name] = AttnProcessor()
|
||||
else:
|
||||
layer_name = name.split(".processor")[0]
|
||||
weights = {
|
||||
"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
|
||||
"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
|
||||
}
|
||||
attn_procs[name] = IPAttnProcessor(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=args.num_tokens
|
||||
)
|
||||
attn_procs[name].load_state_dict(weights)
|
||||
unet.set_attn_processor(attn_procs)
|
||||
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
|
||||
|
||||
ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)
|
||||
|
||||
weight_dtype = torch.float32
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
# unet.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
image_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# optimizer
|
||||
params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters())
|
||||
optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)
|
||||
|
||||
# dataloader
|
||||
train_dataset = MyDataset(
|
||||
args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path
|
||||
)
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn,
|
||||
batch_size=args.train_batch_size,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)
|
||||
|
||||
global_step = 0
|
||||
for epoch in range(0, args.num_train_epochs):
|
||||
begin = time.perf_counter()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
load_data_time = time.perf_counter() - begin
|
||||
with accelerator.accumulate(ip_adapter):
|
||||
# Convert images to latent space
|
||||
with torch.no_grad():
|
||||
latents = vae.encode(
|
||||
batch["images"].to(accelerator.device, dtype=weight_dtype)
|
||||
).latent_dist.sample()
|
||||
latents = latents * vae.config.scaling_factor
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
clip_images = []
|
||||
for clip_image, drop_image_embed in zip(batch["clip_images"], batch["drop_image_embeds"]):
|
||||
if drop_image_embed == 1:
|
||||
clip_images.append(torch.zeros_like(clip_image))
|
||||
else:
|
||||
clip_images.append(clip_image)
|
||||
clip_images = torch.stack(clip_images, dim=0)
|
||||
with torch.no_grad():
|
||||
image_embeds = image_encoder(
|
||||
clip_images.to(accelerator.device, dtype=weight_dtype), output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
|
||||
with torch.no_grad():
|
||||
encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0]
|
||||
|
||||
noise_pred = ip_adapter(noisy_latents, timesteps, encoder_hidden_states, image_embeds)
|
||||
|
||||
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
|
||||
|
||||
# Gather the losses across all processes for logging (if we use distributed training).
|
||||
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item()
|
||||
|
||||
# Backpropagate
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if accelerator.is_main_process:
|
||||
print(
|
||||
"Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format(
|
||||
epoch, step, load_data_time, time.perf_counter() - begin, avg_loss
|
||||
)
|
||||
)
|
||||
|
||||
global_step += 1
|
||||
|
||||
if global_step % args.save_steps == 0:
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
|
||||
begin = time.perf_counter()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,520 @@
|
||||
import argparse
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import ProjectConfiguration
|
||||
from ip_adapter.ip_adapter import ImageProjModel
|
||||
from ip_adapter.utils import is_torch2_available
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPTextModel,
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
CLIPVisionModelWithProjection,
|
||||
)
|
||||
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
|
||||
|
||||
|
||||
if is_torch2_available():
|
||||
from ip_adapter.attention_processor import AttnProcessor2_0 as AttnProcessor
|
||||
from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor
|
||||
else:
|
||||
from ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor
|
||||
|
||||
|
||||
# Dataset
|
||||
class MyDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
json_file,
|
||||
tokenizer,
|
||||
tokenizer_2,
|
||||
size=1024,
|
||||
center_crop=True,
|
||||
t_drop_rate=0.05,
|
||||
i_drop_rate=0.05,
|
||||
ti_drop_rate=0.05,
|
||||
image_root_path="",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer_2 = tokenizer_2
|
||||
self.size = size
|
||||
self.center_crop = center_crop
|
||||
self.i_drop_rate = i_drop_rate
|
||||
self.t_drop_rate = t_drop_rate
|
||||
self.ti_drop_rate = ti_drop_rate
|
||||
self.image_root_path = image_root_path
|
||||
|
||||
self.data = json.load(open(json_file)) # list of dict: [{"image_file": "1.png", "text": "A dog"}]
|
||||
|
||||
self.transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
|
||||
self.clip_image_processor = CLIPImageProcessor()
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data[idx]
|
||||
text = item["text"]
|
||||
image_file = item["image_file"]
|
||||
|
||||
# read image
|
||||
raw_image = Image.open(os.path.join(self.image_root_path, image_file))
|
||||
|
||||
# original size
|
||||
original_width, original_height = raw_image.size
|
||||
original_size = torch.tensor([original_height, original_width])
|
||||
|
||||
image_tensor = self.transform(raw_image.convert("RGB"))
|
||||
# random crop
|
||||
delta_h = image_tensor.shape[1] - self.size
|
||||
delta_w = image_tensor.shape[2] - self.size
|
||||
assert not all([delta_h, delta_w])
|
||||
|
||||
if self.center_crop:
|
||||
top = delta_h // 2
|
||||
left = delta_w // 2
|
||||
else:
|
||||
top = np.random.randint(0, delta_h + 1)
|
||||
left = np.random.randint(0, delta_w + 1)
|
||||
image = transforms.functional.crop(image_tensor, top=top, left=left, height=self.size, width=self.size)
|
||||
crop_coords_top_left = torch.tensor([top, left])
|
||||
|
||||
clip_image = self.clip_image_processor(images=raw_image, return_tensors="pt").pixel_values
|
||||
|
||||
# drop
|
||||
drop_image_embed = 0
|
||||
rand_num = random.random()
|
||||
if rand_num < self.i_drop_rate:
|
||||
drop_image_embed = 1
|
||||
elif rand_num < (self.i_drop_rate + self.t_drop_rate):
|
||||
text = ""
|
||||
elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate):
|
||||
text = ""
|
||||
drop_image_embed = 1
|
||||
|
||||
# get text and tokenize
|
||||
text_input_ids = self.tokenizer(
|
||||
text,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
|
||||
text_input_ids_2 = self.tokenizer_2(
|
||||
text,
|
||||
max_length=self.tokenizer_2.model_max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
|
||||
return {
|
||||
"image": image,
|
||||
"text_input_ids": text_input_ids,
|
||||
"text_input_ids_2": text_input_ids_2,
|
||||
"clip_image": clip_image,
|
||||
"drop_image_embed": drop_image_embed,
|
||||
"original_size": original_size,
|
||||
"crop_coords_top_left": crop_coords_top_left,
|
||||
"target_size": torch.tensor([self.size, self.size]),
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
|
||||
def collate_fn(data):
|
||||
images = torch.stack([example["image"] for example in data])
|
||||
text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0)
|
||||
text_input_ids_2 = torch.cat([example["text_input_ids_2"] for example in data], dim=0)
|
||||
clip_images = torch.cat([example["clip_image"] for example in data], dim=0)
|
||||
drop_image_embeds = [example["drop_image_embed"] for example in data]
|
||||
original_size = torch.stack([example["original_size"] for example in data])
|
||||
crop_coords_top_left = torch.stack([example["crop_coords_top_left"] for example in data])
|
||||
target_size = torch.stack([example["target_size"] for example in data])
|
||||
|
||||
return {
|
||||
"images": images,
|
||||
"text_input_ids": text_input_ids,
|
||||
"text_input_ids_2": text_input_ids_2,
|
||||
"clip_images": clip_images,
|
||||
"drop_image_embeds": drop_image_embeds,
|
||||
"original_size": original_size,
|
||||
"crop_coords_top_left": crop_coords_top_left,
|
||||
"target_size": target_size,
|
||||
}
|
||||
|
||||
|
||||
class IPAdapter(torch.nn.Module):
|
||||
"""IP-Adapter"""
|
||||
|
||||
def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
|
||||
super().__init__()
|
||||
self.unet = unet
|
||||
self.image_proj_model = image_proj_model
|
||||
self.adapter_modules = adapter_modules
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.load_from_checkpoint(ckpt_path)
|
||||
|
||||
def forward(self, noisy_latents, timesteps, encoder_hidden_states, unet_added_cond_kwargs, image_embeds):
|
||||
ip_tokens = self.image_proj_model(image_embeds)
|
||||
encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
|
||||
# Predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=unet_added_cond_kwargs
|
||||
).sample
|
||||
return noise_pred
|
||||
|
||||
def load_from_checkpoint(self, ckpt_path: str):
|
||||
# Calculate original checksums
|
||||
orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
|
||||
orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
|
||||
|
||||
state_dict = torch.load(ckpt_path, map_location="cpu")
|
||||
|
||||
# Load state dict for image_proj_model and adapter_modules
|
||||
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
|
||||
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
|
||||
|
||||
# Calculate new checksums
|
||||
new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
|
||||
new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
|
||||
|
||||
# Verify if the weights have changed
|
||||
assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
|
||||
assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
|
||||
|
||||
print(f"Successfully loaded weights from checkpoint {ckpt_path}")
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_ip_adapter_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_json_file",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Training data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_root_path",
|
||||
type=str,
|
||||
default="",
|
||||
required=True,
|
||||
help="Training data root path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_encoder_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to CLIP image encoder",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="sd-ip_adapter",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help=(
|
||||
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=512,
|
||||
help=("The resolution for input images"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=1e-4,
|
||||
help="Learning rate to use.",
|
||||
)
|
||||
parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
||||
parser.add_argument("--num_train_epochs", type=int, default=100)
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
parser.add_argument("--noise_offset", type=float, default=None, help="noise offset")
|
||||
parser.add_argument(
|
||||
"--dataloader_num_workers",
|
||||
type=int,
|
||||
default=0,
|
||||
help=(
|
||||
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_steps",
|
||||
type=int,
|
||||
default=2000,
|
||||
help=("Save a checkpoint of the training state every X updates"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
||||
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
||||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report_to",
|
||||
type=str,
|
||||
default="tensorboard",
|
||||
help=(
|
||||
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
||||
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
||||
),
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||
|
||||
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
||||
|
||||
accelerator = Accelerator(
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with=args.report_to,
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
if args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Load scheduler, tokenizer and models.
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
|
||||
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer_2")
|
||||
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder_2"
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
|
||||
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path)
|
||||
# freeze parameters of models to save more memory
|
||||
unet.requires_grad_(False)
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder_2.requires_grad_(False)
|
||||
image_encoder.requires_grad_(False)
|
||||
|
||||
# ip-adapter
|
||||
num_tokens = 4
|
||||
image_proj_model = ImageProjModel(
|
||||
cross_attention_dim=unet.config.cross_attention_dim,
|
||||
clip_embeddings_dim=image_encoder.config.projection_dim,
|
||||
clip_extra_context_tokens=num_tokens,
|
||||
)
|
||||
# init adapter modules
|
||||
attn_procs = {}
|
||||
unet_sd = unet.state_dict()
|
||||
for name in unet.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = unet.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = unet.config.block_out_channels[block_id]
|
||||
if cross_attention_dim is None:
|
||||
attn_procs[name] = AttnProcessor()
|
||||
else:
|
||||
layer_name = name.split(".processor")[0]
|
||||
weights = {
|
||||
"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
|
||||
"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
|
||||
}
|
||||
attn_procs[name] = IPAttnProcessor(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=num_tokens
|
||||
)
|
||||
attn_procs[name].load_state_dict(weights)
|
||||
unet.set_attn_processor(attn_procs)
|
||||
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
|
||||
|
||||
ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)
|
||||
|
||||
weight_dtype = torch.float32
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
# unet.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.to(accelerator.device) # use fp32
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder_2.to(accelerator.device, dtype=weight_dtype)
|
||||
image_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# optimizer
|
||||
params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters())
|
||||
optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)
|
||||
|
||||
# dataloader
|
||||
train_dataset = MyDataset(
|
||||
args.data_json_file,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_2=tokenizer_2,
|
||||
size=args.resolution,
|
||||
image_root_path=args.data_root_path,
|
||||
)
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn,
|
||||
batch_size=args.train_batch_size,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)
|
||||
|
||||
global_step = 0
|
||||
for epoch in range(0, args.num_train_epochs):
|
||||
begin = time.perf_counter()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
load_data_time = time.perf_counter() - begin
|
||||
with accelerator.accumulate(ip_adapter):
|
||||
# Convert images to latent space
|
||||
with torch.no_grad():
|
||||
# vae of sdxl should use fp32
|
||||
latents = vae.encode(batch["images"].to(accelerator.device, dtype=vae.dtype)).latent_dist.sample()
|
||||
latents = latents * vae.config.scaling_factor
|
||||
latents = latents.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
if args.noise_offset:
|
||||
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1)).to(
|
||||
accelerator.device, dtype=weight_dtype
|
||||
)
|
||||
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
with torch.no_grad():
|
||||
image_embeds = image_encoder(
|
||||
batch["clip_images"].to(accelerator.device, dtype=weight_dtype)
|
||||
).image_embeds
|
||||
image_embeds_ = []
|
||||
for image_embed, drop_image_embed in zip(image_embeds, batch["drop_image_embeds"]):
|
||||
if drop_image_embed == 1:
|
||||
image_embeds_.append(torch.zeros_like(image_embed))
|
||||
else:
|
||||
image_embeds_.append(image_embed)
|
||||
image_embeds = torch.stack(image_embeds_)
|
||||
|
||||
with torch.no_grad():
|
||||
encoder_output = text_encoder(
|
||||
batch["text_input_ids"].to(accelerator.device), output_hidden_states=True
|
||||
)
|
||||
text_embeds = encoder_output.hidden_states[-2]
|
||||
encoder_output_2 = text_encoder_2(
|
||||
batch["text_input_ids_2"].to(accelerator.device), output_hidden_states=True
|
||||
)
|
||||
pooled_text_embeds = encoder_output_2[0]
|
||||
text_embeds_2 = encoder_output_2.hidden_states[-2]
|
||||
text_embeds = torch.concat([text_embeds, text_embeds_2], dim=-1) # concat
|
||||
|
||||
# add cond
|
||||
add_time_ids = [
|
||||
batch["original_size"].to(accelerator.device),
|
||||
batch["crop_coords_top_left"].to(accelerator.device),
|
||||
batch["target_size"].to(accelerator.device),
|
||||
]
|
||||
add_time_ids = torch.cat(add_time_ids, dim=1).to(accelerator.device, dtype=weight_dtype)
|
||||
unet_added_cond_kwargs = {"text_embeds": pooled_text_embeds, "time_ids": add_time_ids}
|
||||
|
||||
noise_pred = ip_adapter(noisy_latents, timesteps, text_embeds, unet_added_cond_kwargs, image_embeds)
|
||||
|
||||
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
|
||||
|
||||
# Gather the losses across all processes for logging (if we use distributed training).
|
||||
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item()
|
||||
|
||||
# Backpropagate
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if accelerator.is_main_process:
|
||||
print(
|
||||
"Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format(
|
||||
epoch, step, load_data_time, time.perf_counter() - begin, avg_loss
|
||||
)
|
||||
)
|
||||
|
||||
global_step += 1
|
||||
|
||||
if global_step % args.save_steps == 0:
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
|
||||
begin = time.perf_counter()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -215,7 +215,7 @@ class PixArtControlNetTransformerModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# 2. Blocks
|
||||
for block_index, block in enumerate(self.transformer.transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
# rc todo: for training and gradient checkpointing
|
||||
print("Gradient checkpointing is not supported for the controlnet transformer model, yet.")
|
||||
exit(1)
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
|
||||
# Create a server
|
||||
|
||||
Diffusers' pipelines can be used as an inference engine for a server. It supports concurrent and multithreaded requests to generate images that may be requested by multiple users at the same time.
|
||||
|
||||
This guide will show you how to use the [`StableDiffusion3Pipeline`] in a server, but feel free to use any pipeline you want.
|
||||
|
||||
|
||||
Start by navigating to the `examples/server` folder and installing all of the dependencies.
|
||||
|
||||
```py
|
||||
pip install .
|
||||
pip install -f requirements.txt
|
||||
```
|
||||
|
||||
Launch the server with the following command.
|
||||
|
||||
```py
|
||||
python server.py
|
||||
```
|
||||
|
||||
The server is accessed at http://localhost:8000. You can curl this model with the following command.
|
||||
```
|
||||
curl -X POST -H "Content-Type: application/json" --data '{"model": "something", "prompt": "a kitten in front of a fireplace"}' http://localhost:8000/v1/images/generations
|
||||
```
|
||||
|
||||
If you need to upgrade some dependencies, you can use either [pip-tools](https://github.com/jazzband/pip-tools) or [uv](https://github.com/astral-sh/uv). For example, upgrade the dependencies with `uv` using the following command.
|
||||
|
||||
```
|
||||
uv pip compile requirements.in -o requirements.txt
|
||||
```
|
||||
|
||||
|
||||
The server is built with [FastAPI](https://fastapi.tiangolo.com/async/). The endpoint for `v1/images/generations` is shown below.
|
||||
```py
|
||||
@app.post("/v1/images/generations")
|
||||
async def generate_image(image_input: TextToImageInput):
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
scheduler = shared_pipeline.pipeline.scheduler.from_config(shared_pipeline.pipeline.scheduler.config)
|
||||
pipeline = StableDiffusion3Pipeline.from_pipe(shared_pipeline.pipeline, scheduler=scheduler)
|
||||
generator = torch.Generator(device="cuda")
|
||||
generator.manual_seed(random.randint(0, 10000000))
|
||||
output = await loop.run_in_executor(None, lambda: pipeline(image_input.prompt, generator = generator))
|
||||
logger.info(f"output: {output}")
|
||||
image_url = save_image(output.images[0])
|
||||
return {"data": [{"url": image_url}]}
|
||||
except Exception as e:
|
||||
if isinstance(e, HTTPException):
|
||||
raise e
|
||||
elif hasattr(e, 'message'):
|
||||
raise HTTPException(status_code=500, detail=e.message + traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e) + traceback.format_exc())
|
||||
```
|
||||
The `generate_image` function is defined as asynchronous with the [async](https://fastapi.tiangolo.com/async/) keyword so that FastAPI knows that whatever is happening in this function won't necessarily return a result right away. Once it hits some point in the function that it needs to await some other [Task](https://docs.python.org/3/library/asyncio-task.html#asyncio.Task), the main thread goes back to answering other HTTP requests. This is shown in the code below with the [await](https://fastapi.tiangolo.com/async/#async-and-await) keyword.
|
||||
```py
|
||||
output = await loop.run_in_executor(None, lambda: pipeline(image_input.prompt, generator = generator))
|
||||
```
|
||||
At this point, the execution of the pipeline function is placed onto a [new thread](https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.run_in_executor), and the main thread performs other things until a result is returned from the `pipeline`.
|
||||
|
||||
Another important aspect of this implementation is creating a `pipeline` from `shared_pipeline`. The goal behind this is to avoid loading the underlying model more than once onto the GPU while still allowing for each new request that is running on a separate thread to have its own generator and scheduler. The scheduler, in particular, is not thread-safe, and it will cause errors like: `IndexError: index 21 is out of bounds for dimension 0 with size 21` if you try to use the same scheduler across multiple threads.
|
||||
@@ -0,0 +1,9 @@
|
||||
torch~=2.4.0
|
||||
transformers==4.46.1
|
||||
sentencepiece
|
||||
aiohttp
|
||||
py-consul
|
||||
prometheus_client >= 0.18.0
|
||||
prometheus-fastapi-instrumentator >= 7.0.0
|
||||
fastapi
|
||||
uvicorn
|
||||
@@ -0,0 +1,124 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile requirements.in -o requirements.txt
|
||||
aiohappyeyeballs==2.4.3
|
||||
# via aiohttp
|
||||
aiohttp==3.10.10
|
||||
# via -r requirements.in
|
||||
aiosignal==1.3.1
|
||||
# via aiohttp
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
anyio==4.6.2.post1
|
||||
# via starlette
|
||||
attrs==24.2.0
|
||||
# via aiohttp
|
||||
certifi==2024.8.30
|
||||
# via requests
|
||||
charset-normalizer==3.4.0
|
||||
# via requests
|
||||
click==8.1.7
|
||||
# via uvicorn
|
||||
fastapi==0.115.3
|
||||
# via -r requirements.in
|
||||
filelock==3.16.1
|
||||
# via
|
||||
# huggingface-hub
|
||||
# torch
|
||||
# transformers
|
||||
frozenlist==1.5.0
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
fsspec==2024.10.0
|
||||
# via
|
||||
# huggingface-hub
|
||||
# torch
|
||||
h11==0.14.0
|
||||
# via uvicorn
|
||||
huggingface-hub==0.26.1
|
||||
# via
|
||||
# tokenizers
|
||||
# transformers
|
||||
idna==3.10
|
||||
# via
|
||||
# anyio
|
||||
# requests
|
||||
# yarl
|
||||
jinja2==3.1.4
|
||||
# via torch
|
||||
markupsafe==3.0.2
|
||||
# via jinja2
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
multidict==6.1.0
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
networkx==3.4.2
|
||||
# via torch
|
||||
numpy==2.1.2
|
||||
# via transformers
|
||||
packaging==24.1
|
||||
# via
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
prometheus-client==0.21.0
|
||||
# via
|
||||
# -r requirements.in
|
||||
# prometheus-fastapi-instrumentator
|
||||
prometheus-fastapi-instrumentator==7.0.0
|
||||
# via -r requirements.in
|
||||
propcache==0.2.0
|
||||
# via yarl
|
||||
py-consul==1.5.3
|
||||
# via -r requirements.in
|
||||
pydantic==2.9.2
|
||||
# via fastapi
|
||||
pydantic-core==2.23.4
|
||||
# via pydantic
|
||||
pyyaml==6.0.2
|
||||
# via
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
regex==2024.9.11
|
||||
# via transformers
|
||||
requests==2.32.3
|
||||
# via
|
||||
# huggingface-hub
|
||||
# py-consul
|
||||
# transformers
|
||||
safetensors==0.4.5
|
||||
# via transformers
|
||||
sentencepiece==0.2.0
|
||||
# via -r requirements.in
|
||||
sniffio==1.3.1
|
||||
# via anyio
|
||||
starlette==0.41.0
|
||||
# via
|
||||
# fastapi
|
||||
# prometheus-fastapi-instrumentator
|
||||
sympy==1.13.3
|
||||
# via torch
|
||||
tokenizers==0.20.1
|
||||
# via transformers
|
||||
torch==2.4.1
|
||||
# via -r requirements.in
|
||||
tqdm==4.66.5
|
||||
# via
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
transformers==4.46.1
|
||||
# via -r requirements.in
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# fastapi
|
||||
# huggingface-hub
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# torch
|
||||
urllib3==2.2.3
|
||||
# via requests
|
||||
uvicorn==0.32.0
|
||||
# via -r requirements.in
|
||||
yarl==1.16.0
|
||||
# via aiohttp
|
||||
@@ -0,0 +1,133 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
import traceback
|
||||
import uuid
|
||||
|
||||
import aiohttp
|
||||
import torch
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from pydantic import BaseModel
|
||||
|
||||
from diffusers.pipelines.stable_diffusion_3 import StableDiffusion3Pipeline
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TextToImageInput(BaseModel):
|
||||
model: str
|
||||
prompt: str
|
||||
size: str | None = None
|
||||
n: int | None = None
|
||||
|
||||
|
||||
class HttpClient:
|
||||
session: aiohttp.ClientSession = None
|
||||
|
||||
def start(self):
|
||||
self.session = aiohttp.ClientSession()
|
||||
|
||||
async def stop(self):
|
||||
await self.session.close()
|
||||
self.session = None
|
||||
|
||||
def __call__(self) -> aiohttp.ClientSession:
|
||||
assert self.session is not None
|
||||
return self.session
|
||||
|
||||
|
||||
class TextToImagePipeline:
|
||||
pipeline: StableDiffusion3Pipeline = None
|
||||
device: str = None
|
||||
|
||||
def start(self):
|
||||
if torch.cuda.is_available():
|
||||
model_path = os.getenv("MODEL_PATH", "stabilityai/stable-diffusion-3.5-large")
|
||||
logger.info("Loading CUDA")
|
||||
self.device = "cuda"
|
||||
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to(device=self.device)
|
||||
elif torch.backends.mps.is_available():
|
||||
model_path = os.getenv("MODEL_PATH", "stabilityai/stable-diffusion-3.5-medium")
|
||||
logger.info("Loading MPS for Mac M Series")
|
||||
self.device = "mps"
|
||||
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to(device=self.device)
|
||||
else:
|
||||
raise Exception("No CUDA or MPS device available")
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
service_url = os.getenv("SERVICE_URL", "http://localhost:8000")
|
||||
image_dir = os.path.join(tempfile.gettempdir(), "images")
|
||||
if not os.path.exists(image_dir):
|
||||
os.makedirs(image_dir)
|
||||
app.mount("/images", StaticFiles(directory=image_dir), name="images")
|
||||
http_client = HttpClient()
|
||||
shared_pipeline = TextToImagePipeline()
|
||||
|
||||
# Configure CORS settings
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Allows all origins
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"], # Allows all methods, e.g., GET, POST, OPTIONS, etc.
|
||||
allow_headers=["*"], # Allows all headers
|
||||
)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
def startup():
|
||||
http_client.start()
|
||||
shared_pipeline.start()
|
||||
|
||||
|
||||
def save_image(image):
|
||||
filename = "draw" + str(uuid.uuid4()).split("-")[0] + ".png"
|
||||
image_path = os.path.join(image_dir, filename)
|
||||
# write image to disk at image_path
|
||||
logger.info(f"Saving image to {image_path}")
|
||||
image.save(image_path)
|
||||
return os.path.join(service_url, "images", filename)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
@app.post("/")
|
||||
@app.options("/")
|
||||
async def base():
|
||||
return "Welcome to Diffusers! Where you can use diffusion models to generate images"
|
||||
|
||||
|
||||
@app.post("/v1/images/generations")
|
||||
async def generate_image(image_input: TextToImageInput):
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
scheduler = shared_pipeline.pipeline.scheduler.from_config(shared_pipeline.pipeline.scheduler.config)
|
||||
pipeline = StableDiffusion3Pipeline.from_pipe(shared_pipeline.pipeline, scheduler=scheduler)
|
||||
generator = torch.Generator(device=shared_pipeline.device)
|
||||
generator.manual_seed(random.randint(0, 10000000))
|
||||
output = await loop.run_in_executor(None, lambda: pipeline(image_input.prompt, generator=generator))
|
||||
logger.info(f"output: {output}")
|
||||
image_url = save_image(output.images[0])
|
||||
return {"data": [{"url": image_url}]}
|
||||
except Exception as e:
|
||||
if isinstance(e, HTTPException):
|
||||
raise e
|
||||
elif hasattr(e, "message"):
|
||||
raise HTTPException(status_code=500, detail=e.message + traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e) + traceback.format_exc())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
@@ -483,7 +483,6 @@ def parse_args(input_args=None):
|
||||
# Sanity checks
|
||||
if args.dataset_name is None and args.train_data_dir is None:
|
||||
raise ValueError("Need either a dataset name or a training folder.")
|
||||
|
||||
if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
|
||||
raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
|
||||
|
||||
@@ -824,9 +823,7 @@ def main(args):
|
||||
if args.dataset_name is not None:
|
||||
# Downloading and loading a dataset from the hub.
|
||||
dataset = load_dataset(
|
||||
args.dataset_name,
|
||||
args.dataset_config_name,
|
||||
cache_dir=args.cache_dir,
|
||||
args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, data_dir=args.train_data_dir
|
||||
)
|
||||
else:
|
||||
data_files = {}
|
||||
|
||||
@@ -80,6 +80,8 @@ TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"post_attn1_layernorm": "norm2.norm",
|
||||
"time_embed.0": "time_embedding.linear_1",
|
||||
"time_embed.2": "time_embedding.linear_2",
|
||||
"ofs_embed.0": "ofs_embedding.linear_1",
|
||||
"ofs_embed.2": "ofs_embedding.linear_2",
|
||||
"mixins.patch_embed": "patch_embed",
|
||||
"mixins.final_layer.norm_final": "norm_out.norm",
|
||||
"mixins.final_layer.linear": "proj_out",
|
||||
@@ -140,6 +142,7 @@ def convert_transformer(
|
||||
use_rotary_positional_embeddings: bool,
|
||||
i2v: bool,
|
||||
dtype: torch.dtype,
|
||||
init_kwargs: Dict[str, Any],
|
||||
):
|
||||
PREFIX_KEY = "model.diffusion_model."
|
||||
|
||||
@@ -149,7 +152,9 @@ def convert_transformer(
|
||||
num_layers=num_layers,
|
||||
num_attention_heads=num_attention_heads,
|
||||
use_rotary_positional_embeddings=use_rotary_positional_embeddings,
|
||||
use_learned_positional_embeddings=i2v,
|
||||
ofs_embed_dim=512 if (i2v and init_kwargs["patch_size_t"] is not None) else None, # CogVideoX1.5-5B-I2V
|
||||
use_learned_positional_embeddings=i2v and init_kwargs["patch_size_t"] is None, # CogVideoX-5B-I2V
|
||||
**init_kwargs,
|
||||
).to(dtype=dtype)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
@@ -163,13 +168,18 @@ def convert_transformer(
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
transformer.load_state_dict(original_state_dict, strict=True)
|
||||
return transformer
|
||||
|
||||
|
||||
def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
|
||||
def convert_vae(ckpt_path: str, scaling_factor: float, version: str, dtype: torch.dtype):
|
||||
init_kwargs = {"scaling_factor": scaling_factor}
|
||||
if version == "1.5":
|
||||
init_kwargs.update({"invert_scale_latents": True})
|
||||
|
||||
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
|
||||
vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor).to(dtype=dtype)
|
||||
vae = AutoencoderKLCogVideoX(**init_kwargs).to(dtype=dtype)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
@@ -187,6 +197,34 @@ def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
|
||||
return vae
|
||||
|
||||
|
||||
def get_transformer_init_kwargs(version: str):
|
||||
if version == "1.0":
|
||||
vae_scale_factor_spatial = 8
|
||||
init_kwargs = {
|
||||
"patch_size": 2,
|
||||
"patch_size_t": None,
|
||||
"patch_bias": True,
|
||||
"sample_height": 480 // vae_scale_factor_spatial,
|
||||
"sample_width": 720 // vae_scale_factor_spatial,
|
||||
"sample_frames": 49,
|
||||
}
|
||||
|
||||
elif version == "1.5":
|
||||
vae_scale_factor_spatial = 8
|
||||
init_kwargs = {
|
||||
"patch_size": 2,
|
||||
"patch_size_t": 2,
|
||||
"patch_bias": False,
|
||||
"sample_height": 300,
|
||||
"sample_width": 300,
|
||||
"sample_frames": 81,
|
||||
}
|
||||
else:
|
||||
raise ValueError("Unsupported version of CogVideoX.")
|
||||
|
||||
return init_kwargs
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
@@ -202,6 +240,12 @@ def get_args():
|
||||
parser.add_argument(
|
||||
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--typecast_text_encoder",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether or not to apply fp16/bf16 precision to text_encoder",
|
||||
)
|
||||
# For CogVideoX-2B, num_layers is 30. For 5B, it is 42
|
||||
parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks")
|
||||
# For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
|
||||
@@ -214,7 +258,18 @@ def get_args():
|
||||
parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE")
|
||||
# For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
|
||||
parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE")
|
||||
parser.add_argument("--i2v", action="store_true", default=False, help="Whether to save the model weights in fp16")
|
||||
parser.add_argument(
|
||||
"--i2v",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether the model to be converted is the Image-to-Video version of CogVideoX.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--version",
|
||||
choices=["1.0", "1.5"],
|
||||
default="1.0",
|
||||
help="Which version of CogVideoX to use for initializing default modeling parameters.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -230,6 +285,7 @@ if __name__ == "__main__":
|
||||
dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32
|
||||
|
||||
if args.transformer_ckpt_path is not None:
|
||||
init_kwargs = get_transformer_init_kwargs(args.version)
|
||||
transformer = convert_transformer(
|
||||
args.transformer_ckpt_path,
|
||||
args.num_layers,
|
||||
@@ -237,14 +293,19 @@ if __name__ == "__main__":
|
||||
args.use_rotary_positional_embeddings,
|
||||
args.i2v,
|
||||
dtype,
|
||||
init_kwargs,
|
||||
)
|
||||
if args.vae_ckpt_path is not None:
|
||||
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype)
|
||||
# Keep VAE in float32 for better quality
|
||||
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, args.version, torch.float32)
|
||||
|
||||
text_encoder_id = "google/t5-v1_1-xxl"
|
||||
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
|
||||
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
|
||||
|
||||
if args.typecast_text_encoder:
|
||||
text_encoder = text_encoder.to(dtype=dtype)
|
||||
|
||||
# Apparently, the conversion does not work anymore without this :shrug:
|
||||
for param in text_encoder.parameters():
|
||||
param.data = param.data.contiguous()
|
||||
@@ -276,11 +337,6 @@ if __name__ == "__main__":
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
if args.fp16:
|
||||
pipe = pipe.to(dtype=torch.float16)
|
||||
if args.bf16:
|
||||
pipe = pipe.to(dtype=torch.bfloat16)
|
||||
|
||||
# We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
|
||||
# for users to specify variant when the default is not fp32 and they want to run with the correct default (which
|
||||
# is either fp16/bf16 here).
|
||||
|
||||
@@ -36,7 +36,7 @@ from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
|
||||
CTX = init_empty_weights if is_accelerate_available else nullcontext
|
||||
CTX = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
|
||||
TOKENIZER_MAX_LENGTH = 224
|
||||
|
||||
|
||||
@@ -31,12 +31,14 @@ python scripts/convert_flux_to_diffusers.py \
|
||||
--vae
|
||||
"""
|
||||
|
||||
CTX = init_empty_weights if is_accelerate_available else nullcontext
|
||||
CTX = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
|
||||
parser.add_argument("--filename", default="flux.safetensors", type=str)
|
||||
parser.add_argument("--checkpoint_path", default=None, type=str)
|
||||
parser.add_argument("--in_channels", type=int, default=64)
|
||||
parser.add_argument("--out_channels", type=int, default=None)
|
||||
parser.add_argument("--vae", action="store_true")
|
||||
parser.add_argument("--transformer", action="store_true")
|
||||
parser.add_argument("--output_path", type=str)
|
||||
@@ -279,10 +281,13 @@ def main(args):
|
||||
num_single_layers = 38
|
||||
inner_dim = 3072
|
||||
mlp_ratio = 4.0
|
||||
|
||||
converted_transformer_state_dict = convert_flux_transformer_checkpoint_to_diffusers(
|
||||
original_ckpt, num_layers, num_single_layers, inner_dim, mlp_ratio=mlp_ratio
|
||||
)
|
||||
transformer = FluxTransformer2DModel(guidance_embeds=has_guidance)
|
||||
transformer = FluxTransformer2DModel(
|
||||
in_channels=args.in_channels, out_channels=args.out_channels, guidance_embeds=has_guidance
|
||||
)
|
||||
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
|
||||
|
||||
print(
|
||||
|
||||
@@ -10,7 +10,7 @@ from diffusers import AutoencoderKLMochi, FlowMatchEulerDiscreteScheduler, Mochi
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
|
||||
CTX = init_empty_weights if is_accelerate_available else nullcontext
|
||||
CTX = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
|
||||
TOKENIZER_MAX_LENGTH = 256
|
||||
|
||||
|
||||
@@ -0,0 +1,185 @@
|
||||
"""
|
||||
A script to convert Stable Diffusion 3.5 ControlNet checkpoints to the Diffusers format.
|
||||
|
||||
Example:
|
||||
Convert a SD3.5 ControlNet checkpoint to Diffusers format using local file:
|
||||
```bash
|
||||
python scripts/convert_sd3_controlnet_to_diffusers.py \
|
||||
--checkpoint_path "path/to/local/sd3.5_large_controlnet_canny.safetensors" \
|
||||
--output_path "output/sd35-controlnet-canny" \
|
||||
--dtype "fp16" # optional, defaults to fp32
|
||||
```
|
||||
|
||||
Or download and convert from HuggingFace repository:
|
||||
```bash
|
||||
python scripts/convert_sd3_controlnet_to_diffusers.py \
|
||||
--original_state_dict_repo_id "stabilityai/stable-diffusion-3.5-controlnets" \
|
||||
--filename "sd3.5_large_controlnet_canny.safetensors" \
|
||||
--output_path "/raid/yiyi/sd35-controlnet-canny-diffusers" \
|
||||
--dtype "fp32" # optional, defaults to fp32
|
||||
```
|
||||
|
||||
Note:
|
||||
The script supports the following ControlNet types from SD3.5:
|
||||
- Canny edge detection
|
||||
- Depth estimation
|
||||
- Blur detection
|
||||
|
||||
The checkpoint files can be downloaded from:
|
||||
https://huggingface.co/stabilityai/stable-diffusion-3.5-controlnets
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from diffusers import SD3ControlNetModel
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--checkpoint_path", type=str, default=None, help="Path to local checkpoint file")
|
||||
parser.add_argument(
|
||||
"--original_state_dict_repo_id", type=str, default=None, help="HuggingFace repo ID containing the checkpoint"
|
||||
)
|
||||
parser.add_argument("--filename", type=str, default=None, help="Filename of the checkpoint in the HF repo")
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path to save the converted model")
|
||||
parser.add_argument(
|
||||
"--dtype", type=str, default="fp32", help="Data type for the converted model (fp16, bf16, or fp32)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def load_original_checkpoint(args):
|
||||
if args.original_state_dict_repo_id is not None:
|
||||
if args.filename is None:
|
||||
raise ValueError("When using `original_state_dict_repo_id`, `filename` must also be specified")
|
||||
print(f"Downloading checkpoint from {args.original_state_dict_repo_id}/{args.filename}")
|
||||
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
|
||||
elif args.checkpoint_path is not None:
|
||||
print(f"Loading checkpoint from local path: {args.checkpoint_path}")
|
||||
ckpt_path = args.checkpoint_path
|
||||
else:
|
||||
raise ValueError("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
|
||||
|
||||
original_state_dict = safetensors.torch.load_file(ckpt_path)
|
||||
return original_state_dict
|
||||
|
||||
|
||||
def convert_sd3_controlnet_checkpoint_to_diffusers(original_state_dict):
|
||||
converted_state_dict = {}
|
||||
|
||||
# Direct mappings for controlnet blocks
|
||||
for i in range(19): # 19 controlnet blocks
|
||||
converted_state_dict[f"controlnet_blocks.{i}.weight"] = original_state_dict[f"controlnet_blocks.{i}.weight"]
|
||||
converted_state_dict[f"controlnet_blocks.{i}.bias"] = original_state_dict[f"controlnet_blocks.{i}.bias"]
|
||||
|
||||
# Positional embeddings
|
||||
converted_state_dict["pos_embed_input.proj.weight"] = original_state_dict["pos_embed_input.proj.weight"]
|
||||
converted_state_dict["pos_embed_input.proj.bias"] = original_state_dict["pos_embed_input.proj.bias"]
|
||||
|
||||
# Time and text embeddings
|
||||
time_text_mappings = {
|
||||
"time_text_embed.timestep_embedder.linear_1.weight": "time_text_embed.timestep_embedder.linear_1.weight",
|
||||
"time_text_embed.timestep_embedder.linear_1.bias": "time_text_embed.timestep_embedder.linear_1.bias",
|
||||
"time_text_embed.timestep_embedder.linear_2.weight": "time_text_embed.timestep_embedder.linear_2.weight",
|
||||
"time_text_embed.timestep_embedder.linear_2.bias": "time_text_embed.timestep_embedder.linear_2.bias",
|
||||
"time_text_embed.text_embedder.linear_1.weight": "time_text_embed.text_embedder.linear_1.weight",
|
||||
"time_text_embed.text_embedder.linear_1.bias": "time_text_embed.text_embedder.linear_1.bias",
|
||||
"time_text_embed.text_embedder.linear_2.weight": "time_text_embed.text_embedder.linear_2.weight",
|
||||
"time_text_embed.text_embedder.linear_2.bias": "time_text_embed.text_embedder.linear_2.bias",
|
||||
}
|
||||
|
||||
for new_key, old_key in time_text_mappings.items():
|
||||
if old_key in original_state_dict:
|
||||
converted_state_dict[new_key] = original_state_dict[old_key]
|
||||
|
||||
# Transformer blocks
|
||||
for i in range(19):
|
||||
# Split QKV into separate Q, K, V
|
||||
qkv_weight = original_state_dict[f"transformer_blocks.{i}.attn.qkv.weight"]
|
||||
qkv_bias = original_state_dict[f"transformer_blocks.{i}.attn.qkv.bias"]
|
||||
q, k, v = torch.chunk(qkv_weight, 3, dim=0)
|
||||
q_bias, k_bias, v_bias = torch.chunk(qkv_bias, 3, dim=0)
|
||||
|
||||
block_mappings = {
|
||||
f"transformer_blocks.{i}.attn.to_q.weight": q,
|
||||
f"transformer_blocks.{i}.attn.to_q.bias": q_bias,
|
||||
f"transformer_blocks.{i}.attn.to_k.weight": k,
|
||||
f"transformer_blocks.{i}.attn.to_k.bias": k_bias,
|
||||
f"transformer_blocks.{i}.attn.to_v.weight": v,
|
||||
f"transformer_blocks.{i}.attn.to_v.bias": v_bias,
|
||||
# Output projections
|
||||
f"transformer_blocks.{i}.attn.to_out.0.weight": original_state_dict[
|
||||
f"transformer_blocks.{i}.attn.proj.weight"
|
||||
],
|
||||
f"transformer_blocks.{i}.attn.to_out.0.bias": original_state_dict[
|
||||
f"transformer_blocks.{i}.attn.proj.bias"
|
||||
],
|
||||
# Feed forward
|
||||
f"transformer_blocks.{i}.ff.net.0.proj.weight": original_state_dict[
|
||||
f"transformer_blocks.{i}.mlp.fc1.weight"
|
||||
],
|
||||
f"transformer_blocks.{i}.ff.net.0.proj.bias": original_state_dict[f"transformer_blocks.{i}.mlp.fc1.bias"],
|
||||
f"transformer_blocks.{i}.ff.net.2.weight": original_state_dict[f"transformer_blocks.{i}.mlp.fc2.weight"],
|
||||
f"transformer_blocks.{i}.ff.net.2.bias": original_state_dict[f"transformer_blocks.{i}.mlp.fc2.bias"],
|
||||
# Norms
|
||||
f"transformer_blocks.{i}.norm1.linear.weight": original_state_dict[
|
||||
f"transformer_blocks.{i}.adaLN_modulation.1.weight"
|
||||
],
|
||||
f"transformer_blocks.{i}.norm1.linear.bias": original_state_dict[
|
||||
f"transformer_blocks.{i}.adaLN_modulation.1.bias"
|
||||
],
|
||||
}
|
||||
converted_state_dict.update(block_mappings)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def main(args):
|
||||
original_ckpt = load_original_checkpoint(args)
|
||||
original_dtype = next(iter(original_ckpt.values())).dtype
|
||||
|
||||
# Initialize dtype with fp32 as default
|
||||
if args.dtype == "fp16":
|
||||
dtype = torch.float16
|
||||
elif args.dtype == "bf16":
|
||||
dtype = torch.bfloat16
|
||||
elif args.dtype == "fp32":
|
||||
dtype = torch.float32
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype: {args.dtype}. Must be one of: fp16, bf16, fp32")
|
||||
|
||||
if dtype != original_dtype:
|
||||
print(
|
||||
f"Converting checkpoint from {original_dtype} to {dtype}. This can lead to unexpected results, proceed with caution."
|
||||
)
|
||||
|
||||
converted_controlnet_state_dict = convert_sd3_controlnet_checkpoint_to_diffusers(original_ckpt)
|
||||
|
||||
controlnet = SD3ControlNetModel(
|
||||
patch_size=2,
|
||||
in_channels=16,
|
||||
num_layers=19,
|
||||
attention_head_dim=64,
|
||||
num_attention_heads=38,
|
||||
joint_attention_dim=None,
|
||||
caption_projection_dim=2048,
|
||||
pooled_projection_dim=2048,
|
||||
out_channels=16,
|
||||
pos_embed_max_size=None,
|
||||
pos_embed_type=None,
|
||||
use_pos_embed=False,
|
||||
force_zeros_for_pooled_projection=False,
|
||||
)
|
||||
|
||||
controlnet.load_state_dict(converted_controlnet_state_dict, strict=True)
|
||||
|
||||
print(f"Saving SD3 ControlNet in Diffusers format in {args.output_path}.")
|
||||
controlnet.to(dtype).save_pretrained(args.output_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(args)
|
||||
@@ -11,7 +11,7 @@ from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
|
||||
CTX = init_empty_weights if is_accelerate_available else nullcontext
|
||||
CTX = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--checkpoint_path", type=str)
|
||||
|
||||
@@ -130,7 +130,7 @@ _deps = [
|
||||
"regex!=2019.12.17",
|
||||
"requests",
|
||||
"tensorboard",
|
||||
"torch>=1.4,<2.5.0",
|
||||
"torch>=1.4",
|
||||
"torchvision",
|
||||
"transformers>=4.41.2",
|
||||
"urllib3<=2.0.0",
|
||||
|
||||
@@ -107,6 +107,7 @@ else:
|
||||
"ModelMixin",
|
||||
"MotionAdapter",
|
||||
"MultiAdapter",
|
||||
"MultiControlNetModel",
|
||||
"PixArtTransformer2DModel",
|
||||
"PriorTransformer",
|
||||
"SD3ControlNetModel",
|
||||
@@ -268,12 +269,16 @@ else:
|
||||
"CogVideoXVideoToVideoPipeline",
|
||||
"CogView3PlusPipeline",
|
||||
"CycleDiffusionPipeline",
|
||||
"FluxControlImg2ImgPipeline",
|
||||
"FluxControlNetImg2ImgPipeline",
|
||||
"FluxControlNetInpaintPipeline",
|
||||
"FluxControlNetPipeline",
|
||||
"FluxControlPipeline",
|
||||
"FluxFillPipeline",
|
||||
"FluxImg2ImgPipeline",
|
||||
"FluxInpaintPipeline",
|
||||
"FluxPipeline",
|
||||
"FluxPriorReduxPipeline",
|
||||
"HunyuanDiTControlNetPipeline",
|
||||
"HunyuanDiTPAGPipeline",
|
||||
"HunyuanDiTPipeline",
|
||||
@@ -320,6 +325,7 @@ else:
|
||||
"PixArtAlphaPipeline",
|
||||
"PixArtSigmaPAGPipeline",
|
||||
"PixArtSigmaPipeline",
|
||||
"ReduxImageEncoder",
|
||||
"SemanticStableDiffusionPipeline",
|
||||
"ShapEImg2ImgPipeline",
|
||||
"ShapEPipeline",
|
||||
@@ -332,6 +338,7 @@ else:
|
||||
"StableDiffusion3ControlNetPipeline",
|
||||
"StableDiffusion3Img2ImgPipeline",
|
||||
"StableDiffusion3InpaintPipeline",
|
||||
"StableDiffusion3PAGImg2ImgPipeline",
|
||||
"StableDiffusion3PAGPipeline",
|
||||
"StableDiffusion3Pipeline",
|
||||
"StableDiffusionAdapterPipeline",
|
||||
@@ -592,6 +599,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ModelMixin,
|
||||
MotionAdapter,
|
||||
MultiAdapter,
|
||||
MultiControlNetModel,
|
||||
PixArtTransformer2DModel,
|
||||
PriorTransformer,
|
||||
SD3ControlNetModel,
|
||||
@@ -732,12 +740,16 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
CogVideoXVideoToVideoPipeline,
|
||||
CogView3PlusPipeline,
|
||||
CycleDiffusionPipeline,
|
||||
FluxControlImg2ImgPipeline,
|
||||
FluxControlNetImg2ImgPipeline,
|
||||
FluxControlNetInpaintPipeline,
|
||||
FluxControlNetPipeline,
|
||||
FluxControlPipeline,
|
||||
FluxFillPipeline,
|
||||
FluxImg2ImgPipeline,
|
||||
FluxInpaintPipeline,
|
||||
FluxPipeline,
|
||||
FluxPriorReduxPipeline,
|
||||
HunyuanDiTControlNetPipeline,
|
||||
HunyuanDiTPAGPipeline,
|
||||
HunyuanDiTPipeline,
|
||||
@@ -784,6 +796,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
PixArtAlphaPipeline,
|
||||
PixArtSigmaPAGPipeline,
|
||||
PixArtSigmaPipeline,
|
||||
ReduxImageEncoder,
|
||||
SemanticStableDiffusionPipeline,
|
||||
ShapEImg2ImgPipeline,
|
||||
ShapEPipeline,
|
||||
@@ -795,6 +808,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusion3ControlNetPipeline,
|
||||
StableDiffusion3Img2ImgPipeline,
|
||||
StableDiffusion3InpaintPipeline,
|
||||
StableDiffusion3PAGImg2ImgPipeline,
|
||||
StableDiffusion3PAGPipeline,
|
||||
StableDiffusion3Pipeline,
|
||||
StableDiffusionAdapterPipeline,
|
||||
|
||||
@@ -170,7 +170,7 @@ class ConfigMixin:
|
||||
|
||||
if push_to_hub:
|
||||
commit_message = kwargs.pop("commit_message", None)
|
||||
private = kwargs.pop("private", False)
|
||||
private = kwargs.pop("private", None)
|
||||
create_pr = kwargs.pop("create_pr", False)
|
||||
token = kwargs.pop("token", None)
|
||||
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
||||
@@ -347,6 +347,7 @@ class ConfigMixin:
|
||||
_ = kwargs.pop("mirror", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
user_agent = kwargs.pop("user_agent", {})
|
||||
dduf_reader = kwargs.pop("dduf_reader", None)
|
||||
|
||||
user_agent = {**user_agent, "file_type": "config"}
|
||||
user_agent = http_user_agent(user_agent)
|
||||
@@ -358,8 +359,22 @@ class ConfigMixin:
|
||||
"`self.config_name` is not defined. Note that one should not load a config from "
|
||||
"`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
|
||||
)
|
||||
|
||||
if os.path.isfile(pretrained_model_name_or_path):
|
||||
# Custom path for now
|
||||
if dduf_reader:
|
||||
if subfolder is not None:
|
||||
if dduf_reader.has_file(os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)):
|
||||
config_file = os.path.join(subfolder, cls.config_name)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"We did not manage to find the file {os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)} in the archive. We only have the following files {dduf_reader.files}"
|
||||
)
|
||||
elif dduf_reader.has_file(os.path.join(pretrained_model_name_or_path, cls.config_name)):
|
||||
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"We did not manage to find the file {os.path.join(pretrained_model_name_or_path, cls.config_name)} in the archive. We only have the following files {dduf_reader.files}"
|
||||
)
|
||||
elif os.path.isfile(pretrained_model_name_or_path):
|
||||
config_file = pretrained_model_name_or_path
|
||||
elif os.path.isdir(pretrained_model_name_or_path):
|
||||
if subfolder is not None and os.path.isfile(
|
||||
@@ -426,10 +441,8 @@ class ConfigMixin:
|
||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||
f"containing a {cls.config_name} file"
|
||||
)
|
||||
|
||||
try:
|
||||
# Load config dict
|
||||
config_dict = cls._dict_from_json_file(config_file)
|
||||
config_dict = cls._dict_from_json_file(config_file, dduf_reader=dduf_reader)
|
||||
|
||||
commit_hash = extract_commit_hash(config_file)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
@@ -552,9 +565,12 @@ class ConfigMixin:
|
||||
return init_dict, unused_kwargs, hidden_config_dict
|
||||
|
||||
@classmethod
|
||||
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
|
||||
with open(json_file, "r", encoding="utf-8") as reader:
|
||||
text = reader.read()
|
||||
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike], dduf_reader=None):
|
||||
if dduf_reader:
|
||||
text = dduf_reader.read_file(json_file, encoding="utf-8")
|
||||
else:
|
||||
with open(json_file, "r", encoding="utf-8") as reader:
|
||||
text = reader.read()
|
||||
return json.loads(text)
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
@@ -38,7 +38,7 @@ deps = {
|
||||
"regex": "regex!=2019.12.17",
|
||||
"requests": "requests",
|
||||
"tensorboard": "tensorboard",
|
||||
"torch": "torch>=1.4,<2.5.0",
|
||||
"torch": "torch>=1.4",
|
||||
"torchvision": "torchvision",
|
||||
"transformers": "transformers>=4.41.2",
|
||||
"urllib3": "urllib3<=2.0.0",
|
||||
|
||||
@@ -795,13 +795,11 @@ class VaeImageProcessor(ConfigMixin):
|
||||
The final image with the overlay applied.
|
||||
"""
|
||||
|
||||
width, height = image.width, image.height
|
||||
|
||||
init_image = self.resize(init_image, width=width, height=height)
|
||||
mask = self.resize(mask, width=width, height=height)
|
||||
width, height = init_image.width, init_image.height
|
||||
|
||||
init_image_masked = PIL.Image.new("RGBa", (width, height))
|
||||
init_image_masked.paste(init_image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert("L")))
|
||||
|
||||
init_image_masked = init_image_masked.convert("RGBA")
|
||||
|
||||
if crop_coords is not None:
|
||||
|
||||
@@ -68,6 +68,7 @@ if is_torch_available():
|
||||
"LoraLoaderMixin",
|
||||
"FluxLoraLoaderMixin",
|
||||
"CogVideoXLoraLoaderMixin",
|
||||
"Mochi1LoraLoaderMixin",
|
||||
]
|
||||
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
|
||||
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
|
||||
@@ -88,6 +89,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
CogVideoXLoraLoaderMixin,
|
||||
FluxLoraLoaderMixin,
|
||||
LoraLoaderMixin,
|
||||
Mochi1LoraLoaderMixin,
|
||||
SD3LoraLoaderMixin,
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
StableDiffusionXLLoraLoaderMixin,
|
||||
|
||||
@@ -33,16 +33,14 @@ from .unet_loader_utils import _maybe_expand_lora_scales
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPVisionModelWithProjection,
|
||||
)
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from ..models.attention_processor import (
|
||||
AttnProcessor,
|
||||
AttnProcessor2_0,
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
IPAdapterXFormersAttnProcessor,
|
||||
)
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -76,7 +74,7 @@ class IPAdapterMixin:
|
||||
list is passed, it should have the same length as `weight_name`.
|
||||
weight_name (`str` or `List[str]`):
|
||||
The name of the weight file to load. If a list is passed, it should have the same length as
|
||||
`weight_name`.
|
||||
`subfolder`.
|
||||
image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
|
||||
The subfolder location of the image encoder within a larger model repository on the Hub or locally.
|
||||
Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
|
||||
@@ -284,7 +282,9 @@ class IPAdapterMixin:
|
||||
scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)
|
||||
|
||||
for attn_name, attn_processor in unet.attn_processors.items():
|
||||
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
|
||||
if isinstance(
|
||||
attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
|
||||
):
|
||||
if len(scale_configs) != len(attn_processor.scale):
|
||||
raise ValueError(
|
||||
f"Cannot assign {len(scale_configs)} scale_configs to "
|
||||
@@ -342,7 +342,9 @@ class IPAdapterMixin:
|
||||
)
|
||||
attn_procs[name] = (
|
||||
attn_processor_class
|
||||
if isinstance(value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0))
|
||||
if isinstance(
|
||||
value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
|
||||
)
|
||||
else value.__class__()
|
||||
)
|
||||
self.unet.set_attn_processor(attn_procs)
|
||||
|
||||
@@ -636,10 +636,15 @@ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
|
||||
block_num = re.search(r"single_blocks\.(\d+)", old_key).group(1)
|
||||
new_key = f"transformer.single_transformer_blocks.{block_num}"
|
||||
|
||||
if "proj_lora1" in old_key or "proj_lora2" in old_key:
|
||||
if "proj_lora" in old_key:
|
||||
new_key += ".proj_out"
|
||||
elif "qkv_lora1" in old_key or "qkv_lora2" in old_key:
|
||||
new_key += ".norm.linear"
|
||||
elif "qkv_lora" in old_key and "up" not in old_key:
|
||||
handle_qkv(
|
||||
old_state_dict,
|
||||
new_state_dict,
|
||||
old_key,
|
||||
[f"transformer.single_transformer_blocks.{block_num}.norm.linear"],
|
||||
)
|
||||
|
||||
if "down" in old_key:
|
||||
new_key += ".lora_A.weight"
|
||||
|
||||
@@ -298,8 +298,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
if not only_text_encoder:
|
||||
# Load the layers corresponding to UNet.
|
||||
logger.info(f"Loading {cls.unet_name}.")
|
||||
unet.load_attn_procs(
|
||||
unet.load_lora_adapter(
|
||||
state_dict,
|
||||
prefix=cls.unet_name,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
@@ -827,8 +828,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
if not only_text_encoder:
|
||||
# Load the layers corresponding to UNet.
|
||||
logger.info(f"Loading {cls.unet_name}.")
|
||||
unet.load_attn_procs(
|
||||
unet.load_lora_adapter(
|
||||
state_dict,
|
||||
prefix=cls.unet_name,
|
||||
network_alphas=network_alphas,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
@@ -2362,7 +2364,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
|
||||
class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
r"""
|
||||
Load LoRA layers into [`CogVideoXTransformer3DModel`]. Specific to [`CogVideoX`].
|
||||
Load LoRA layers into [`CogVideoXTransformer3DModel`]. Specific to [`CogVideoXPipeline`].
|
||||
"""
|
||||
|
||||
_lora_loadable_modules = ["transformer"]
|
||||
@@ -2667,6 +2669,314 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
super().unfuse_lora(components=components)
|
||||
|
||||
|
||||
class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
||||
r"""
|
||||
Load LoRA layers into [`MochiTransformer3DModel`]. Specific to [`MochiPipeline`].
|
||||
"""
|
||||
|
||||
_lora_loadable_modules = ["transformer"]
|
||||
transformer_name = TRANSFORMER_NAME
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Return state dict for lora weights and the network alphas.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
||||
|
||||
This function is experimental and might change in the future.
|
||||
|
||||
</Tip>
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
|
||||
"""
|
||||
# Load the main state dict first which has the LoRA layers for either of
|
||||
# transformer and text encoder or both.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
state_dict = _fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
||||
weight_name=weight_name,
|
||||
use_safetensors=use_safetensors,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
allow_pickle=allow_pickle,
|
||||
)
|
||||
|
||||
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
||||
if is_dora_scale_present:
|
||||
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
return state_dict
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
def load_lora_weights(
|
||||
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
||||
):
|
||||
"""
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
||||
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
|
||||
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
||||
dict is loaded into `self.transformer`.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# if a dict is passed, copy it instead of modifying it inplace
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel
|
||||
def load_lora_into_transformer(
|
||||
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
|
||||
Parameters:
|
||||
state_dict (`dict`):
|
||||
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
||||
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
||||
encoder lora layers.
|
||||
transformer (`CogVideoXTransformer3DModel`):
|
||||
The Transformer model to load the LoRA layers into.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# Load the layers corresponding to transformer.
|
||||
logger.info(f"Loading {cls.transformer_name}.")
|
||||
transformer.load_lora_adapter(
|
||||
state_dict,
|
||||
network_alphas=None,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
):
|
||||
r"""
|
||||
Save the LoRA parameters corresponding to the UNet and text encoder.
|
||||
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
||||
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
||||
State dict of the LoRA layers corresponding to the `transformer`.
|
||||
is_main_process (`bool`, *optional*, defaults to `True`):
|
||||
Whether the process calling this is the main process or not. Useful during distributed training and you
|
||||
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
||||
process to avoid race conditions.
|
||||
save_function (`Callable`):
|
||||
The function to use to save the state dictionary. Useful during distributed training when you need to
|
||||
replace `torch.save` with another method. Can be configured with the environment variable
|
||||
`DIFFUSERS_SAVE_MODE`.
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
||||
"""
|
||||
state_dict = {}
|
||||
|
||||
if not transformer_lora_layers:
|
||||
raise ValueError("You must pass `transformer_lora_layers`.")
|
||||
|
||||
if transformer_lora_layers:
|
||||
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
|
||||
|
||||
# Save the model
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
save_directory=save_directory,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.fuse_lora with unet->transformer
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer", "text_encoder"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental API.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
||||
lora_scale (`float`, defaults to 1.0):
|
||||
Controls how much to influence the outputs with the LoRA parameters.
|
||||
safe_fusing (`bool`, defaults to `False`):
|
||||
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
pipeline.fuse_lora(lora_scale=0.7)
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.unfuse_lora with unet->transformer
|
||||
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental API.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
unfuse_text_encoder (`bool`, defaults to `True`):
|
||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||
LoRA parameters then it won't have any effect.
|
||||
"""
|
||||
super().unfuse_lora(components=components)
|
||||
|
||||
|
||||
class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
|
||||
|
||||
@@ -13,9 +13,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..utils import (
|
||||
@@ -48,6 +52,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
|
||||
"SD3Transformer2DModel": lambda model_cls, weights: weights,
|
||||
"FluxTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"CogVideoXTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"MochiTransformer3DModel": lambda model_cls, weights: weights,
|
||||
}
|
||||
|
||||
|
||||
@@ -189,40 +194,45 @@ class PeftAdapterMixin:
|
||||
user_agent=user_agent,
|
||||
allow_pickle=allow_pickle,
|
||||
)
|
||||
if network_alphas is not None and prefix is None:
|
||||
raise ValueError("`network_alphas` cannot be None when `prefix` is None.")
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
transformer_keys = [k for k in keys if k.startswith(prefix)]
|
||||
if len(transformer_keys) > 0:
|
||||
state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in transformer_keys}
|
||||
if prefix is not None:
|
||||
keys = list(state_dict.keys())
|
||||
model_keys = [k for k in keys if k.startswith(f"{prefix}.")]
|
||||
if len(model_keys) > 0:
|
||||
state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in model_keys}
|
||||
|
||||
if len(state_dict) > 0:
|
||||
if adapter_name in getattr(self, "peft_config", {}):
|
||||
raise ValueError(
|
||||
f"Adapter name {adapter_name} already in use in the model - please select a new adapter name."
|
||||
)
|
||||
|
||||
if len(state_dict.keys()) > 0:
|
||||
# check with first key if is not in peft format
|
||||
first_key = next(iter(state_dict.keys()))
|
||||
if "lora_A" not in first_key:
|
||||
state_dict = convert_unet_state_dict_to_peft(state_dict)
|
||||
|
||||
if adapter_name in getattr(self, "peft_config", {}):
|
||||
raise ValueError(
|
||||
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
|
||||
)
|
||||
|
||||
rank = {}
|
||||
for key, val in state_dict.items():
|
||||
if "lora_B" in key:
|
||||
rank[key] = val.shape[1]
|
||||
|
||||
if network_alphas is not None and len(network_alphas) >= 1:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(f"{prefix}.")]
|
||||
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
@@ -276,6 +286,69 @@ class PeftAdapterMixin:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
def save_lora_adapter(
|
||||
self,
|
||||
save_directory,
|
||||
adapter_name: str = "default",
|
||||
upcast_before_saving: bool = False,
|
||||
safe_serialization: bool = True,
|
||||
weight_name: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Save the LoRA parameters corresponding to the underlying model.
|
||||
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
||||
adapter_name: (`str`, defaults to "default"): The name of the adapter to serialize. Useful when the
|
||||
underlying model has multiple adapters loaded.
|
||||
upcast_before_saving (`bool`, defaults to `False`):
|
||||
Whether to cast the underlying model to `torch.float32` before serialization.
|
||||
save_function (`Callable`):
|
||||
The function to use to save the state dictionary. Useful during distributed training when you need to
|
||||
replace `torch.save` with another method. Can be configured with the environment variable
|
||||
`DIFFUSERS_SAVE_MODE`.
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
||||
weight_name: (`str`, *optional*, defaults to `None`): Name of the file to serialize the state dict with.
|
||||
"""
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
|
||||
from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
|
||||
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(self)
|
||||
|
||||
if adapter_name not in getattr(self, "peft_config", {}):
|
||||
raise ValueError(f"Adapter name {adapter_name} not found in the model.")
|
||||
|
||||
lora_layers_to_save = get_peft_model_state_dict(
|
||||
self.to(dtype=torch.float32 if upcast_before_saving else None), adapter_name=adapter_name
|
||||
)
|
||||
if os.path.isfile(save_directory):
|
||||
raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
|
||||
if safe_serialization:
|
||||
|
||||
def save_function(weights, filename):
|
||||
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
|
||||
|
||||
else:
|
||||
save_function = torch.save
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
if weight_name is None:
|
||||
if safe_serialization:
|
||||
weight_name = LORA_WEIGHT_NAME_SAFE
|
||||
else:
|
||||
weight_name = LORA_WEIGHT_NAME
|
||||
|
||||
# TODO: we could consider saving the `peft_config` as well.
|
||||
save_path = Path(save_directory, weight_name).as_posix()
|
||||
save_function(lora_layers_to_save, save_path)
|
||||
logger.info(f"Model weights saved in {save_path}")
|
||||
|
||||
def set_adapters(
|
||||
self,
|
||||
adapter_names: Union[List[str], str],
|
||||
|
||||
@@ -269,6 +269,7 @@ class FromOriginalModelMixin:
|
||||
pretrained_model_name_or_path=default_pretrained_model_config_name,
|
||||
subfolder=subfolder,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
)
|
||||
expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)
|
||||
|
||||
|
||||
@@ -62,7 +62,14 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias",
|
||||
"xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias",
|
||||
"upscale": "model.diffusion_model.input_blocks.10.0.skip_connection.bias",
|
||||
"controlnet": "control_model.time_embed.0.weight",
|
||||
"controlnet": [
|
||||
"control_model.time_embed.0.weight",
|
||||
"controlnet_cond_embedding.conv_in.weight",
|
||||
],
|
||||
# TODO: find non-Diffusers keys for controlnet_xl
|
||||
"controlnet_xl": "add_embedding.linear_1.weight",
|
||||
"controlnet_xl_large": "down_blocks.1.attentions.0.transformer_blocks.0.attn1.to_k.weight",
|
||||
"controlnet_xl_mid": "down_blocks.1.attentions.0.norm.weight",
|
||||
"playground-v2-5": "edm_mean",
|
||||
"inpainting": "model.diffusion_model.input_blocks.0.0.weight",
|
||||
"clip": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
|
||||
@@ -96,6 +103,9 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"inpainting": {"pretrained_model_name_or_path": "stable-diffusion-v1-5/stable-diffusion-inpainting"},
|
||||
"inpainting_v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-inpainting"},
|
||||
"controlnet": {"pretrained_model_name_or_path": "lllyasviel/control_v11p_sd15_canny"},
|
||||
"controlnet_xl_large": {"pretrained_model_name_or_path": "diffusers/controlnet-canny-sdxl-1.0"},
|
||||
"controlnet_xl_mid": {"pretrained_model_name_or_path": "diffusers/controlnet-canny-sdxl-1.0-mid"},
|
||||
"controlnet_xl_small": {"pretrained_model_name_or_path": "diffusers/controlnet-canny-sdxl-1.0-small"},
|
||||
"v2": {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-1"},
|
||||
"v1": {"pretrained_model_name_or_path": "stable-diffusion-v1-5/stable-diffusion-v1-5"},
|
||||
"stable_cascade_stage_b": {"pretrained_model_name_or_path": "stabilityai/stable-cascade", "subfolder": "decoder"},
|
||||
@@ -117,6 +127,9 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"sd35_large": {
|
||||
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-3.5-large",
|
||||
},
|
||||
"sd35_medium": {
|
||||
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-3.5-medium",
|
||||
},
|
||||
"animatediff_v1": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5"},
|
||||
"animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"},
|
||||
"animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"},
|
||||
@@ -481,8 +494,16 @@ def infer_diffusers_model_type(checkpoint):
|
||||
elif CHECKPOINT_KEY_NAMES["upscale"] in checkpoint:
|
||||
model_type = "upscale"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["controlnet"] in checkpoint:
|
||||
model_type = "controlnet"
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["controlnet"]):
|
||||
if CHECKPOINT_KEY_NAMES["controlnet_xl"] in checkpoint:
|
||||
if CHECKPOINT_KEY_NAMES["controlnet_xl_large"] in checkpoint:
|
||||
model_type = "controlnet_xl_large"
|
||||
elif CHECKPOINT_KEY_NAMES["controlnet_xl_mid"] in checkpoint:
|
||||
model_type = "controlnet_xl_mid"
|
||||
else:
|
||||
model_type = "controlnet_xl_small"
|
||||
else:
|
||||
model_type = "controlnet"
|
||||
|
||||
elif (
|
||||
CHECKPOINT_KEY_NAMES["stable_cascade_stage_c"] in checkpoint
|
||||
@@ -509,7 +530,10 @@ def infer_diffusers_model_type(checkpoint):
|
||||
model_type = "stable_cascade_stage_b"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["sd3"]].shape[-1] == 9216:
|
||||
model_type = "sd3"
|
||||
if checkpoint["model.diffusion_model.pos_embed"].shape[1] == 36864:
|
||||
model_type = "sd3"
|
||||
elif checkpoint["model.diffusion_model.pos_embed"].shape[1] == 147456:
|
||||
model_type = "sd35_medium"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["sd35_large"] in checkpoint:
|
||||
model_type = "sd35_large"
|
||||
@@ -1072,6 +1096,9 @@ def convert_controlnet_checkpoint(
|
||||
config,
|
||||
**kwargs,
|
||||
):
|
||||
# Return checkpoint if it's already been converted
|
||||
if "time_embedding.linear_1.weight" in checkpoint:
|
||||
return checkpoint
|
||||
# Some controlnet ckpt files are distributed independently from the rest of the
|
||||
# model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
|
||||
if "time_embed.0.weight" in checkpoint:
|
||||
|
||||
@@ -497,19 +497,19 @@ class TextualInversionLoaderMixin:
|
||||
# load embeddings of text_encoder 1 (CLIP ViT-L/14)
|
||||
pipeline.load_textual_inversion(
|
||||
state_dict["clip_l"],
|
||||
token=["<s0>", "<s1>"],
|
||||
tokens=["<s0>", "<s1>"],
|
||||
text_encoder=pipeline.text_encoder,
|
||||
tokenizer=pipeline.tokenizer,
|
||||
)
|
||||
# load embeddings of text_encoder 2 (CLIP ViT-G/14)
|
||||
pipeline.load_textual_inversion(
|
||||
state_dict["clip_g"],
|
||||
token=["<s0>", "<s1>"],
|
||||
tokens=["<s0>", "<s1>"],
|
||||
text_encoder=pipeline.text_encoder_2,
|
||||
tokenizer=pipeline.tokenizer_2,
|
||||
)
|
||||
|
||||
# Unload explicitly from both text encoders abd tokenizers
|
||||
# Unload explicitly from both text encoders and tokenizers
|
||||
pipeline.unload_textual_inversion(
|
||||
tokens=["<s0>", "<s1>"], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer
|
||||
)
|
||||
|
||||
@@ -36,6 +36,7 @@ from ..utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
_get_model_file,
|
||||
convert_unet_state_dict_to_peft,
|
||||
deprecate,
|
||||
get_adapter_name,
|
||||
get_peft_kwargs,
|
||||
is_accelerate_available,
|
||||
@@ -209,6 +210,10 @@ class UNet2DConditionLoadersMixin:
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
|
||||
if is_lora:
|
||||
deprecation_message = "Using the `load_attn_procs()` method has been deprecated and will be removed in a future version. Please use `load_lora_adapter()`."
|
||||
deprecate("load_attn_procs", "0.40.0", deprecation_message)
|
||||
|
||||
if is_custom_diffusion:
|
||||
attn_processors = self._process_custom_diffusion(state_dict=state_dict)
|
||||
elif is_lora:
|
||||
@@ -765,6 +770,7 @@ class UNet2DConditionLoadersMixin:
|
||||
from ..models.attention_processor import (
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
IPAdapterXFormersAttnProcessor,
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
@@ -804,11 +810,15 @@ class UNet2DConditionLoadersMixin:
|
||||
if cross_attention_dim is None or "motion_modules" in name:
|
||||
attn_processor_class = self.attn_processors[name].__class__
|
||||
attn_procs[name] = attn_processor_class()
|
||||
|
||||
else:
|
||||
attn_processor_class = (
|
||||
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
|
||||
)
|
||||
if "XFormers" in str(self.attn_processors[name].__class__):
|
||||
attn_processor_class = IPAdapterXFormersAttnProcessor
|
||||
else:
|
||||
attn_processor_class = (
|
||||
IPAdapterAttnProcessor2_0
|
||||
if hasattr(F, "scaled_dot_product_attention")
|
||||
else IPAdapterAttnProcessor
|
||||
)
|
||||
num_image_text_embeds = []
|
||||
for state_dict in state_dicts:
|
||||
if "proj.weight" in state_dict["image_proj"]:
|
||||
|
||||
@@ -318,7 +318,10 @@ class Attention(nn.Module):
|
||||
XFormersAttnAddedKVProcessor,
|
||||
),
|
||||
)
|
||||
|
||||
is_ip_adapter = hasattr(self, "processor") and isinstance(
|
||||
self.processor,
|
||||
(IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor),
|
||||
)
|
||||
if use_memory_efficient_attention_xformers:
|
||||
if is_added_kv_processor and is_custom_diffusion:
|
||||
raise NotImplementedError(
|
||||
@@ -368,6 +371,19 @@ class Attention(nn.Module):
|
||||
"Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
|
||||
)
|
||||
processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
|
||||
elif is_ip_adapter:
|
||||
processor = IPAdapterXFormersAttnProcessor(
|
||||
hidden_size=self.processor.hidden_size,
|
||||
cross_attention_dim=self.processor.cross_attention_dim,
|
||||
num_tokens=self.processor.num_tokens,
|
||||
scale=self.processor.scale,
|
||||
attention_op=attention_op,
|
||||
)
|
||||
processor.load_state_dict(self.processor.state_dict())
|
||||
if hasattr(self.processor, "to_k_ip"):
|
||||
processor.to(
|
||||
device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype
|
||||
)
|
||||
else:
|
||||
processor = XFormersAttnProcessor(attention_op=attention_op)
|
||||
else:
|
||||
@@ -386,6 +402,18 @@ class Attention(nn.Module):
|
||||
processor.load_state_dict(self.processor.state_dict())
|
||||
if hasattr(self.processor, "to_k_custom_diffusion"):
|
||||
processor.to(self.processor.to_k_custom_diffusion.weight.device)
|
||||
elif is_ip_adapter:
|
||||
processor = IPAdapterAttnProcessor2_0(
|
||||
hidden_size=self.processor.hidden_size,
|
||||
cross_attention_dim=self.processor.cross_attention_dim,
|
||||
num_tokens=self.processor.num_tokens,
|
||||
scale=self.processor.scale,
|
||||
)
|
||||
processor.load_state_dict(self.processor.state_dict())
|
||||
if hasattr(self.processor, "to_k_ip"):
|
||||
processor.to(
|
||||
device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype
|
||||
)
|
||||
else:
|
||||
# set attention processor
|
||||
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
||||
@@ -1143,6 +1171,7 @@ class PAGJointAttnProcessor2_0:
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
residual = hidden_states
|
||||
|
||||
@@ -4542,6 +4571,238 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class IPAdapterXFormersAttnProcessor(torch.nn.Module):
|
||||
r"""
|
||||
Attention processor for IP-Adapter using xFormers.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`):
|
||||
The hidden size of the attention layer.
|
||||
cross_attention_dim (`int`):
|
||||
The number of channels in the `encoder_hidden_states`.
|
||||
num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
|
||||
The context length of the image features.
|
||||
scale (`float` or `List[float]`, defaults to 1.0):
|
||||
the weight scale of image prompt.
|
||||
attention_op (`Callable`, *optional*, defaults to `None`):
|
||||
The base
|
||||
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
|
||||
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
|
||||
operator.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
cross_attention_dim=None,
|
||||
num_tokens=(4,),
|
||||
scale=1.0,
|
||||
attention_op: Optional[Callable] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.attention_op = attention_op
|
||||
|
||||
if not isinstance(num_tokens, (tuple, list)):
|
||||
num_tokens = [num_tokens]
|
||||
self.num_tokens = num_tokens
|
||||
|
||||
if not isinstance(scale, list):
|
||||
scale = [scale] * len(num_tokens)
|
||||
if len(scale) != len(num_tokens):
|
||||
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
|
||||
self.scale = scale
|
||||
|
||||
self.to_k_ip = nn.ModuleList(
|
||||
[nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) for _ in range(len(num_tokens))]
|
||||
)
|
||||
self.to_v_ip = nn.ModuleList(
|
||||
[nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) for _ in range(len(num_tokens))]
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
scale: float = 1.0,
|
||||
ip_adapter_masks: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
# separate ip_hidden_states from encoder_hidden_states
|
||||
if encoder_hidden_states is not None:
|
||||
if isinstance(encoder_hidden_states, tuple):
|
||||
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
|
||||
else:
|
||||
deprecation_message = (
|
||||
"You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release."
|
||||
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning."
|
||||
)
|
||||
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
|
||||
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
|
||||
encoder_hidden_states, ip_hidden_states = (
|
||||
encoder_hidden_states[:, :end_pos, :],
|
||||
[encoder_hidden_states[:, end_pos:, :]],
|
||||
)
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# expand our mask's singleton query_tokens dimension:
|
||||
# [batch*heads, 1, key_tokens] ->
|
||||
# [batch*heads, query_tokens, key_tokens]
|
||||
# so that it can be added as a bias onto the attention scores that xformers computes:
|
||||
# [batch*heads, query_tokens, key_tokens]
|
||||
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
|
||||
_, query_tokens, _ = hidden_states.shape
|
||||
attention_mask = attention_mask.expand(-1, query_tokens, -1)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
query = attn.head_to_batch_dim(query).contiguous()
|
||||
key = attn.head_to_batch_dim(key).contiguous()
|
||||
value = attn.head_to_batch_dim(value).contiguous()
|
||||
|
||||
hidden_states = xformers.ops.memory_efficient_attention(
|
||||
query, key, value, attn_bias=attention_mask, op=self.attention_op
|
||||
)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
if ip_hidden_states:
|
||||
if ip_adapter_masks is not None:
|
||||
if not isinstance(ip_adapter_masks, List):
|
||||
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
|
||||
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
|
||||
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
|
||||
raise ValueError(
|
||||
f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
|
||||
f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
|
||||
f"({len(ip_hidden_states)})"
|
||||
)
|
||||
else:
|
||||
for index, (mask, scale, ip_state) in enumerate(
|
||||
zip(ip_adapter_masks, self.scale, ip_hidden_states)
|
||||
):
|
||||
if mask is None:
|
||||
continue
|
||||
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
|
||||
raise ValueError(
|
||||
"Each element of the ip_adapter_masks array should be a tensor with shape "
|
||||
"[1, num_images_for_ip_adapter, height, width]."
|
||||
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
|
||||
)
|
||||
if mask.shape[1] != ip_state.shape[1]:
|
||||
raise ValueError(
|
||||
f"Number of masks ({mask.shape[1]}) does not match "
|
||||
f"number of ip images ({ip_state.shape[1]}) at index {index}"
|
||||
)
|
||||
if isinstance(scale, list) and not len(scale) == mask.shape[1]:
|
||||
raise ValueError(
|
||||
f"Number of masks ({mask.shape[1]}) does not match "
|
||||
f"number of scales ({len(scale)}) at index {index}"
|
||||
)
|
||||
else:
|
||||
ip_adapter_masks = [None] * len(self.scale)
|
||||
|
||||
# for ip-adapter
|
||||
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
|
||||
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
|
||||
):
|
||||
skip = False
|
||||
if isinstance(scale, list):
|
||||
if all(s == 0 for s in scale):
|
||||
skip = True
|
||||
elif scale == 0:
|
||||
skip = True
|
||||
if not skip:
|
||||
if mask is not None:
|
||||
mask = mask.to(torch.float16)
|
||||
if not isinstance(scale, list):
|
||||
scale = [scale] * mask.shape[1]
|
||||
|
||||
current_num_images = mask.shape[1]
|
||||
for i in range(current_num_images):
|
||||
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
|
||||
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])
|
||||
|
||||
ip_key = attn.head_to_batch_dim(ip_key).contiguous()
|
||||
ip_value = attn.head_to_batch_dim(ip_value).contiguous()
|
||||
|
||||
_current_ip_hidden_states = xformers.ops.memory_efficient_attention(
|
||||
query, ip_key, ip_value, op=self.attention_op
|
||||
)
|
||||
_current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
|
||||
_current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)
|
||||
|
||||
mask_downsample = IPAdapterMaskProcessor.downsample(
|
||||
mask[:, i, :, :],
|
||||
batch_size,
|
||||
_current_ip_hidden_states.shape[1],
|
||||
_current_ip_hidden_states.shape[2],
|
||||
)
|
||||
|
||||
mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
|
||||
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
|
||||
else:
|
||||
ip_key = to_k_ip(current_ip_hidden_states)
|
||||
ip_value = to_v_ip(current_ip_hidden_states)
|
||||
|
||||
ip_key = attn.head_to_batch_dim(ip_key).contiguous()
|
||||
ip_value = attn.head_to_batch_dim(ip_value).contiguous()
|
||||
|
||||
current_ip_hidden_states = xformers.ops.memory_efficient_attention(
|
||||
query, ip_key, ip_value, op=self.attention_op
|
||||
)
|
||||
current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
|
||||
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + scale * current_ip_hidden_states
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PAGIdentitySelfAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
@@ -4793,19 +5054,46 @@ CROSS_ATTENTION_PROCESSORS = (
|
||||
|
||||
AttentionProcessor = Union[
|
||||
AttnProcessor,
|
||||
AttnProcessor2_0,
|
||||
FusedAttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
SlicedAttnProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
SlicedAttnAddedKVProcessor,
|
||||
AttnAddedKVProcessor2_0,
|
||||
XFormersAttnAddedKVProcessor,
|
||||
CustomDiffusionAttnProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnAddedKVProcessor2_0,
|
||||
JointAttnProcessor2_0,
|
||||
PAGJointAttnProcessor2_0,
|
||||
PAGCFGJointAttnProcessor2_0,
|
||||
FusedJointAttnProcessor2_0,
|
||||
AllegroAttnProcessor2_0,
|
||||
AuraFlowAttnProcessor2_0,
|
||||
FusedAuraFlowAttnProcessor2_0,
|
||||
FluxAttnProcessor2_0,
|
||||
FluxAttnProcessor2_0_NPU,
|
||||
FusedFluxAttnProcessor2_0,
|
||||
FusedFluxAttnProcessor2_0_NPU,
|
||||
CogVideoXAttnProcessor2_0,
|
||||
FusedCogVideoXAttnProcessor2_0,
|
||||
XFormersAttnAddedKVProcessor,
|
||||
XFormersAttnProcessor,
|
||||
AttnProcessorNPU,
|
||||
AttnProcessor2_0,
|
||||
MochiVaeAttnProcessor2_0,
|
||||
StableAudioAttnProcessor2_0,
|
||||
HunyuanAttnProcessor2_0,
|
||||
FusedHunyuanAttnProcessor2_0,
|
||||
PAGHunyuanAttnProcessor2_0,
|
||||
PAGCFGHunyuanAttnProcessor2_0,
|
||||
LuminaAttnProcessor2_0,
|
||||
MochiAttnProcessor2_0,
|
||||
FusedAttnProcessor2_0,
|
||||
CustomDiffusionXFormersAttnProcessor,
|
||||
CustomDiffusionAttnProcessor2_0,
|
||||
PAGCFGIdentitySelfAttnProcessor2_0,
|
||||
SlicedAttnProcessor,
|
||||
SlicedAttnAddedKVProcessor,
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
IPAdapterXFormersAttnProcessor,
|
||||
PAGIdentitySelfAttnProcessor2_0,
|
||||
PAGCFGHunyuanAttnProcessor2_0,
|
||||
PAGHunyuanAttnProcessor2_0,
|
||||
PAGCFGIdentitySelfAttnProcessor2_0,
|
||||
LoRAAttnProcessor,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
LoRAAttnAddedKVProcessor,
|
||||
]
|
||||
|
||||
@@ -17,6 +17,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...utils import deprecate
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
@@ -34,7 +35,7 @@ from ..modeling_utils import ModelMixin
|
||||
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
|
||||
|
||||
|
||||
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
|
||||
|
||||
|
||||
@@ -506,7 +506,7 @@ class AllegroEncoder3D(nn.Module):
|
||||
sample = self.temp_conv_in(sample)
|
||||
sample = sample + residual
|
||||
|
||||
if self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
@@ -646,7 +646,7 @@ class AllegroDecoder3D(nn.Module):
|
||||
|
||||
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
||||
|
||||
if self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
|
||||
@@ -420,7 +420,7 @@ class CogVideoXDownBlock3D(nn.Module):
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
conv_cache_key = f"resnet_{i}"
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def create_forward(*inputs):
|
||||
@@ -433,7 +433,7 @@ class CogVideoXDownBlock3D(nn.Module):
|
||||
hidden_states,
|
||||
temb,
|
||||
zq,
|
||||
conv_cache=conv_cache.get(conv_cache_key),
|
||||
conv_cache.get(conv_cache_key),
|
||||
)
|
||||
else:
|
||||
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
||||
@@ -522,7 +522,7 @@ class CogVideoXMidBlock3D(nn.Module):
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
conv_cache_key = f"resnet_{i}"
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def create_forward(*inputs):
|
||||
@@ -531,7 +531,7 @@ class CogVideoXMidBlock3D(nn.Module):
|
||||
return create_forward
|
||||
|
||||
hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
|
||||
create_custom_forward(resnet), hidden_states, temb, zq, conv_cache.get(conv_cache_key)
|
||||
)
|
||||
else:
|
||||
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
||||
@@ -636,7 +636,7 @@ class CogVideoXUpBlock3D(nn.Module):
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
conv_cache_key = f"resnet_{i}"
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def create_forward(*inputs):
|
||||
@@ -649,7 +649,7 @@ class CogVideoXUpBlock3D(nn.Module):
|
||||
hidden_states,
|
||||
temb,
|
||||
zq,
|
||||
conv_cache=conv_cache.get(conv_cache_key),
|
||||
conv_cache.get(conv_cache_key),
|
||||
)
|
||||
else:
|
||||
hidden_states, new_conv_cache[conv_cache_key] = resnet(
|
||||
@@ -773,7 +773,7 @@ class CogVideoXEncoder3D(nn.Module):
|
||||
|
||||
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
@@ -789,7 +789,7 @@ class CogVideoXEncoder3D(nn.Module):
|
||||
hidden_states,
|
||||
temb,
|
||||
None,
|
||||
conv_cache=conv_cache.get(conv_cache_key),
|
||||
conv_cache.get(conv_cache_key),
|
||||
)
|
||||
|
||||
# 2. Mid
|
||||
@@ -798,14 +798,14 @@ class CogVideoXEncoder3D(nn.Module):
|
||||
hidden_states,
|
||||
temb,
|
||||
None,
|
||||
conv_cache=conv_cache.get("mid_block"),
|
||||
conv_cache.get("mid_block"),
|
||||
)
|
||||
else:
|
||||
# 1. Down
|
||||
for i, down_block in enumerate(self.down_blocks):
|
||||
conv_cache_key = f"down_block_{i}"
|
||||
hidden_states, new_conv_cache[conv_cache_key] = down_block(
|
||||
hidden_states, temb, None, conv_cache=conv_cache.get(conv_cache_key)
|
||||
hidden_states, temb, None, conv_cache.get(conv_cache_key)
|
||||
)
|
||||
|
||||
# 2. Mid
|
||||
@@ -939,7 +939,7 @@ class CogVideoXDecoder3D(nn.Module):
|
||||
|
||||
hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
@@ -953,7 +953,7 @@ class CogVideoXDecoder3D(nn.Module):
|
||||
hidden_states,
|
||||
temb,
|
||||
sample,
|
||||
conv_cache=conv_cache.get("mid_block"),
|
||||
conv_cache.get("mid_block"),
|
||||
)
|
||||
|
||||
# 2. Up
|
||||
@@ -964,7 +964,7 @@ class CogVideoXDecoder3D(nn.Module):
|
||||
hidden_states,
|
||||
temb,
|
||||
sample,
|
||||
conv_cache=conv_cache.get(conv_cache_key),
|
||||
conv_cache.get(conv_cache_key),
|
||||
)
|
||||
else:
|
||||
# 1. Mid
|
||||
@@ -1057,6 +1057,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
force_upcast: float = True,
|
||||
use_quant_conv: bool = False,
|
||||
use_post_quant_conv: bool = False,
|
||||
invert_scale_latents: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -1475,7 +1476,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z)
|
||||
dec = self.decode(z).sample
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
return dec
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
@@ -206,7 +206,7 @@ class MochiDownBlock3D(nn.Module):
|
||||
for i, (resnet, norm, attn) in enumerate(zip(self.resnets, self.norms, self.attentions)):
|
||||
conv_cache_key = f"resnet_{i}"
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def create_forward(*inputs):
|
||||
@@ -311,7 +311,7 @@ class MochiMidBlock3D(nn.Module):
|
||||
for i, (resnet, norm, attn) in enumerate(zip(self.resnets, self.norms, self.attentions)):
|
||||
conv_cache_key = f"resnet_{i}"
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def create_forward(*inputs):
|
||||
@@ -392,7 +392,7 @@ class MochiUpBlock3D(nn.Module):
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
conv_cache_key = f"resnet_{i}"
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def create_forward(*inputs):
|
||||
@@ -437,7 +437,8 @@ class FourierFeatures(nn.Module):
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
r"""Forward method of the `FourierFeatures` class."""
|
||||
|
||||
original_dtype = inputs.dtype
|
||||
inputs = inputs.to(torch.float32)
|
||||
num_channels = inputs.shape[1]
|
||||
num_freqs = (self.stop - self.start) // self.step
|
||||
|
||||
@@ -450,7 +451,7 @@ class FourierFeatures(nn.Module):
|
||||
# Scale channels by frequency.
|
||||
h = w * h
|
||||
|
||||
return torch.cat([inputs, torch.sin(h), torch.cos(h)], dim=1)
|
||||
return torch.cat([inputs, torch.sin(h), torch.cos(h)], dim=1).to(original_dtype)
|
||||
|
||||
|
||||
class MochiEncoder3D(nn.Module):
|
||||
@@ -529,7 +530,7 @@ class MochiEncoder3D(nn.Module):
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
hidden_states = hidden_states.permute(0, 4, 1, 2, 3)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def create_forward(*inputs):
|
||||
@@ -646,7 +647,7 @@ class MochiDecoder3D(nn.Module):
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
# 1. Mid
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def create_forward(*inputs):
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import itertools
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -94,8 +95,8 @@ class TemporalDecoder(nn.Module):
|
||||
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
||||
if self.training and self.gradient_checkpointing:
|
||||
upscale_dtype = next(itertools.chain(self.up_blocks.parameters(), self.up_blocks.buffers())).dtype
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
@@ -228,14 +229,6 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
|
||||
|
||||
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
|
||||
|
||||
sample_size = (
|
||||
self.config.sample_size[0]
|
||||
if isinstance(self.config.sample_size, (list, tuple))
|
||||
else self.config.sample_size
|
||||
)
|
||||
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
|
||||
self.tile_overlap_factor = 0.25
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (Encoder, TemporalDecoder)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@@ -310,7 +310,9 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
||||
self, x: torch.Tensor, generator: Optional[torch.Generator] = None, return_dict: bool = True
|
||||
) -> Union[DecoderOutput, Tuple[torch.Tensor]]:
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)]
|
||||
output = [
|
||||
self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x_slice) for x_slice in x.split(1)
|
||||
]
|
||||
output = torch.cat(output)
|
||||
else:
|
||||
output = self._tiled_decode(x) if self.use_tiling else self.decoder(x)
|
||||
@@ -341,7 +343,7 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
|
||||
# as if we were loading the latents from an RGBA uint8 image.
|
||||
unscaled_enc = self.unscale_latents(scaled_enc / 255.0)
|
||||
|
||||
dec = self.decode(unscaled_enc)
|
||||
dec = self.decode(unscaled_enc).sample
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
@@ -142,7 +142,7 @@ class Encoder(nn.Module):
|
||||
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
@@ -291,7 +291,7 @@ class Decoder(nn.Module):
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
@@ -544,7 +544,7 @@ class MaskConditionDecoder(nn.Module):
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
@@ -876,7 +876,7 @@ class EncoderTiny(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
r"""The forward method of the `EncoderTiny` class."""
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
@@ -962,7 +962,7 @@ class DecoderTiny(nn.Module):
|
||||
# Clamp.
|
||||
x = torch.tanh(x / 3) * 3
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
|
||||
@@ -11,9 +11,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
from ..utils import deprecate
|
||||
from .controlnets.controlnet import ( # noqa
|
||||
BaseOutput,
|
||||
ControlNetConditioningEmbedding,
|
||||
ControlNetModel,
|
||||
ControlNetOutput,
|
||||
@@ -24,19 +25,91 @@ from .controlnets.controlnet import ( # noqa
|
||||
class ControlNetOutput(ControlNetOutput):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `ControlNetOutput` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetOutput`, instead."
|
||||
deprecate("ControlNetOutput", "0.34", deprecation_message)
|
||||
deprecate("diffusers.models.controlnet.ControlNetOutput", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class ControlNetModel(ControlNetModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 4,
|
||||
conditioning_channels: int = 3,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
act_fn: str = "silu",
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
encoder_hid_dim_type: Optional[str] = None,
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
addition_time_embed_dim: Optional[int] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
global_pool_conditions: bool = False,
|
||||
addition_embed_type_num_heads: int = 64,
|
||||
):
|
||||
deprecation_message = "Importing `ControlNetModel` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetModel`, instead."
|
||||
deprecate("ControlNetModel", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
deprecate("diffusers.models.controlnet.ControlNetModel", "0.34", deprecation_message)
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
conditioning_channels=conditioning_channels,
|
||||
flip_sin_to_cos=flip_sin_to_cos,
|
||||
freq_shift=freq_shift,
|
||||
down_block_types=down_block_types,
|
||||
mid_block_type=mid_block_type,
|
||||
only_cross_attention=only_cross_attention,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
downsample_padding=downsample_padding,
|
||||
mid_block_scale_factor=mid_block_scale_factor,
|
||||
act_fn=act_fn,
|
||||
norm_num_groups=norm_num_groups,
|
||||
norm_eps=norm_eps,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
transformer_layers_per_block=transformer_layers_per_block,
|
||||
encoder_hid_dim=encoder_hid_dim,
|
||||
encoder_hid_dim_type=encoder_hid_dim_type,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
use_linear_projection=use_linear_projection,
|
||||
class_embed_type=class_embed_type,
|
||||
addition_embed_type=addition_embed_type,
|
||||
addition_time_embed_dim=addition_time_embed_dim,
|
||||
num_class_embeds=num_class_embeds,
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
|
||||
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
|
||||
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
||||
global_pool_conditions=global_pool_conditions,
|
||||
addition_embed_type_num_heads=addition_embed_type_num_heads,
|
||||
)
|
||||
|
||||
|
||||
class ControlNetConditioningEmbedding(ControlNetConditioningEmbedding):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `ControlNetConditioningEmbedding` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetConditioningEmbedding`, instead."
|
||||
deprecate("ControlNetConditioningEmbedding", "0.34", deprecation_message)
|
||||
deprecate("diffusers.models.controlnet.ControlNetConditioningEmbedding", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import List
|
||||
|
||||
from ..utils import deprecate, logging
|
||||
from .controlnets.controlnet_flux import FluxControlNetModel, FluxControlNetOutput, FluxMultiControlNetModel
|
||||
|
||||
@@ -23,19 +25,46 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
class FluxControlNetOutput(FluxControlNetOutput):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `FluxControlNetOutput` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxControlNetOutput`, instead."
|
||||
deprecate("FluxControlNetOutput", "0.34", deprecation_message)
|
||||
deprecate("diffusers.models.controlnet_flux.FluxControlNetOutput", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class FluxControlNetModel(FluxControlNetModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 1,
|
||||
in_channels: int = 64,
|
||||
num_layers: int = 19,
|
||||
num_single_layers: int = 38,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 24,
|
||||
joint_attention_dim: int = 4096,
|
||||
pooled_projection_dim: int = 768,
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: List[int] = [16, 56, 56],
|
||||
num_mode: int = None,
|
||||
conditioning_embedding_channels: int = None,
|
||||
):
|
||||
deprecation_message = "Importing `FluxControlNetModel` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxControlNetModel`, instead."
|
||||
deprecate("FluxControlNetModel", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
deprecate("diffusers.models.controlnet_flux.FluxControlNetModel", "0.34", deprecation_message)
|
||||
super().__init__(
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
num_layers=num_layers,
|
||||
num_single_layers=num_single_layers,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
joint_attention_dim=joint_attention_dim,
|
||||
pooled_projection_dim=pooled_projection_dim,
|
||||
guidance_embeds=guidance_embeds,
|
||||
axes_dims_rope=axes_dims_rope,
|
||||
num_mode=num_mode,
|
||||
conditioning_embedding_channels=conditioning_embedding_channels,
|
||||
)
|
||||
|
||||
|
||||
class FluxMultiControlNetModel(FluxMultiControlNetModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `FluxMultiControlNetModel` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxMultiControlNetModel`, instead."
|
||||
deprecate("FluxMultiControlNetModel", "0.34", deprecation_message)
|
||||
deprecate("diffusers.models.controlnet_flux.FluxMultiControlNetModel", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@@ -23,19 +23,46 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
class SD3ControlNetOutput(SD3ControlNetOutput):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `SD3ControlNetOutput` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3ControlNetOutput`, instead."
|
||||
deprecate("SD3ControlNetOutput", "0.34", deprecation_message)
|
||||
deprecate("diffusers.models.controlnet_sd3.SD3ControlNetOutput", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class SD3ControlNetModel(SD3ControlNetModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
sample_size: int = 128,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 16,
|
||||
num_layers: int = 18,
|
||||
attention_head_dim: int = 64,
|
||||
num_attention_heads: int = 18,
|
||||
joint_attention_dim: int = 4096,
|
||||
caption_projection_dim: int = 1152,
|
||||
pooled_projection_dim: int = 2048,
|
||||
out_channels: int = 16,
|
||||
pos_embed_max_size: int = 96,
|
||||
extra_conditioning_channels: int = 0,
|
||||
):
|
||||
deprecation_message = "Importing `SD3ControlNetModel` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3ControlNetModel`, instead."
|
||||
deprecate("SD3ControlNetModel", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
deprecate("diffusers.models.controlnet_sd3.SD3ControlNetModel", "0.34", deprecation_message)
|
||||
super().__init__(
|
||||
sample_size=sample_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
num_layers=num_layers,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
joint_attention_dim=joint_attention_dim,
|
||||
caption_projection_dim=caption_projection_dim,
|
||||
pooled_projection_dim=pooled_projection_dim,
|
||||
out_channels=out_channels,
|
||||
pos_embed_max_size=pos_embed_max_size,
|
||||
extra_conditioning_channels=extra_conditioning_channels,
|
||||
)
|
||||
|
||||
|
||||
class SD3MultiControlNetModel(SD3MultiControlNetModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `SD3MultiControlNetModel` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3MultiControlNetModel`, instead."
|
||||
deprecate("SD3MultiControlNetModel", "0.34", deprecation_message)
|
||||
deprecate("diffusers.models.controlnet_sd3.SD3MultiControlNetModel", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
from ..utils import deprecate, logging
|
||||
from .controlnets.controlnet_sparsectrl import ( # noqa
|
||||
SparseControlNetConditioningEmbedding,
|
||||
@@ -28,19 +30,87 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
class SparseControlNetOutput(SparseControlNetOutput):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `SparseControlNetOutput` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetOutput`, instead."
|
||||
deprecate("SparseControlNetOutput", "0.34", deprecation_message)
|
||||
deprecate("diffusers.models.controlnet_sparsectrl.SparseControlNetOutput", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class SparseControlNetConditioningEmbedding(SparseControlNetConditioningEmbedding):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `SparseControlNetConditioningEmbedding` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetConditioningEmbedding`, instead."
|
||||
deprecate("SparseControlNetConditioningEmbedding", "0.34", deprecation_message)
|
||||
deprecate(
|
||||
"diffusers.models.controlnet_sparsectrl.SparseControlNetConditioningEmbedding", "0.34", deprecation_message
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class SparseControlNetModel(SparseControlNetModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 4,
|
||||
conditioning_channels: int = 4,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlockMotion",
|
||||
"CrossAttnDownBlockMotion",
|
||||
"CrossAttnDownBlockMotion",
|
||||
"DownBlockMotion",
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
act_fn: str = "silu",
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 768,
|
||||
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None,
|
||||
temporal_transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
global_pool_conditions: bool = False,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
motion_max_seq_length: int = 32,
|
||||
motion_num_attention_heads: int = 8,
|
||||
concat_conditioning_mask: bool = True,
|
||||
use_simplified_condition_embedding: bool = True,
|
||||
):
|
||||
deprecation_message = "Importing `SparseControlNetModel` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetModel`, instead."
|
||||
deprecate("SparseControlNetModel", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
deprecate("diffusers.models.controlnet_sparsectrl.SparseControlNetModel", "0.34", deprecation_message)
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
conditioning_channels=conditioning_channels,
|
||||
flip_sin_to_cos=flip_sin_to_cos,
|
||||
freq_shift=freq_shift,
|
||||
down_block_types=down_block_types,
|
||||
only_cross_attention=only_cross_attention,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
downsample_padding=downsample_padding,
|
||||
mid_block_scale_factor=mid_block_scale_factor,
|
||||
act_fn=act_fn,
|
||||
norm_num_groups=norm_num_groups,
|
||||
norm_eps=norm_eps,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
transformer_layers_per_block=transformer_layers_per_block,
|
||||
transformer_layers_per_mid_block=transformer_layers_per_mid_block,
|
||||
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
||||
global_pool_conditions=global_pool_conditions,
|
||||
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
|
||||
motion_max_seq_length=motion_max_seq_length,
|
||||
motion_num_attention_heads=motion_num_attention_heads,
|
||||
concat_conditioning_mask=concat_conditioning_mask,
|
||||
use_simplified_condition_embedding=use_simplified_condition_embedding,
|
||||
)
|
||||
|
||||
@@ -22,8 +22,8 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...models.attention_processor import AttentionProcessor
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..controlnet import BaseOutput, ControlNetConditioningEmbedding, zero_module
|
||||
from ...utils import USE_PEFT_BACKEND, BaseOutput, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..controlnets.controlnet import ControlNetConditioningEmbedding, zero_module
|
||||
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
|
||||
@@ -192,13 +192,13 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
num_attention_heads: int = 24,
|
||||
load_weights_from_transformer=True,
|
||||
):
|
||||
config = transformer.config
|
||||
config = dict(transformer.config)
|
||||
config["num_layers"] = num_layers
|
||||
config["num_single_layers"] = num_single_layers
|
||||
config["attention_head_dim"] = attention_head_dim
|
||||
config["num_attention_heads"] = num_attention_heads
|
||||
|
||||
controlnet = cls(**config)
|
||||
controlnet = cls.from_config(config)
|
||||
|
||||
if load_weights_from_transformer:
|
||||
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
|
||||
@@ -329,7 +329,7 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
|
||||
block_samples = ()
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
@@ -363,7 +363,7 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
|
||||
single_block_samples = ()
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
|
||||
@@ -18,7 +18,7 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ...utils import BaseOutput, logging
|
||||
from ..attention_processor import AttentionProcessor
|
||||
from ..embeddings import (
|
||||
HunyuanCombinedTimestepTextSizeStyleEmbedding,
|
||||
@@ -27,7 +27,7 @@ from ..embeddings import (
|
||||
)
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..transformers.hunyuan_transformer_2d import HunyuanDiTBlock
|
||||
from .controlnet import BaseOutput, Tuple, zero_module
|
||||
from .controlnet import Tuple, zero_module
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -27,6 +27,7 @@ from ..attention_processor import Attention, AttentionProcessor, FusedJointAttnP
|
||||
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..transformers.transformer_sd3 import SD3SingleTransformerBlock
|
||||
from .controlnet import BaseOutput, zero_module
|
||||
|
||||
|
||||
@@ -56,38 +57,62 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
out_channels: int = 16,
|
||||
pos_embed_max_size: int = 96,
|
||||
extra_conditioning_channels: int = 0,
|
||||
dual_attention_layers: Tuple[int, ...] = (),
|
||||
qk_norm: Optional[str] = None,
|
||||
pos_embed_type: Optional[str] = "sincos",
|
||||
use_pos_embed: bool = True,
|
||||
force_zeros_for_pooled_projection: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
default_out_channels = in_channels
|
||||
self.out_channels = out_channels if out_channels is not None else default_out_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.pos_embed = PatchEmbed(
|
||||
height=sample_size,
|
||||
width=sample_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=self.inner_dim,
|
||||
pos_embed_max_size=pos_embed_max_size,
|
||||
)
|
||||
if use_pos_embed:
|
||||
self.pos_embed = PatchEmbed(
|
||||
height=sample_size,
|
||||
width=sample_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=self.inner_dim,
|
||||
pos_embed_max_size=pos_embed_max_size,
|
||||
pos_embed_type=pos_embed_type,
|
||||
)
|
||||
else:
|
||||
self.pos_embed = None
|
||||
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
|
||||
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
|
||||
)
|
||||
self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)
|
||||
if joint_attention_dim is not None:
|
||||
self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)
|
||||
|
||||
# `attention_head_dim` is doubled to account for the mixing.
|
||||
# It needs to crafted when we get the actual checkpoints.
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=self.config.attention_head_dim,
|
||||
context_pre_only=False,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
)
|
||||
# `attention_head_dim` is doubled to account for the mixing.
|
||||
# It needs to crafted when we get the actual checkpoints.
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=self.config.attention_head_dim,
|
||||
context_pre_only=False,
|
||||
qk_norm=qk_norm,
|
||||
use_dual_attention=True if i in dual_attention_layers else False,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.context_embedder = None
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
SD3SingleTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=self.config.attention_head_dim,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# controlnet_blocks
|
||||
self.controlnet_blocks = nn.ModuleList([])
|
||||
@@ -241,6 +266,20 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
# Notes: This is for SD3.5 8b controlnet, which shares the pos_embed with the transformer
|
||||
# we should have handled this in conversion script
|
||||
def _get_pos_embed_from_transformer(self, transformer):
|
||||
pos_embed = PatchEmbed(
|
||||
height=transformer.config.sample_size,
|
||||
width=transformer.config.sample_size,
|
||||
patch_size=transformer.config.patch_size,
|
||||
in_channels=transformer.config.in_channels,
|
||||
embed_dim=transformer.inner_dim,
|
||||
pos_embed_max_size=transformer.config.pos_embed_max_size,
|
||||
)
|
||||
pos_embed.load_state_dict(transformer.pos_embed.state_dict(), strict=True)
|
||||
return pos_embed
|
||||
|
||||
@classmethod
|
||||
def from_transformer(
|
||||
cls, transformer, num_layers=12, num_extra_conditioning_channels=1, load_weights_from_transformer=True
|
||||
@@ -248,7 +287,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
config = transformer.config
|
||||
config["num_layers"] = num_layers or config.num_layers
|
||||
config["extra_conditioning_channels"] = num_extra_conditioning_channels
|
||||
controlnet = cls(**config)
|
||||
controlnet = cls.from_config(config)
|
||||
|
||||
if load_weights_from_transformer:
|
||||
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
|
||||
@@ -314,9 +353,27 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
|
||||
if self.pos_embed is not None and hidden_states.ndim != 4:
|
||||
raise ValueError("hidden_states must be 4D when pos_embed is used")
|
||||
|
||||
# SD3.5 8b controlnet does not have a `pos_embed`,
|
||||
# it use the `pos_embed` from the transformer to process input before passing to controlnet
|
||||
elif self.pos_embed is None and hidden_states.ndim != 3:
|
||||
raise ValueError("hidden_states must be 3D when pos_embed is not used")
|
||||
|
||||
if self.context_embedder is not None and encoder_hidden_states is None:
|
||||
raise ValueError("encoder_hidden_states must be provided when context_embedder is used")
|
||||
# SD3.5 8b controlnet does not have a `context_embedder`, it does not use `encoder_hidden_states`
|
||||
elif self.context_embedder is None and encoder_hidden_states is not None:
|
||||
raise ValueError("encoder_hidden_states should not be provided when context_embedder is not used")
|
||||
|
||||
if self.pos_embed is not None:
|
||||
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
|
||||
|
||||
temb = self.time_text_embed(timestep, pooled_projections)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
if self.context_embedder is not None:
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
# add
|
||||
hidden_states = hidden_states + self.pos_embed_input(controlnet_cond)
|
||||
@@ -324,7 +381,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
block_res_samples = ()
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
@@ -345,9 +402,13 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
|
||||
)
|
||||
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
|
||||
)
|
||||
if self.context_embedder is not None:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
|
||||
)
|
||||
else:
|
||||
# SD3.5 8b controlnet use single transformer block, which does not use `encoder_hidden_states`
|
||||
hidden_states = block(hidden_states, temb)
|
||||
|
||||
block_res_samples = block_res_samples + (hidden_states,)
|
||||
|
||||
|
||||
@@ -1466,7 +1466,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
|
||||
h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1)
|
||||
|
||||
# apply base subblock
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
h_base = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(b_res),
|
||||
@@ -1489,7 +1489,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
|
||||
|
||||
# apply ctrl subblock
|
||||
if apply_control:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
h_ctrl = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(c_res),
|
||||
@@ -1898,7 +1898,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
|
||||
hidden_states, res_h_base = maybe_apply_freeu_to_subblock(hidden_states, res_h_base)
|
||||
hidden_states = torch.cat([hidden_states, res_h_base], dim=1)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
|
||||
@@ -82,7 +82,7 @@ class MultiControlNetModel(ModelMixin):
|
||||
):
|
||||
"""
|
||||
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
||||
`[`~pipelines.controlnet.MultiControlNetModel.from_pretrained`]` class method.
|
||||
`[`~models.controlnets.multicontrolnet.MultiControlNetModel.from_pretrained`]` class method.
|
||||
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
@@ -128,7 +128,7 @@ class MultiControlNetModel(ModelMixin):
|
||||
Parameters:
|
||||
pretrained_model_path (`os.PathLike`):
|
||||
A path to a *directory* containing model weights saved using
|
||||
[`~diffusers.pipelines.controlnet.MultiControlNetModel.save_pretrained`], e.g.,
|
||||
[`~models.controlnets.multicontrolnet.MultiControlNetModel.save_pretrained`], e.g.,
|
||||
`./my_model_directory/controlnet`.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
|
||||
|
||||
@@ -86,12 +86,25 @@ def get_3d_sincos_pos_embed(
|
||||
temporal_interpolation_scale: float = 1.0,
|
||||
) -> np.ndarray:
|
||||
r"""
|
||||
Creates 3D sinusoidal positional embeddings.
|
||||
|
||||
Args:
|
||||
embed_dim (`int`):
|
||||
The embedding dimension of inputs. It must be divisible by 16.
|
||||
spatial_size (`int` or `Tuple[int, int]`):
|
||||
The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
|
||||
spatial dimensions (height and width).
|
||||
temporal_size (`int`):
|
||||
The temporal dimension of postional embeddings (number of frames).
|
||||
spatial_interpolation_scale (`float`, defaults to 1.0):
|
||||
Scale factor for spatial grid interpolation.
|
||||
temporal_interpolation_scale (`float`, defaults to 1.0):
|
||||
Scale factor for temporal grid interpolation.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`:
|
||||
The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1],
|
||||
embed_dim]`.
|
||||
"""
|
||||
if embed_dim % 4 != 0:
|
||||
raise ValueError("`embed_dim` must be divisible by 4")
|
||||
@@ -129,8 +142,24 @@ def get_2d_sincos_pos_embed(
|
||||
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
|
||||
):
|
||||
"""
|
||||
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
|
||||
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
Creates 2D sinusoidal positional embeddings.
|
||||
|
||||
Args:
|
||||
embed_dim (`int`):
|
||||
The embedding dimension.
|
||||
grid_size (`int`):
|
||||
The size of the grid height and width.
|
||||
cls_token (`bool`, defaults to `False`):
|
||||
Whether or not to add a classification token.
|
||||
extra_tokens (`int`, defaults to `0`):
|
||||
The number of extra tokens to add.
|
||||
interpolation_scale (`float`, defaults to `1.0`):
|
||||
The scale of the interpolation.
|
||||
|
||||
Returns:
|
||||
pos_embed (`np.ndarray`):
|
||||
Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size,
|
||||
embed_dim]` if using cls_token
|
||||
"""
|
||||
if isinstance(grid_size, int):
|
||||
grid_size = (grid_size, grid_size)
|
||||
@@ -148,6 +177,16 @@ def get_2d_sincos_pos_embed(
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
r"""
|
||||
This function generates 2D sinusoidal positional embeddings from a grid.
|
||||
|
||||
Args:
|
||||
embed_dim (`int`): The embedding dimension.
|
||||
grid (`np.ndarray`): Grid of positions with shape `(H * W,)`.
|
||||
|
||||
Returns:
|
||||
`np.ndarray`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
|
||||
"""
|
||||
if embed_dim % 2 != 0:
|
||||
raise ValueError("embed_dim must be divisible by 2")
|
||||
|
||||
@@ -161,7 +200,14 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
|
||||
This function generates 1D positional embeddings from a grid.
|
||||
|
||||
Args:
|
||||
embed_dim (`int`): The embedding dimension `D`
|
||||
pos (`numpy.ndarray`): 1D tensor of positions with shape `(M,)`
|
||||
|
||||
Returns:
|
||||
`numpy.ndarray`: Sinusoidal positional embeddings of shape `(M, D)`.
|
||||
"""
|
||||
if embed_dim % 2 != 0:
|
||||
raise ValueError("embed_dim must be divisible by 2")
|
||||
@@ -181,7 +227,22 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""2D Image to Patch Embedding with support for SD3 cropping."""
|
||||
"""
|
||||
2D Image to Patch Embedding with support for SD3 cropping.
|
||||
|
||||
Args:
|
||||
height (`int`, defaults to `224`): The height of the image.
|
||||
width (`int`, defaults to `224`): The width of the image.
|
||||
patch_size (`int`, defaults to `16`): The size of the patches.
|
||||
in_channels (`int`, defaults to `3`): The number of input channels.
|
||||
embed_dim (`int`, defaults to `768`): The output dimension of the embedding.
|
||||
layer_norm (`bool`, defaults to `False`): Whether or not to use layer normalization.
|
||||
flatten (`bool`, defaults to `True`): Whether or not to flatten the output.
|
||||
bias (`bool`, defaults to `True`): Whether or not to use bias.
|
||||
interpolation_scale (`float`, defaults to `1`): The scale of the interpolation.
|
||||
pos_embed_type (`str`, defaults to `"sincos"`): The type of positional embedding.
|
||||
pos_embed_max_size (`int`, defaults to `None`): The maximum size of the positional embedding.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -289,7 +350,15 @@ class PatchEmbed(nn.Module):
|
||||
|
||||
|
||||
class LuminaPatchEmbed(nn.Module):
|
||||
"""2D Image to Patch Embedding with support for Lumina-T2X"""
|
||||
"""
|
||||
2D Image to Patch Embedding with support for Lumina-T2X
|
||||
|
||||
Args:
|
||||
patch_size (`int`, defaults to `2`): The size of the patches.
|
||||
in_channels (`int`, defaults to `4`): The number of input channels.
|
||||
embed_dim (`int`, defaults to `768`): The output dimension of the embedding.
|
||||
bias (`bool`, defaults to `True`): Whether or not to use bias.
|
||||
"""
|
||||
|
||||
def __init__(self, patch_size=2, in_channels=4, embed_dim=768, bias=True):
|
||||
super().__init__()
|
||||
@@ -338,6 +407,7 @@ class CogVideoXPatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
patch_size_t: Optional[int] = None,
|
||||
in_channels: int = 16,
|
||||
embed_dim: int = 1920,
|
||||
text_embed_dim: int = 4096,
|
||||
@@ -355,6 +425,7 @@ class CogVideoXPatchEmbed(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.patch_size_t = patch_size_t
|
||||
self.embed_dim = embed_dim
|
||||
self.sample_height = sample_height
|
||||
self.sample_width = sample_width
|
||||
@@ -366,9 +437,15 @@ class CogVideoXPatchEmbed(nn.Module):
|
||||
self.use_positional_embeddings = use_positional_embeddings
|
||||
self.use_learned_positional_embeddings = use_learned_positional_embeddings
|
||||
|
||||
self.proj = nn.Conv2d(
|
||||
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
||||
)
|
||||
if patch_size_t is None:
|
||||
# CogVideoX 1.0 checkpoints
|
||||
self.proj = nn.Conv2d(
|
||||
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
||||
)
|
||||
else:
|
||||
# CogVideoX 1.5 checkpoints
|
||||
self.proj = nn.Linear(in_channels * patch_size * patch_size * patch_size_t, embed_dim)
|
||||
|
||||
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
|
||||
|
||||
if use_positional_embeddings or use_learned_positional_embeddings:
|
||||
@@ -407,12 +484,24 @@ class CogVideoXPatchEmbed(nn.Module):
|
||||
"""
|
||||
text_embeds = self.text_proj(text_embeds)
|
||||
|
||||
batch, num_frames, channels, height, width = image_embeds.shape
|
||||
image_embeds = image_embeds.reshape(-1, channels, height, width)
|
||||
image_embeds = self.proj(image_embeds)
|
||||
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
|
||||
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
|
||||
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
|
||||
batch_size, num_frames, channels, height, width = image_embeds.shape
|
||||
|
||||
if self.patch_size_t is None:
|
||||
image_embeds = image_embeds.reshape(-1, channels, height, width)
|
||||
image_embeds = self.proj(image_embeds)
|
||||
image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
|
||||
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
|
||||
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
|
||||
else:
|
||||
p = self.patch_size
|
||||
p_t = self.patch_size_t
|
||||
|
||||
image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
|
||||
image_embeds = image_embeds.reshape(
|
||||
batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
|
||||
)
|
||||
image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
|
||||
image_embeds = self.proj(image_embeds)
|
||||
|
||||
embeds = torch.cat(
|
||||
[text_embeds, image_embeds], dim=1
|
||||
@@ -497,7 +586,14 @@ class CogView3PlusPatchEmbed(nn.Module):
|
||||
|
||||
|
||||
def get_3d_rotary_pos_embed(
|
||||
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
|
||||
embed_dim,
|
||||
crops_coords,
|
||||
grid_size,
|
||||
temporal_size,
|
||||
theta: int = 10000,
|
||||
use_real: bool = True,
|
||||
grid_type: str = "linspace",
|
||||
max_size: Optional[Tuple[int, int]] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
RoPE for video tokens with 3D structure.
|
||||
@@ -513,17 +609,30 @@ def get_3d_rotary_pos_embed(
|
||||
The size of the temporal dimension.
|
||||
theta (`float`):
|
||||
Scaling factor for frequency computation.
|
||||
grid_type (`str`):
|
||||
Whether to use "linspace" or "slice" to compute grids.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
|
||||
"""
|
||||
if use_real is not True:
|
||||
raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed")
|
||||
start, stop = crops_coords
|
||||
grid_size_h, grid_size_w = grid_size
|
||||
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
|
||||
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
|
||||
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
|
||||
|
||||
if grid_type == "linspace":
|
||||
start, stop = crops_coords
|
||||
grid_size_h, grid_size_w = grid_size
|
||||
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
|
||||
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
|
||||
grid_t = np.arange(temporal_size, dtype=np.float32)
|
||||
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
|
||||
elif grid_type == "slice":
|
||||
max_h, max_w = max_size
|
||||
grid_size_h, grid_size_w = grid_size
|
||||
grid_h = np.arange(max_h, dtype=np.float32)
|
||||
grid_w = np.arange(max_w, dtype=np.float32)
|
||||
grid_t = np.arange(temporal_size, dtype=np.float32)
|
||||
else:
|
||||
raise ValueError("Invalid value passed for `grid_type`.")
|
||||
|
||||
# Compute dimensions for each axis
|
||||
dim_t = embed_dim // 4
|
||||
@@ -559,6 +668,12 @@ def get_3d_rotary_pos_embed(
|
||||
t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t
|
||||
h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h
|
||||
w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w
|
||||
|
||||
if grid_type == "slice":
|
||||
t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size]
|
||||
h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h]
|
||||
w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w]
|
||||
|
||||
cos = combine_time_height_width(t_cos, h_cos, w_cos)
|
||||
sin = combine_time_height_width(t_sin, h_sin, w_sin)
|
||||
return cos, sin
|
||||
@@ -629,6 +744,20 @@ def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
|
||||
|
||||
|
||||
def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
|
||||
"""
|
||||
Get 2D RoPE from grid.
|
||||
|
||||
Args:
|
||||
embed_dim: (`int`):
|
||||
The embedding dimension size, corresponding to hidden_size_head.
|
||||
grid (`np.ndarray`):
|
||||
The grid of the positional embedding.
|
||||
use_real (`bool`):
|
||||
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
|
||||
"""
|
||||
assert embed_dim % 4 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
@@ -649,6 +778,23 @@ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
|
||||
|
||||
|
||||
def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0):
|
||||
"""
|
||||
Get 2D RoPE from grid.
|
||||
|
||||
Args:
|
||||
embed_dim: (`int`):
|
||||
The embedding dimension size, corresponding to hidden_size_head.
|
||||
grid (`np.ndarray`):
|
||||
The grid of the positional embedding.
|
||||
linear_factor (`float`):
|
||||
The linear factor of the positional embedding, which is used to scale the positional embedding in the linear
|
||||
layer.
|
||||
ntk_factor (`float`):
|
||||
The ntk factor of the positional embedding, which is used to scale the positional embedding in the ntk layer.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
|
||||
"""
|
||||
assert embed_dim % 4 == 0
|
||||
|
||||
emb_h = get_1d_rotary_pos_embed(
|
||||
|
||||
@@ -128,7 +128,7 @@ def _fetch_remapped_cls_from_config(config, old_class):
|
||||
return old_class
|
||||
|
||||
|
||||
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
|
||||
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None, dduf_reader=None):
|
||||
"""
|
||||
Reads a checkpoint file, returning properly formatted errors if they arise.
|
||||
"""
|
||||
@@ -138,8 +138,15 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
|
||||
return checkpoint_file
|
||||
try:
|
||||
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
|
||||
if dduf_reader:
|
||||
checkpoint_file = dduf_reader.read_file(checkpoint_file)
|
||||
if file_extension == SAFETENSORS_FILE_EXTENSION:
|
||||
return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
||||
if dduf_reader:
|
||||
# tensors are loaded on cpu
|
||||
return safetensors.torch.load(checkpoint_file)
|
||||
else:
|
||||
return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
||||
|
||||
else:
|
||||
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
|
||||
return torch.load(
|
||||
@@ -272,6 +279,7 @@ def _fetch_index_file(
|
||||
revision,
|
||||
user_agent,
|
||||
commit_hash,
|
||||
dduf_reader=None,
|
||||
):
|
||||
if is_local:
|
||||
index_file = Path(
|
||||
@@ -297,6 +305,7 @@ def _fetch_index_file(
|
||||
subfolder=None,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
dduf_reader=dduf_reader,
|
||||
)
|
||||
index_file = Path(index_file)
|
||||
except (EntryNotFoundError, EnvironmentError):
|
||||
|
||||
@@ -530,7 +530,7 @@ class FlaxModelMixin(PushToHubMixin):
|
||||
|
||||
if push_to_hub:
|
||||
commit_message = kwargs.pop("commit_message", None)
|
||||
private = kwargs.pop("private", False)
|
||||
private = kwargs.pop("private", None)
|
||||
create_pr = kwargs.pop("create_pr", False)
|
||||
token = kwargs.pop("token", None)
|
||||
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
||||
|
||||
@@ -338,7 +338,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
|
||||
if push_to_hub:
|
||||
commit_message = kwargs.pop("commit_message", None)
|
||||
private = kwargs.pop("private", False)
|
||||
private = kwargs.pop("private", None)
|
||||
create_pr = kwargs.pop("create_pr", False)
|
||||
token = kwargs.pop("token", None)
|
||||
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
||||
@@ -557,6 +557,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
variant = kwargs.pop("variant", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
dduf_reader = kwargs.pop("dduf_reader", None)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
@@ -649,6 +650,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
dduf_reader=dduf_reader,
|
||||
**kwargs,
|
||||
)
|
||||
# no in-place modification of the original config.
|
||||
@@ -724,6 +726,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
"revision": revision,
|
||||
"user_agent": user_agent,
|
||||
"commit_hash": commit_hash,
|
||||
"dduf_reader": dduf_reader,
|
||||
}
|
||||
index_file = _fetch_index_file(**index_file_kwargs)
|
||||
# In case the index file was not found we still have to consider the legacy format.
|
||||
@@ -759,7 +762,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
|
||||
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
|
||||
else:
|
||||
if is_sharded:
|
||||
# in the case it is sharded, we have already the index
|
||||
if is_sharded and not dduf_reader:
|
||||
sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files(
|
||||
pretrained_model_name_or_path,
|
||||
index_file,
|
||||
@@ -790,6 +794,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
dduf_reader=dduf_reader,
|
||||
)
|
||||
|
||||
except IOError as e:
|
||||
@@ -813,6 +818,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
dduf_reader=dduf_reader,
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
@@ -837,7 +843,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
|
||||
elif is_quant_method_bnb:
|
||||
param_device = torch.cuda.current_device()
|
||||
state_dict = load_state_dict(model_file, variant=variant)
|
||||
state_dict = load_state_dict(model_file, variant=variant, dduf_reader=dduf_reader)
|
||||
model._convert_deprecated_attention_blocks(state_dict)
|
||||
|
||||
# move the params from meta device to cpu
|
||||
@@ -937,7 +943,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
else:
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
|
||||
state_dict = load_state_dict(model_file, variant=variant)
|
||||
state_dict = load_state_dict(model_file, variant=variant, dduf_reader=dduf_reader)
|
||||
model._convert_deprecated_attention_blocks(state_dict)
|
||||
|
||||
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
||||
|
||||
@@ -466,7 +466,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# MMDiT blocks.
|
||||
for index_block, block in enumerate(self.joint_transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
@@ -497,7 +497,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
combined_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
|
||||
@@ -170,6 +170,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
Whether to flip the sin to cos in the time embedding.
|
||||
time_embed_dim (`int`, defaults to `512`):
|
||||
Output dimension of timestep embeddings.
|
||||
ofs_embed_dim (`int`, defaults to `512`):
|
||||
Output dimension of "ofs" embeddings used in CogVideoX-5b-I2B in version 1.5
|
||||
text_embed_dim (`int`, defaults to `4096`):
|
||||
Input dimension of text embeddings from the text encoder.
|
||||
num_layers (`int`, defaults to `30`):
|
||||
@@ -177,7 +179,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
dropout (`float`, defaults to `0.0`):
|
||||
The dropout probability to use.
|
||||
attention_bias (`bool`, defaults to `True`):
|
||||
Whether or not to use bias in the attention projection layers.
|
||||
Whether to use bias in the attention projection layers.
|
||||
sample_width (`int`, defaults to `90`):
|
||||
The width of the input latents.
|
||||
sample_height (`int`, defaults to `60`):
|
||||
@@ -198,7 +200,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
timestep_activation_fn (`str`, defaults to `"silu"`):
|
||||
Activation function to use when generating the timestep embeddings.
|
||||
norm_elementwise_affine (`bool`, defaults to `True`):
|
||||
Whether or not to use elementwise affine in normalization layers.
|
||||
Whether to use elementwise affine in normalization layers.
|
||||
norm_eps (`float`, defaults to `1e-5`):
|
||||
The epsilon value to use in normalization layers.
|
||||
spatial_interpolation_scale (`float`, defaults to `1.875`):
|
||||
@@ -219,6 +221,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
time_embed_dim: int = 512,
|
||||
ofs_embed_dim: Optional[int] = None,
|
||||
text_embed_dim: int = 4096,
|
||||
num_layers: int = 30,
|
||||
dropout: float = 0.0,
|
||||
@@ -227,6 +230,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
sample_height: int = 60,
|
||||
sample_frames: int = 49,
|
||||
patch_size: int = 2,
|
||||
patch_size_t: Optional[int] = None,
|
||||
temporal_compression_ratio: int = 4,
|
||||
max_text_seq_length: int = 226,
|
||||
activation_fn: str = "gelu-approximate",
|
||||
@@ -237,6 +241,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
temporal_interpolation_scale: float = 1.0,
|
||||
use_rotary_positional_embeddings: bool = False,
|
||||
use_learned_positional_embeddings: bool = False,
|
||||
patch_bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
@@ -251,10 +256,11 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
# 1. Patch embedding
|
||||
self.patch_embed = CogVideoXPatchEmbed(
|
||||
patch_size=patch_size,
|
||||
patch_size_t=patch_size_t,
|
||||
in_channels=in_channels,
|
||||
embed_dim=inner_dim,
|
||||
text_embed_dim=text_embed_dim,
|
||||
bias=True,
|
||||
bias=patch_bias,
|
||||
sample_width=sample_width,
|
||||
sample_height=sample_height,
|
||||
sample_frames=sample_frames,
|
||||
@@ -267,10 +273,19 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
)
|
||||
self.embedding_dropout = nn.Dropout(dropout)
|
||||
|
||||
# 2. Time embeddings
|
||||
# 2. Time embeddings and ofs embedding(Only CogVideoX1.5-5B I2V have)
|
||||
|
||||
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
|
||||
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
|
||||
|
||||
self.ofs_proj = None
|
||||
self.ofs_embedding = None
|
||||
if ofs_embed_dim:
|
||||
self.ofs_proj = Timesteps(ofs_embed_dim, flip_sin_to_cos, freq_shift)
|
||||
self.ofs_embedding = TimestepEmbedding(
|
||||
ofs_embed_dim, ofs_embed_dim, timestep_activation_fn
|
||||
) # same as time embeddings, for ofs
|
||||
|
||||
# 3. Define spatio-temporal transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
@@ -298,7 +313,15 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
norm_eps=norm_eps,
|
||||
chunk_dim=1,
|
||||
)
|
||||
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
|
||||
|
||||
if patch_size_t is None:
|
||||
# For CogVideox 1.0
|
||||
output_dim = patch_size * patch_size * out_channels
|
||||
else:
|
||||
# For CogVideoX 1.5
|
||||
output_dim = patch_size * patch_size * patch_size_t * out_channels
|
||||
|
||||
self.proj_out = nn.Linear(inner_dim, output_dim)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@@ -411,6 +434,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: Union[int, float, torch.LongTensor],
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
ofs: Optional[Union[int, float, torch.LongTensor]] = None,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
@@ -442,6 +466,12 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
|
||||
if self.ofs_embedding is not None:
|
||||
ofs_emb = self.ofs_proj(ofs)
|
||||
ofs_emb = ofs_emb.to(dtype=hidden_states.dtype)
|
||||
ofs_emb = self.ofs_embedding(ofs_emb)
|
||||
emb = emb + ofs_emb
|
||||
|
||||
# 2. Patch embedding
|
||||
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
||||
hidden_states = self.embedding_dropout(hidden_states)
|
||||
@@ -452,7 +482,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
|
||||
# 3. Transformer blocks
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
@@ -491,12 +521,17 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# 5. Unpatchify
|
||||
# Note: we use `-1` instead of `channels`:
|
||||
# - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
|
||||
# - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
|
||||
p = self.config.patch_size
|
||||
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
|
||||
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
||||
p_t = self.config.patch_size_t
|
||||
|
||||
if p_t is None:
|
||||
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
|
||||
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
||||
else:
|
||||
output = hidden_states.reshape(
|
||||
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
|
||||
)
|
||||
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
|
||||
@@ -184,7 +184,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# 2. Blocks
|
||||
for block in self.transformer_blocks:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
|
||||
@@ -238,7 +238,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
for i, (spatial_block, temp_block) in enumerate(
|
||||
zip(self.transformer_blocks, self.temporal_transformer_blocks)
|
||||
):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
spatial_block,
|
||||
hidden_states,
|
||||
@@ -271,7 +271,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
if i == 0 and num_frame > 1:
|
||||
hidden_states = hidden_states + self.temp_pos_embed
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
temp_block,
|
||||
hidden_states,
|
||||
|
||||
@@ -386,7 +386,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# 2. Blocks
|
||||
for block in self.transformer_blocks:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user