Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6ca6fbd614 | |||
| 3e3d102f20 | |||
| 1b4c4d4614 | |||
| 28ef949cf6 |
@@ -63,27 +63,23 @@ body:
|
||||
|
||||
Please tag a maximum of 2 people.
|
||||
|
||||
Questions on DiffusionPipeline (Saving, Loading, From pretrained, ...): @sayakpaul @DN6
|
||||
Questions on DiffusionPipeline (Saving, Loading, From pretrained, ...):
|
||||
|
||||
Questions on pipelines:
|
||||
- Stable Diffusion @yiyixuxu @asomoza
|
||||
- Stable Diffusion @yiyixuxu @DN6 @sayakpaul
|
||||
- Stable Diffusion XL @yiyixuxu @sayakpaul @DN6
|
||||
- Stable Diffusion 3: @yiyixuxu @sayakpaul @DN6 @asomoza
|
||||
- Kandinsky @yiyixuxu
|
||||
- ControlNet @sayakpaul @yiyixuxu @DN6
|
||||
- T2I Adapter @sayakpaul @yiyixuxu @DN6
|
||||
- IF @DN6
|
||||
- Text-to-Video / Video-to-Video @DN6 @a-r-r-o-w
|
||||
- Text-to-Video / Video-to-Video @DN6 @sayakpaul
|
||||
- Wuerstchen @DN6
|
||||
- Other: @yiyixuxu @DN6
|
||||
- Improving generation quality: @asomoza
|
||||
|
||||
Questions on models:
|
||||
- UNet @DN6 @yiyixuxu @sayakpaul
|
||||
- VAE @sayakpaul @DN6 @yiyixuxu
|
||||
- Transformers/Attention @DN6 @yiyixuxu @sayakpaul
|
||||
|
||||
Questions on single file checkpoints: @DN6
|
||||
- Transformers/Attention @DN6 @yiyixuxu @sayakpaul @DN6
|
||||
|
||||
Questions on Schedulers: @yiyixuxu
|
||||
|
||||
@@ -103,7 +99,7 @@ body:
|
||||
|
||||
Questions on JAX- and MPS-related things: @pcuenca
|
||||
|
||||
Questions on audio pipelines: @sanchit-gandhi
|
||||
Questions on audio pipelines: @DN6
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ members/contributors who may be interested in your PR.
|
||||
Core library:
|
||||
|
||||
- Schedulers: @yiyixuxu
|
||||
- Pipelines and pipeline callbacks: @yiyixuxu and @asomoza
|
||||
- Pipelines: @sayakpaul @yiyixuxu @DN6
|
||||
- Training examples: @sayakpaul
|
||||
- Docs: @stevhliu and @sayakpaul
|
||||
- JAX and MPS: @pcuenca
|
||||
@@ -48,8 +48,7 @@ Core library:
|
||||
|
||||
Integrations:
|
||||
|
||||
- deepspeed: HF Trainer/Accelerate: @SunMarc
|
||||
- PEFT: @sayakpaul @BenjaminBossan
|
||||
- deepspeed: HF Trainer/Accelerate: @pacman100
|
||||
|
||||
HF projects:
|
||||
|
||||
|
||||
@@ -13,15 +13,13 @@ env:
|
||||
|
||||
jobs:
|
||||
torch_pipelines_cuda_benchmark_tests:
|
||||
env:
|
||||
SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL_BENCHMARK }}
|
||||
name: Torch Core Pipelines CUDA Benchmarking Tests
|
||||
strategy:
|
||||
fail-fast: false
|
||||
max-parallel: 1
|
||||
runs-on: [single-gpu, nvidia-gpu, a10, ci]
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-compile-cuda
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ --gpus 0
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
@@ -52,14 +50,4 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: benchmark_test_reports
|
||||
path: benchmarks/benchmark_outputs
|
||||
|
||||
- name: Report success status
|
||||
if: ${{ success() }}
|
||||
run: |
|
||||
pip install requests && python utils/notify_benchmarking_status.py --status=success
|
||||
|
||||
- name: Report failure status
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
pip install requests && python utils/notify_benchmarking_status.py --status=failure
|
||||
path: benchmarks/benchmark_outputs
|
||||
@@ -22,9 +22,6 @@ on:
|
||||
|
||||
jobs:
|
||||
mirror_community_pipeline:
|
||||
env:
|
||||
SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL_COMMUNITY_MIRROR }}
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
# Checkout to correct ref
|
||||
@@ -89,14 +86,4 @@ jobs:
|
||||
run: huggingface-cli upload diffusers/community-pipelines-mirror ./examples/community ${PATH_IN_REPO} --repo-type dataset
|
||||
env:
|
||||
PATH_IN_REPO: ${{ env.PATH_IN_REPO }}
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN_MIRROR_COMMUNITY_PIPELINES }}
|
||||
|
||||
- name: Report success status
|
||||
if: ${{ success() }}
|
||||
run: |
|
||||
pip install requests && python utils/notify_community_pipelines_mirror.py --status=success
|
||||
|
||||
- name: Report failure status
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
pip install requests && python utils/notify_community_pipelines_mirror.py --status=failure
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN_MIRROR_COMMUNITY_PIPELINES }}
|
||||
@@ -330,7 +330,6 @@ jobs:
|
||||
- name: Run example tests on GPU
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||
RUN_COMPILE: yes
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
|
||||
- name: Failure short reports
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
name: SSH into PR runners
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
docker_image:
|
||||
description: 'Name of the Docker image'
|
||||
required: true
|
||||
|
||||
env:
|
||||
IS_GITHUB_CI: "1"
|
||||
HF_HUB_READ_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
|
||||
HF_HOME: /mnt/cache
|
||||
DIFFUSERS_IS_CI: yes
|
||||
OMP_NUM_THREADS: 8
|
||||
MKL_NUM_THREADS: 8
|
||||
RUN_SLOW: yes
|
||||
|
||||
jobs:
|
||||
ssh_runner:
|
||||
name: "SSH"
|
||||
runs-on: [self-hosted, intel-cpu, 32-cpu, 256-ram, ci]
|
||||
container:
|
||||
image: ${{ github.event.inputs.docker_image }}
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface/diffusers:/mnt/cache/ --privileged
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Tailscale # In order to be able to SSH when a test fails
|
||||
uses: huggingface/tailscale-action@main
|
||||
with:
|
||||
authkey: ${{ secrets.TAILSCALE_SSH_AUTHKEY }}
|
||||
slackChannel: ${{ secrets.SLACK_CIFEEDBACK_CHANNEL }}
|
||||
slackToken: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
waitForSSH: true
|
||||
@@ -1,4 +1,4 @@
|
||||
name: SSH into GPU runners
|
||||
name: SSH into runners
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
+1
-1
@@ -175,4 +175,4 @@ tags
|
||||
.ruff_cache
|
||||
|
||||
# wandb
|
||||
wandb
|
||||
wandb
|
||||
|
||||
+6
-6
@@ -63,14 +63,14 @@ Let's walk through more detailed design decisions for each class.
|
||||
Pipelines are designed to be easy to use (therefore do not follow [*Simple over easy*](#simple-over-easy) 100%), are not feature complete, and should loosely be seen as examples of how to use [models](#models) and [schedulers](#schedulers) for inference.
|
||||
|
||||
The following design principles are followed:
|
||||
- Pipelines follow the single-file policy. All pipelines can be found in individual directories under src/diffusers/pipelines. One pipeline folder corresponds to one diffusion paper/project/release. Multiple pipeline files can be gathered in one pipeline folder, as it’s done for [`src/diffusers/pipelines/stable-diffusion`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/stable_diffusion). If pipelines share similar functionality, one can make use of the [# Copied from mechanism](https://github.com/huggingface/diffusers/blob/125d783076e5bd9785beb05367a2d2566843a271/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py#L251).
|
||||
- Pipelines follow the single-file policy. All pipelines can be found in individual directories under src/diffusers/pipelines. One pipeline folder corresponds to one diffusion paper/project/release. Multiple pipeline files can be gathered in one pipeline folder, as it’s done for [`src/diffusers/pipelines/stable-diffusion`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/stable_diffusion). If pipelines share similar functionality, one can make use of the [#Copied from mechanism](https://github.com/huggingface/diffusers/blob/125d783076e5bd9785beb05367a2d2566843a271/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py#L251).
|
||||
- Pipelines all inherit from [`DiffusionPipeline`].
|
||||
- Every pipeline consists of different model and scheduler components, that are documented in the [`model_index.json` file](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/model_index.json), are accessible under the same name as attributes of the pipeline and can be shared between pipelines with [`DiffusionPipeline.components`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.components) function.
|
||||
- Every pipeline should be loadable via the [`DiffusionPipeline.from_pretrained`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained) function.
|
||||
- Pipelines should be used **only** for inference.
|
||||
- Pipelines should be very readable, self-explanatory, and easy to tweak.
|
||||
- Pipelines should be designed to build on top of each other and be easy to integrate into higher-level APIs.
|
||||
- Pipelines are **not** intended to be feature-complete user interfaces. For feature-complete user interfaces one should rather have a look at [InvokeAI](https://github.com/invoke-ai/InvokeAI), [Diffuzers](https://github.com/abhishekkrthakur/diffuzers), and [lama-cleaner](https://github.com/Sanster/lama-cleaner).
|
||||
- Pipelines are **not** intended to be feature-complete user interfaces. For future complete user interfaces one should rather have a look at [InvokeAI](https://github.com/invoke-ai/InvokeAI), [Diffuzers](https://github.com/abhishekkrthakur/diffuzers), and [lama-cleaner](https://github.com/Sanster/lama-cleaner).
|
||||
- Every pipeline should have one and only one way to run it via a `__call__` method. The naming of the `__call__` arguments should be shared across all pipelines.
|
||||
- Pipelines should be named after the task they are intended to solve.
|
||||
- In almost all cases, novel diffusion pipelines shall be implemented in a new pipeline folder/file.
|
||||
@@ -81,7 +81,7 @@ Models are designed as configurable toolboxes that are natural extensions of [Py
|
||||
|
||||
The following design principles are followed:
|
||||
- Models correspond to **a type of model architecture**. *E.g.* the [`UNet2DConditionModel`] class is used for all UNet variations that expect 2D image inputs and are conditioned on some context.
|
||||
- All models can be found in [`src/diffusers/models`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models) and every model architecture shall be defined in its file, e.g. [`unets/unet_2d_condition.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unets/unet_2d_condition.py), [`transformers/transformer_2d.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_2d.py), etc...
|
||||
- All models can be found in [`src/diffusers/models`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models) and every model architecture shall be defined in its file, e.g. [`unet_2d_condition.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py), [`transformer_2d.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformer_2d.py), etc...
|
||||
- Models **do not** follow the single-file policy and should make use of smaller model building blocks, such as [`attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py), [`resnet.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py), [`embeddings.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py), etc... **Note**: This is in stark contrast to Transformers' modeling files and shows that models do not really follow the single-file policy.
|
||||
- Models intend to expose complexity, just like PyTorch's `Module` class, and give clear error messages.
|
||||
- Models all inherit from `ModelMixin` and `ConfigMixin`.
|
||||
@@ -90,7 +90,7 @@ The following design principles are followed:
|
||||
- To integrate new model checkpoints whose general architecture can be classified as an architecture that already exists in Diffusers, the existing model architecture shall be adapted to make it work with the new checkpoint. One should only create a new file if the model architecture is fundamentally different.
|
||||
- Models should be designed to be easily extendable to future changes. This can be achieved by limiting public function arguments, configuration arguments, and "foreseeing" future changes, *e.g.* it is usually better to add `string` "...type" arguments that can easily be extended to new future types instead of boolean `is_..._type` arguments. Only the minimum amount of changes shall be made to existing architectures to make a new model checkpoint work.
|
||||
- The model design is a difficult trade-off between keeping code readable and concise and supporting many model checkpoints. For most parts of the modeling code, classes shall be adapted for new model checkpoints, while there are some exceptions where it is preferred to add new classes to make sure the code is kept concise and
|
||||
readable long-term, such as [UNet blocks](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unets/unet_2d_blocks.py) and [Attention processors](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
readable long-term, such as [UNet blocks](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py) and [Attention processors](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
|
||||
### Schedulers
|
||||
|
||||
@@ -100,11 +100,11 @@ The following design principles are followed:
|
||||
- All schedulers are found in [`src/diffusers/schedulers`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers).
|
||||
- Schedulers are **not** allowed to import from large utils files and shall be kept very self-contained.
|
||||
- One scheduler Python file corresponds to one scheduler algorithm (as might be defined in a paper).
|
||||
- If schedulers share similar functionalities, we can make use of the `# Copied from` mechanism.
|
||||
- If schedulers share similar functionalities, we can make use of the `#Copied from` mechanism.
|
||||
- Schedulers all inherit from `SchedulerMixin` and `ConfigMixin`.
|
||||
- Schedulers can be easily swapped out with the [`ConfigMixin.from_config`](https://huggingface.co/docs/diffusers/main/en/api/configuration#diffusers.ConfigMixin.from_config) method as explained in detail [here](./docs/source/en/using-diffusers/schedulers.md).
|
||||
- Every scheduler has to have a `set_num_inference_steps`, and a `step` function. `set_num_inference_steps(...)` has to be called before every denoising process, *i.e.* before `step(...)` is called.
|
||||
- Every scheduler exposes the timesteps to be "looped over" via a `timesteps` attribute, which is an array of timesteps the model will be called upon.
|
||||
- The `step(...)` function takes a predicted model output and the "current" sample (x_t) and returns the "previous", slightly more denoised sample (x_t-1).
|
||||
- Given the complexity of diffusion schedulers, the `step` function does not expose all the complexity and can be a bit of a "black box".
|
||||
- In almost all cases, novel schedulers shall be implemented in a new scheduling file.
|
||||
- In almost all cases, novel schedulers shall be implemented in a new scheduling file.
|
||||
|
||||
@@ -67,7 +67,7 @@ Please refer to the [How to use Stable Diffusion in Apple Silicon](https://huggi
|
||||
|
||||
## Quickstart
|
||||
|
||||
Generating outputs is super easy with 🤗 Diffusers. To generate an image from text, use the `from_pretrained` method to load any pretrained diffusion model (browse the [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) for 27.000+ checkpoints):
|
||||
Generating outputs is super easy with 🤗 Diffusers. To generate an image from text, use the `from_pretrained` method to load any pretrained diffusion model (browse the [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) for 25.000+ checkpoints):
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
@@ -209,7 +209,7 @@ Also, say 👋 in our public Discord channel <a href="https://discord.gg/G7tWnz9
|
||||
- https://github.com/deep-floyd/IF
|
||||
- https://github.com/bentoml/BentoML
|
||||
- https://github.com/bmaltais/kohya_ss
|
||||
- +12.000 other amazing GitHub repositories 💪
|
||||
- +11.000 other amazing GitHub repositories 💪
|
||||
|
||||
Thank you for using us ❤️.
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ def main():
|
||||
print(f"****** Running file: {file} ******")
|
||||
|
||||
# Run with canonical settings.
|
||||
if file != "benchmark_text_to_image.py" and file != "benchmark_ip_adapters.py":
|
||||
if file != "benchmark_text_to_image.py":
|
||||
command = f"python {file}"
|
||||
run_command(command.split())
|
||||
|
||||
@@ -49,10 +49,6 @@ def main():
|
||||
|
||||
# Run variants.
|
||||
for file in python_files:
|
||||
# See: https://github.com/pytorch/pytorch/issues/129637
|
||||
if file == "benchmark_ip_adapters.py":
|
||||
continue
|
||||
|
||||
if file == "benchmark_text_to_image.py":
|
||||
for ckpt in ALL_T2I_CKPTS:
|
||||
command = f"python {file} --ckpt {ckpt}"
|
||||
|
||||
@@ -17,7 +17,6 @@ RUN apt install -y bash \
|
||||
libsndfile1-dev \
|
||||
libgl1 \
|
||||
python3.10 \
|
||||
python3.10-dev \
|
||||
python3-pip \
|
||||
python3.10-venv && \
|
||||
rm -rf /var/lib/apt/lists
|
||||
|
||||
@@ -16,7 +16,6 @@ RUN apt install -y bash \
|
||||
ca-certificates \
|
||||
libsndfile1-dev \
|
||||
python3.10 \
|
||||
python3.10-dev \
|
||||
python3-pip \
|
||||
libgl1 \
|
||||
python3.10-venv && \
|
||||
|
||||
@@ -17,7 +17,6 @@ RUN apt install -y bash \
|
||||
libsndfile1-dev \
|
||||
libgl1 \
|
||||
python3.10 \
|
||||
python3.10-dev \
|
||||
python3-pip \
|
||||
python3.10-venv && \
|
||||
rm -rf /var/lib/apt/lists
|
||||
|
||||
@@ -17,7 +17,6 @@ RUN apt install -y bash \
|
||||
libsndfile1-dev \
|
||||
libgl1 \
|
||||
python3.10 \
|
||||
python3.10-dev \
|
||||
python3-pip \
|
||||
python3.10-venv && \
|
||||
rm -rf /var/lib/apt/lists
|
||||
|
||||
@@ -21,8 +21,6 @@
|
||||
title: Load LoRAs for inference
|
||||
- local: tutorials/fast_diffusion
|
||||
title: Accelerate inference of text-to-image diffusion models
|
||||
- local: tutorials/inference_with_big_models
|
||||
title: Working with big models
|
||||
title: Tutorials
|
||||
- sections:
|
||||
- local: using-diffusers/loading
|
||||
@@ -83,8 +81,6 @@
|
||||
title: Kandinsky
|
||||
- local: using-diffusers/ip_adapter
|
||||
title: IP-Adapter
|
||||
- local: using-diffusers/pag
|
||||
title: PAG
|
||||
- local: using-diffusers/controlnet
|
||||
title: ControlNet
|
||||
- local: using-diffusers/t2i_adapter
|
||||
@@ -249,12 +245,6 @@
|
||||
title: DiTTransformer2DModel
|
||||
- local: api/models/hunyuan_transformer2d
|
||||
title: HunyuanDiT2DModel
|
||||
- local: api/models/aura_flow_transformer2d
|
||||
title: AuraFlowTransformer2DModel
|
||||
- local: api/models/latte_transformer3d
|
||||
title: LatteTransformer3DModel
|
||||
- local: api/models/lumina_nextdit2d
|
||||
title: LuminaNextDiT2DModel
|
||||
- local: api/models/transformer_temporal
|
||||
title: TransformerTemporalModel
|
||||
- local: api/models/sd3_transformer2d
|
||||
@@ -263,8 +253,6 @@
|
||||
title: PriorTransformer
|
||||
- local: api/models/controlnet
|
||||
title: ControlNetModel
|
||||
- local: api/models/controlnet_hunyuandit
|
||||
title: HunyuanDiT2DControlNetModel
|
||||
- local: api/models/controlnet_sd3
|
||||
title: SD3ControlNetModel
|
||||
title: Models
|
||||
@@ -282,8 +270,6 @@
|
||||
title: AudioLDM
|
||||
- local: api/pipelines/audioldm2
|
||||
title: AudioLDM 2
|
||||
- local: api/pipelines/aura_flow
|
||||
title: AuraFlow
|
||||
- local: api/pipelines/auto_pipeline
|
||||
title: AutoPipeline
|
||||
- local: api/pipelines/blip_diffusion
|
||||
@@ -292,8 +278,6 @@
|
||||
title: Consistency Models
|
||||
- local: api/pipelines/controlnet
|
||||
title: ControlNet
|
||||
- local: api/pipelines/controlnet_hunyuandit
|
||||
title: ControlNet with Hunyuan-DiT
|
||||
- local: api/pipelines/controlnet_sd3
|
||||
title: ControlNet with Stable Diffusion 3
|
||||
- local: api/pipelines/controlnet_sdxl
|
||||
@@ -326,26 +310,18 @@
|
||||
title: Kandinsky 2.2
|
||||
- local: api/pipelines/kandinsky3
|
||||
title: Kandinsky 3
|
||||
- local: api/pipelines/kolors
|
||||
title: Kolors
|
||||
- local: api/pipelines/latent_consistency_models
|
||||
title: Latent Consistency Models
|
||||
- local: api/pipelines/latent_diffusion
|
||||
title: Latent Diffusion
|
||||
- local: api/pipelines/latte
|
||||
title: Latte
|
||||
- local: api/pipelines/ledits_pp
|
||||
title: LEDITS++
|
||||
- local: api/pipelines/lumina
|
||||
title: Lumina-T2X
|
||||
- local: api/pipelines/marigold
|
||||
title: Marigold
|
||||
- local: api/pipelines/panorama
|
||||
title: MultiDiffusion
|
||||
- local: api/pipelines/musicldm
|
||||
title: MusicLDM
|
||||
- local: api/pipelines/pag
|
||||
title: PAG
|
||||
- local: api/pipelines/paint_by_example
|
||||
title: Paint by Example
|
||||
- local: api/pipelines/pia
|
||||
@@ -449,8 +425,6 @@
|
||||
title: EulerDiscreteScheduler
|
||||
- local: api/schedulers/flow_match_euler_discrete
|
||||
title: FlowMatchEulerDiscreteScheduler
|
||||
- local: api/schedulers/flow_match_heun_discrete
|
||||
title: FlowMatchHeunDiscreteScheduler
|
||||
- local: api/schedulers/heun
|
||||
title: HeunDiscreteScheduler
|
||||
- local: api/schedulers/ipndm
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# AuraFlowTransformer2DModel
|
||||
|
||||
A Transformer model for image-like data from [AuraFlow](https://blog.fal.ai/auraflow/).
|
||||
|
||||
## AuraFlowTransformer2DModel
|
||||
|
||||
[[autodoc]] AuraFlowTransformer2DModel
|
||||
@@ -21,7 +21,7 @@ The abstract from the paper is:
|
||||
## Loading from the original format
|
||||
|
||||
By default the [`AutoencoderKL`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded
|
||||
from the original format using [`FromOriginalModelMixin.from_single_file`] as follows:
|
||||
from the original format using [`FromOriginalVAEMixin.from_single_file`] as follows:
|
||||
|
||||
```py
|
||||
from diffusers import AutoencoderKL
|
||||
|
||||
@@ -21,7 +21,7 @@ The abstract from the paper is:
|
||||
## Loading from the original format
|
||||
|
||||
By default the [`ControlNetModel`] should be loaded with [`~ModelMixin.from_pretrained`], but it can also be loaded
|
||||
from the original format using [`FromOriginalModelMixin.from_single_file`] as follows:
|
||||
from the original format using [`FromOriginalControlnetMixin.from_single_file`] as follows:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
<!--Copyright 2024 The HuggingFace Team and Tencent Hunyuan Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# HunyuanDiT2DControlNetModel
|
||||
|
||||
HunyuanDiT2DControlNetModel is an implementation of ControlNet for [Hunyuan-DiT](https://arxiv.org/abs/2405.08748).
|
||||
|
||||
ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala.
|
||||
|
||||
With a ControlNet model, you can provide an additional control image to condition and control Hunyuan-DiT generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*
|
||||
|
||||
This code is implemented by Tencent Hunyuan Team. You can find pre-trained checkpoints for Hunyuan-DiT ControlNets on [Tencent Hunyuan](https://huggingface.co/Tencent-Hunyuan).
|
||||
|
||||
## Example For Loading HunyuanDiT2DControlNetModel
|
||||
|
||||
```py
|
||||
from diffusers import HunyuanDiT2DControlNetModel
|
||||
import torch
|
||||
controlnet = HunyuanDiT2DControlNetModel.from_pretrained("Tencent-Hunyuan/HunyuanDiT-v1.1-ControlNet-Diffusers-Pose", torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
## HunyuanDiT2DControlNetModel
|
||||
|
||||
[[autodoc]] HunyuanDiT2DControlNetModel
|
||||
@@ -1,19 +0,0 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
## LatteTransformer3DModel
|
||||
|
||||
A Diffusion Transformer model for 3D data from [Latte](https://github.com/Vchitect/Latte).
|
||||
|
||||
## LatteTransformer3DModel
|
||||
|
||||
[[autodoc]] LatteTransformer3DModel
|
||||
@@ -1,20 +0,0 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# LuminaNextDiT2DModel
|
||||
|
||||
A Next Version of Diffusion Transformer model for 2D data from [Lumina-T2X](https://github.com/Alpha-VLLM/Lumina-T2X).
|
||||
|
||||
## LuminaNextDiT2DModel
|
||||
|
||||
[[autodoc]] LuminaNextDiT2DModel
|
||||
|
||||
@@ -38,4 +38,4 @@ It is assumed one of the input classes is the masked latent pixel. The predicted
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
|
||||
[[autodoc]] models.transformers.transformer_2d.Transformer2DModelOutput
|
||||
|
||||
@@ -560,20 +560,6 @@ export_to_gif(frames, "animatelcm-motion-lora.gif")
|
||||
</table>
|
||||
|
||||
|
||||
## Using `from_single_file` with the MotionAdapter
|
||||
|
||||
`diffusers>=0.30.0` supports loading the AnimateDiff checkpoints into the `MotionAdapter` in their original format via `from_single_file`
|
||||
|
||||
```python
|
||||
from diffusers import MotionAdapter
|
||||
|
||||
ckpt_path = "https://huggingface.co/Lightricks/LongAnimateDiff/blob/main/lt_long_mm_32_frames.ckpt"
|
||||
|
||||
adapter = MotionAdapter.from_single_file(ckpt_path, torch_dtype=torch.float16)
|
||||
pipe = AnimateDiffPipeline.from_pretrained("emilianJR/epiCRealism", motion_adapter=adapter)
|
||||
|
||||
```
|
||||
|
||||
## AnimateDiffPipeline
|
||||
|
||||
[[autodoc]] AnimateDiffPipeline
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# AuraFlow
|
||||
|
||||
AuraFlow is inspired by [Stable Diffusion 3](../pipelines/stable_diffusion/stable_diffusion_3.md) and is by far the largest text-to-image generation model that comes with an Apache 2.0 license. This model achieves state-of-the-art results on the [GenEval](https://github.com/djghosh13/geneval) benchmark.
|
||||
|
||||
It was developed by the Fal team and more details about it can be found in [this blog post](https://blog.fal.ai/auraflow/).
|
||||
|
||||
<Tip>
|
||||
|
||||
AuraFlow can be quite expensive to run on consumer hardware devices. However, you can perform a suite of optimizations to run it faster and in a more memory-friendly manner. Check out [this section](https://huggingface.co/blog/sd3#memory-optimizations-for-sd3) for more details.
|
||||
|
||||
</Tip>
|
||||
|
||||
## AuraFlowPipeline
|
||||
|
||||
[[autodoc]] AuraFlowPipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -1,36 +0,0 @@
|
||||
<!--Copyright 2024 The HuggingFace Team and Tencent Hunyuan Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# ControlNet with Hunyuan-DiT
|
||||
|
||||
HunyuanDiTControlNetPipeline is an implementation of ControlNet for [Hunyuan-DiT](https://arxiv.org/abs/2405.08748).
|
||||
|
||||
ControlNet was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang, Anyi Rao, and Maneesh Agrawala.
|
||||
|
||||
With a ControlNet model, you can provide an additional control image to condition and control Hunyuan-DiT generation. For example, if you provide a depth map, the ControlNet model generates an image that'll preserve the spatial information from the depth map. It is a more flexible and accurate way to control the image generation process.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*We present ControlNet, a neural network architecture to add spatial conditioning controls to large, pretrained text-to-image diffusion models. ControlNet locks the production-ready large diffusion models, and reuses their deep and robust encoding layers pretrained with billions of images as a strong backbone to learn a diverse set of conditional controls. The neural architecture is connected with "zero convolutions" (zero-initialized convolution layers) that progressively grow the parameters from zero and ensure that no harmful noise could affect the finetuning. We test various conditioning controls, eg, edges, depth, segmentation, human pose, etc, with Stable Diffusion, using single or multiple conditions, with or without prompts. We show that the training of ControlNets is robust with small (<50k) and large (>1m) datasets. Extensive results show that ControlNet may facilitate wider applications to control image diffusion models.*
|
||||
|
||||
This code is implemented by Tencent Hunyuan Team. You can find pre-trained checkpoints for Hunyuan-DiT ControlNets on [Tencent Hunyuan](https://huggingface.co/Tencent-Hunyuan).
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
## HunyuanDiTControlNetPipeline
|
||||
[[autodoc]] HunyuanDiTControlNetPipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -1,4 +1,4 @@
|
||||
<!--Copyright 2024 The HuggingFace Team and Tencent Hunyuan Team. All rights reserved.
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
@@ -34,12 +34,6 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.m
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip>
|
||||
|
||||
You can further improve generation quality by passing the generated image from [`HungyuanDiTPipeline`] to the [SDXL refiner](../../using-diffusers/sdxl#base-to-refiner-model) model.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Optimization
|
||||
|
||||
You can optimize the pipeline's runtime and memory consumption with torch.compile and feed-forward chunking. To learn about other optimization methods, check out the [Speed up inference](../../optimization/fp16) and [Reduce memory usage](../../optimization/memory) guides.
|
||||
|
||||
@@ -11,7 +11,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
Kandinsky 3 is created by [Vladimir Arkhipkin](https://github.com/oriBetelgeuse),[Anastasia Maltseva](https://github.com/NastyaMittseva),[Igor Pavlov](https://github.com/boomb0om),[Andrei Filatov](https://github.com/anvilarth),[Arseniy Shakhmatov](https://github.com/cene555),[Andrey Kuznetsov](https://github.com/kuznetsoffandrey),[Denis Dimitrov](https://github.com/denndimitrov), [Zein Shaheen](https://github.com/zeinsh)
|
||||
|
||||
The description from it's GitHub page:
|
||||
The description from it's Github page:
|
||||
|
||||
*Kandinsky 3.0 is an open-source text-to-image diffusion model built upon the Kandinsky2-x model family. In comparison to its predecessors, enhancements have been made to the text understanding and visual quality of the model, achieved by increasing the size of the text encoder and Diffusion U-Net models, respectively.*
|
||||
|
||||
|
||||
@@ -1,49 +0,0 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Kolors: Effective Training of Diffusion Model for Photorealistic Text-to-Image Synthesis
|
||||
|
||||

|
||||
|
||||
Kolors is a large-scale text-to-image generation model based on latent diffusion, developed by [the Kuaishou Kolors team](kwai-kolors@kuaishou.com). Trained on billions of text-image pairs, Kolors exhibits significant advantages over both open-source and closed-source models in visual quality, complex semantic accuracy, and text rendering for both Chinese and English characters. Furthermore, Kolors supports both Chinese and English inputs, demonstrating strong performance in understanding and generating Chinese-specific content. For more details, please refer to this [technical report](https://github.com/Kwai-Kolors/Kolors/blob/master/imgs/Kolors_paper.pdf).
|
||||
|
||||
The abstract from the technical report is:
|
||||
|
||||
*We present Kolors, a latent diffusion model for text-to-image synthesis, characterized by its profound understanding of both English and Chinese, as well as an impressive degree of photorealism. There are three key insights contributing to the development of Kolors. Firstly, unlike large language model T5 used in Imagen and Stable Diffusion 3, Kolors is built upon the General Language Model (GLM), which enhances its comprehension capabilities in both English and Chinese. Moreover, we employ a multimodal large language model to recaption the extensive training dataset for fine-grained text understanding. These strategies significantly improve Kolors’ ability to comprehend intricate semantics, particularly those involving multiple entities, and enable its advanced text rendering capabilities. Secondly, we divide the training of Kolors into two phases: the concept learning phase with broad knowledge and the quality improvement phase with specifically curated high-aesthetic data. Furthermore, we investigate the critical role of the noise schedule and introduce a novel schedule to optimize high-resolution image generation. These strategies collectively enhance the visual appeal of the generated high-resolution images. Lastly, we propose a category-balanced benchmark KolorsPrompts, which serves as a guide for the training and evaluation of Kolors. Consequently, even when employing the commonly used U-Net backbone, Kolors has demonstrated remarkable performance in human evaluations, surpassing the existing open-source models and achieving Midjourney-v6 level performance, especially in terms of visual appeal. We will release the code and weights of Kolors at <https://github.com/Kwai-Kolors/Kolors>, and hope that it will benefit future research and applications in the visual generation community.*
|
||||
|
||||
## Usage Example
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
from diffusers import DPMSolverMultistepScheduler, KolorsPipeline
|
||||
|
||||
pipe = KolorsPipeline.from_pretrained("Kwai-Kolors/Kolors-diffusers", torch_dtype=torch.float16, variant="fp16")
|
||||
pipe.to("cuda")
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)
|
||||
|
||||
image = pipe(
|
||||
prompt='一张瓢虫的照片,微距,变焦,高质量,电影,拿着一个牌子,写着"可图"',
|
||||
negative_prompt="",
|
||||
guidance_scale=6.5,
|
||||
num_inference_steps=25,
|
||||
).images[0]
|
||||
|
||||
image.save("kolors_sample.png")
|
||||
```
|
||||
|
||||
## KolorsPipeline
|
||||
|
||||
[[autodoc]] KolorsPipeline
|
||||
|
||||
- all
|
||||
- __call__
|
||||
@@ -1,75 +0,0 @@
|
||||
<!-- # Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# Latte
|
||||
|
||||

|
||||
|
||||
[Latte: Latent Diffusion Transformer for Video Generation](https://arxiv.org/abs/2401.03048) from Monash University, Shanghai AI Lab, Nanjing University, and Nanyang Technological University.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*We propose a novel Latent Diffusion Transformer, namely Latte, for video generation. Latte first extracts spatio-temporal tokens from input videos and then adopts a series of Transformer blocks to model video distribution in the latent space. In order to model a substantial number of tokens extracted from videos, four efficient variants are introduced from the perspective of decomposing the spatial and temporal dimensions of input videos. To improve the quality of generated videos, we determine the best practices of Latte through rigorous experimental analysis, including video clip patch embedding, model variants, timestep-class information injection, temporal positional embedding, and learning strategies. Our comprehensive evaluation demonstrates that Latte achieves state-of-the-art performance across four standard video generation datasets, i.e., FaceForensics, SkyTimelapse, UCF101, and Taichi-HD. In addition, we extend Latte to text-to-video generation (T2V) task, where Latte achieves comparable results compared to recent T2V models. We strongly believe that Latte provides valuable insights for future research on incorporating Transformers into diffusion models for video generation.*
|
||||
|
||||
**Highlights**: Latte is a latent diffusion transformer proposed as a backbone for modeling different modalities (trained for text-to-video generation here). It achieves state-of-the-art performance across four standard video benchmarks - [FaceForensics](https://arxiv.org/abs/1803.09179), [SkyTimelapse](https://arxiv.org/abs/1709.07592), [UCF101](https://arxiv.org/abs/1212.0402) and [Taichi-HD](https://arxiv.org/abs/2003.00196). To prepare and download the datasets for evaluation, please refer to [this https URL](https://github.com/Vchitect/Latte/blob/main/docs/datasets_evaluation.md).
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
### Inference
|
||||
|
||||
Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency.
|
||||
|
||||
First, load the pipeline:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import LattePipeline
|
||||
|
||||
pipeline = LattePipeline.from_pretrained(
|
||||
"maxin-cn/Latte-1", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
```
|
||||
|
||||
Then change the memory layout of the pipelines `transformer` and `vae` components to `torch.channels-last`:
|
||||
|
||||
```python
|
||||
pipeline.transformer.to(memory_format=torch.channels_last)
|
||||
pipeline.vae.to(memory_format=torch.channels_last)
|
||||
```
|
||||
|
||||
Finally, compile the components and run inference:
|
||||
|
||||
```python
|
||||
pipeline.transformer = torch.compile(pipeline.transformer)
|
||||
pipeline.vae.decode = torch.compile(pipeline.vae.decode)
|
||||
|
||||
video = pipeline(prompt="A dog wearing sunglasses floating in space, surreal, nebulae in background").frames[0]
|
||||
```
|
||||
|
||||
The [benchmark](https://gist.github.com/a-r-r-o-w/4e1694ca46374793c0361d740a99ff19) results on an 80GB A100 machine are:
|
||||
|
||||
```
|
||||
Without torch.compile(): Average inference time: 16.246 seconds.
|
||||
With torch.compile(): Average inference time: 14.573 seconds.
|
||||
```
|
||||
|
||||
## LattePipeline
|
||||
|
||||
[[autodoc]] LattePipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -1,88 +0,0 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Lumina-T2X
|
||||

|
||||
|
||||
[Lumina-Next : Making Lumina-T2X Stronger and Faster with Next-DiT](https://github.com/Alpha-VLLM/Lumina-T2X/blob/main/assets/lumina-next.pdf) from Alpha-VLLM, OpenGVLab, Shanghai AI Laboratory.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*Lumina-T2X is a nascent family of Flow-based Large Diffusion Transformers (Flag-DiT) that establishes a unified framework for transforming noise into various modalities, such as images and videos, conditioned on text instructions. Despite its promising capabilities, Lumina-T2X still encounters challenges including training instability, slow inference, and extrapolation artifacts. In this paper, we present Lumina-Next, an improved version of Lumina-T2X, showcasing stronger generation performance with increased training and inference efficiency. We begin with a comprehensive analysis of the Flag-DiT architecture and identify several suboptimal components, which we address by introducing the Next-DiT architecture with 3D RoPE and sandwich normalizations. To enable better resolution extrapolation, we thoroughly compare different context extrapolation methods applied to text-to-image generation with 3D RoPE, and propose Frequency- and Time-Aware Scaled RoPE tailored for diffusion transformers. Additionally, we introduce a sigmoid time discretization schedule to reduce sampling steps in solving the Flow ODE and the Context Drop method to merge redundant visual tokens for faster network evaluation, effectively boosting the overall sampling speed. Thanks to these improvements, Lumina-Next not only improves the quality and efficiency of basic text-to-image generation but also demonstrates superior resolution extrapolation capabilities and multilingual generation using decoder-based LLMs as the text encoder, all in a zero-shot manner. To further validate Lumina-Next as a versatile generative framework, we instantiate it on diverse tasks including visual recognition, multi-view, audio, music, and point cloud generation, showcasing strong performance across these domains. By releasing all codes and model weights at https://github.com/Alpha-VLLM/Lumina-T2X, we aim to advance the development of next-generation generative AI capable of universal modeling.*
|
||||
|
||||
**Highlights**: Lumina-Next is a next-generation Diffusion Transformer that significantly enhances text-to-image generation, multilingual generation, and multitask performance by introducing the Next-DiT architecture, 3D RoPE, and frequency- and time-aware RoPE, among other improvements.
|
||||
|
||||
Lumina-Next has the following components:
|
||||
* It improves sampling efficiency with fewer and faster Steps.
|
||||
* It uses a Next-DiT as a transformer backbone with Sandwichnorm 3D RoPE, and Grouped-Query Attention.
|
||||
* It uses a Frequency- and Time-Aware Scaled RoPE.
|
||||
|
||||
---
|
||||
|
||||
[Lumina-T2X: Transforming Text into Any Modality, Resolution, and Duration via Flow-based Large Diffusion Transformers](https://arxiv.org/abs/2405.05945) from Alpha-VLLM, OpenGVLab, Shanghai AI Laboratory.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*Sora unveils the potential of scaling Diffusion Transformer for generating photorealistic images and videos at arbitrary resolutions, aspect ratios, and durations, yet it still lacks sufficient implementation details. In this technical report, we introduce the Lumina-T2X family - a series of Flow-based Large Diffusion Transformers (Flag-DiT) equipped with zero-initialized attention, as a unified framework designed to transform noise into images, videos, multi-view 3D objects, and audio clips conditioned on text instructions. By tokenizing the latent spatial-temporal space and incorporating learnable placeholders such as [nextline] and [nextframe] tokens, Lumina-T2X seamlessly unifies the representations of different modalities across various spatial-temporal resolutions. This unified approach enables training within a single framework for different modalities and allows for flexible generation of multimodal data at any resolution, aspect ratio, and length during inference. Advanced techniques like RoPE, RMSNorm, and flow matching enhance the stability, flexibility, and scalability of Flag-DiT, enabling models of Lumina-T2X to scale up to 7 billion parameters and extend the context window to 128K tokens. This is particularly beneficial for creating ultra-high-definition images with our Lumina-T2I model and long 720p videos with our Lumina-T2V model. Remarkably, Lumina-T2I, powered by a 5-billion-parameter Flag-DiT, requires only 35% of the training computational costs of a 600-million-parameter naive DiT. Our further comprehensive analysis underscores Lumina-T2X's preliminary capability in resolution extrapolation, high-resolution editing, generating consistent 3D views, and synthesizing videos with seamless transitions. We expect that the open-sourcing of Lumina-T2X will further foster creativity, transparency, and diversity in the generative AI community.*
|
||||
|
||||
|
||||
You can find the original codebase at [Alpha-VLLM](https://github.com/Alpha-VLLM/Lumina-T2X) and all the available checkpoints at [Alpha-VLLM Lumina Family](https://huggingface.co/collections/Alpha-VLLM/lumina-family-66423205bedb81171fd0644b).
|
||||
|
||||
**Highlights**: Lumina-T2X supports Any Modality, Resolution, and Duration.
|
||||
|
||||
Lumina-T2X has the following components:
|
||||
* It uses a Flow-based Large Diffusion Transformer as the backbone
|
||||
* It supports different any modalities with one backbone and corresponding encoder, decoder.
|
||||
|
||||
<Tip>
|
||||
|
||||
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
</Tip>
|
||||
|
||||
### Inference (Text-to-Image)
|
||||
|
||||
Use [`torch.compile`](https://huggingface.co/docs/diffusers/main/en/tutorials/fast_diffusion#torchcompile) to reduce the inference latency.
|
||||
|
||||
First, load the pipeline:
|
||||
|
||||
```python
|
||||
from diffusers import LuminaText2ImgPipeline
|
||||
import torch
|
||||
|
||||
pipeline = LuminaText2ImgPipeline.from_pretrained(
|
||||
"Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
```
|
||||
|
||||
Then change the memory layout of the pipelines `transformer` and `vae` components to `torch.channels-last`:
|
||||
|
||||
```python
|
||||
pipeline.transformer.to(memory_format=torch.channels_last)
|
||||
pipeline.vae.to(memory_format=torch.channels_last)
|
||||
```
|
||||
|
||||
Finally, compile the components and run inference:
|
||||
|
||||
```python
|
||||
pipeline.transformer = torch.compile(pipeline.transformer, mode="max-autotune", fullgraph=True)
|
||||
pipeline.vae.decode = torch.compile(pipeline.vae.decode, mode="max-autotune", fullgraph=True)
|
||||
|
||||
image = pipeline(prompt="Upper body of a young woman in a Victorian-era outfit with brass goggles and leather straps. Background shows an industrial revolution cityscape with smoky skies and tall, metal structures").images[0]
|
||||
```
|
||||
|
||||
## LuminaText2ImgPipeline
|
||||
|
||||
[[autodoc]] LuminaText2ImgPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Perturbed-Attention Guidance
|
||||
|
||||
[Perturbed-Attention Guidance (PAG)](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) is a new diffusion sampling guidance that improves sample quality across both unconditional and conditional settings, achieving this without requiring further training or the integration of external modules.
|
||||
|
||||
PAG was introduced in [Self-Rectifying Diffusion Sampling with Perturbed-Attention Guidance](https://huggingface.co/papers/2403.17377) by Donghoon Ahn, Hyoungwon Cho, Jaewon Min, Wooseok Jang, Jungwoo Kim, SeonHwa Kim, Hyun Hee Park, Kyong Hwan Jin and Seungryong Kim.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*Recent studies have demonstrated that diffusion models are capable of generating high-quality samples, but their quality heavily depends on sampling guidance techniques, such as classifier guidance (CG) and classifier-free guidance (CFG). These techniques are often not applicable in unconditional generation or in various downstream tasks such as image restoration. In this paper, we propose a novel sampling guidance, called Perturbed-Attention Guidance (PAG), which improves diffusion sample quality across both unconditional and conditional settings, achieving this without requiring additional training or the integration of external modules. PAG is designed to progressively enhance the structure of samples throughout the denoising process. It involves generating intermediate samples with degraded structure by substituting selected self-attention maps in diffusion U-Net with an identity matrix, by considering the self-attention mechanisms' ability to capture structural information, and guiding the denoising process away from these degraded samples. In both ADM and Stable Diffusion, PAG surprisingly improves sample quality in conditional and even unconditional scenarios. Moreover, PAG significantly improves the baseline performance in various downstream tasks where existing guidances such as CG or CFG cannot be fully utilized, including ControlNet with empty prompts and image restoration such as inpainting and deblurring.*
|
||||
|
||||
## StableDiffusionPAGPipeline
|
||||
[[autodoc]] StableDiffusionPAGPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusionControlNetPAGPipeline
|
||||
[[autodoc]] StableDiffusionControlNetPAGPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusionXLPAGPipeline
|
||||
[[autodoc]] StableDiffusionXLPAGPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusionXLPAGImg2ImgPipeline
|
||||
[[autodoc]] StableDiffusionXLPAGImg2ImgPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusionXLPAGInpaintPipeline
|
||||
[[autodoc]] StableDiffusionXLPAGInpaintPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## StableDiffusionXLControlNetPAGPipeline
|
||||
[[autodoc]] StableDiffusionXLControlNetPAGPipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -37,12 +37,6 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip>
|
||||
|
||||
You can further improve generation quality by passing the generated image from [`PixArtSigmaPipeline`] to the [SDXL refiner](../../using-diffusers/sdxl#base-to-refiner-model) model.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Inference with under 8GB GPU VRAM
|
||||
|
||||
Run the [`PixArtSigmaPipeline`] with under 8GB GPU VRAM by loading the text encoder in 8-bit precision. Let's walk through a full-fledged example.
|
||||
|
||||
@@ -48,7 +48,7 @@ from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
|
||||
import torch
|
||||
|
||||
repo_id = "stabilityai/stable-diffusion-2-base"
|
||||
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, variant="fp16")
|
||||
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, revision="fp16")
|
||||
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
pipe = pipe.to("cuda")
|
||||
@@ -72,7 +72,7 @@ init_image = load_image(img_url).resize((512, 512))
|
||||
mask_image = load_image(mask_url).resize((512, 512))
|
||||
|
||||
repo_id = "stabilityai/stable-diffusion-2-inpainting"
|
||||
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, variant="fp16")
|
||||
pipe = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.float16, revision="fp16")
|
||||
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
@@ -35,6 +35,7 @@ The SD3 pipeline uses three text encoders to generate an image. Model offloading
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import StableDiffusion3Pipeline
|
||||
@@ -196,47 +197,6 @@ image.save("sd3_hello_world.png")
|
||||
|
||||
Check out the full script [here](https://gist.github.com/sayakpaul/508d89d7aad4f454900813da5d42ca97).
|
||||
|
||||
## Using Long Prompts with the T5 Text Encoder
|
||||
|
||||
By default, the T5 Text Encoder prompt uses a maximum sequence length of `256`. This can be adjusted by setting the `max_sequence_length` to accept fewer or more tokens. Keep in mind that longer sequences require additional resources and result in longer generation times, such as during batch inference.
|
||||
|
||||
```python
|
||||
prompt = "A whimsical and creative image depicting a hybrid creature that is a mix of a waffle and a hippopotamus, basking in a river of melted butter amidst a breakfast-themed landscape. It features the distinctive, bulky body shape of a hippo. However, instead of the usual grey skin, the creature’s body resembles a golden-brown, crispy waffle fresh off the griddle. The skin is textured with the familiar grid pattern of a waffle, each square filled with a glistening sheen of syrup. The environment combines the natural habitat of a hippo with elements of a breakfast table setting, a river of warm, melted butter, with oversized utensils or plates peeking out from the lush, pancake-like foliage in the background, a towering pepper mill standing in for a tree. As the sun rises in this fantastical world, it casts a warm, buttery glow over the scene. The creature, content in its butter river, lets out a yawn. Nearby, a flock of birds take flight"
|
||||
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt="",
|
||||
num_inference_steps=28,
|
||||
guidance_scale=4.5,
|
||||
max_sequence_length=512,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
### Sending a different prompt to the T5 Text Encoder
|
||||
|
||||
You can send a different prompt to the CLIP Text Encoders and the T5 Text Encoder to prevent the prompt from being truncated by the CLIP Text Encoders and to improve generation.
|
||||
|
||||
<Tip>
|
||||
|
||||
The prompt with the CLIP Text Encoders is still truncated to the 77 token limit.
|
||||
|
||||
</Tip>
|
||||
|
||||
```python
|
||||
prompt = "A whimsical and creative image depicting a hybrid creature that is a mix of a waffle and a hippopotamus, basking in a river of melted butter amidst a breakfast-themed landscape. A river of warm, melted butter, pancake-like foliage in the background, a towering pepper mill standing in for a tree."
|
||||
|
||||
prompt_3 = "A whimsical and creative image depicting a hybrid creature that is a mix of a waffle and a hippopotamus, basking in a river of melted butter amidst a breakfast-themed landscape. It features the distinctive, bulky body shape of a hippo. However, instead of the usual grey skin, the creature’s body resembles a golden-brown, crispy waffle fresh off the griddle. The skin is textured with the familiar grid pattern of a waffle, each square filled with a glistening sheen of syrup. The environment combines the natural habitat of a hippo with elements of a breakfast table setting, a river of warm, melted butter, with oversized utensils or plates peeking out from the lush, pancake-like foliage in the background, a towering pepper mill standing in for a tree. As the sun rises in this fantastical world, it casts a warm, buttery glow over the scene. The creature, content in its butter river, lets out a yawn. Nearby, a flock of birds take flight"
|
||||
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
prompt_3=prompt_3,
|
||||
negative_prompt="",
|
||||
num_inference_steps=28,
|
||||
guidance_scale=4.5,
|
||||
max_sequence_length=512,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
## Tiny AutoEncoder for Stable Diffusion 3
|
||||
|
||||
Tiny AutoEncoder for Stable Diffusion (TAESD3) is a tiny distilled version of Stable Diffusion 3's VAE by [Ollin Boer Bohan](https://github.com/madebyollin/taesd) that can decode [`StableDiffusion3Pipeline`] latents almost instantly.
|
||||
@@ -291,9 +251,6 @@ image.save('sd3-single-file.png')
|
||||
|
||||
### Loading the single file checkpoint with T5
|
||||
|
||||
> [!TIP]
|
||||
> The following example loads a checkpoint stored in a 8-bit floating point format which requires PyTorch 2.3 or later.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import StableDiffusion3Pipeline
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# FlowMatchHeunDiscreteScheduler
|
||||
|
||||
`FlowMatchHeunDiscreteScheduler` is based on the flow-matching sampling introduced in [EDM](https://arxiv.org/abs/2403.03206).
|
||||
|
||||
## FlowMatchHeunDiscreteScheduler
|
||||
[[autodoc]] FlowMatchHeunDiscreteScheduler
|
||||
@@ -63,7 +63,7 @@ Let's walk through more in-detail design decisions for each class.
|
||||
Pipelines are designed to be easy to use (therefore do not follow [*Simple over easy*](#simple-over-easy) 100%), are not feature complete, and should loosely be seen as examples of how to use [models](#models) and [schedulers](#schedulers) for inference.
|
||||
|
||||
The following design principles are followed:
|
||||
- Pipelines follow the single-file policy. All pipelines can be found in individual directories under src/diffusers/pipelines. One pipeline folder corresponds to one diffusion paper/project/release. Multiple pipeline files can be gathered in one pipeline folder, as it’s done for [`src/diffusers/pipelines/stable-diffusion`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/stable_diffusion). If pipelines share similar functionality, one can make use of the [# Copied from mechanism](https://github.com/huggingface/diffusers/blob/125d783076e5bd9785beb05367a2d2566843a271/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py#L251).
|
||||
- Pipelines follow the single-file policy. All pipelines can be found in individual directories under src/diffusers/pipelines. One pipeline folder corresponds to one diffusion paper/project/release. Multiple pipeline files can be gathered in one pipeline folder, as it’s done for [`src/diffusers/pipelines/stable-diffusion`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/stable_diffusion). If pipelines share similar functionality, one can make use of the [#Copied from mechanism](https://github.com/huggingface/diffusers/blob/125d783076e5bd9785beb05367a2d2566843a271/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py#L251).
|
||||
- Pipelines all inherit from [`DiffusionPipeline`].
|
||||
- Every pipeline consists of different model and scheduler components, that are documented in the [`model_index.json` file](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/model_index.json), are accessible under the same name as attributes of the pipeline and can be shared between pipelines with [`DiffusionPipeline.components`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.components) function.
|
||||
- Every pipeline should be loadable via the [`DiffusionPipeline.from_pretrained`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained) function.
|
||||
@@ -81,7 +81,7 @@ Models are designed as configurable toolboxes that are natural extensions of [Py
|
||||
|
||||
The following design principles are followed:
|
||||
- Models correspond to **a type of model architecture**. *E.g.* the [`UNet2DConditionModel`] class is used for all UNet variations that expect 2D image inputs and are conditioned on some context.
|
||||
- All models can be found in [`src/diffusers/models`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models) and every model architecture shall be defined in its file, e.g. [`unets/unet_2d_condition.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unets/unet_2d_condition.py), [`transformers/transformer_2d.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_2d.py), etc...
|
||||
- All models can be found in [`src/diffusers/models`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models) and every model architecture shall be defined in its file, e.g. [`unet_2d_condition.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py), [`transformer_2d.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformer_2d.py), etc...
|
||||
- Models **do not** follow the single-file policy and should make use of smaller model building blocks, such as [`attention.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py), [`resnet.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py), [`embeddings.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py), etc... **Note**: This is in stark contrast to Transformers' modeling files and shows that models do not really follow the single-file policy.
|
||||
- Models intend to expose complexity, just like PyTorch's `Module` class, and give clear error messages.
|
||||
- Models all inherit from `ModelMixin` and `ConfigMixin`.
|
||||
@@ -90,7 +90,7 @@ The following design principles are followed:
|
||||
- To integrate new model checkpoints whose general architecture can be classified as an architecture that already exists in Diffusers, the existing model architecture shall be adapted to make it work with the new checkpoint. One should only create a new file if the model architecture is fundamentally different.
|
||||
- Models should be designed to be easily extendable to future changes. This can be achieved by limiting public function arguments, configuration arguments, and "foreseeing" future changes, *e.g.* it is usually better to add `string` "...type" arguments that can easily be extended to new future types instead of boolean `is_..._type` arguments. Only the minimum amount of changes shall be made to existing architectures to make a new model checkpoint work.
|
||||
- The model design is a difficult trade-off between keeping code readable and concise and supporting many model checkpoints. For most parts of the modeling code, classes shall be adapted for new model checkpoints, while there are some exceptions where it is preferred to add new classes to make sure the code is kept concise and
|
||||
readable long-term, such as [UNet blocks](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unets/unet_2d_blocks.py) and [Attention processors](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
readable long-term, such as [UNet blocks](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py) and [Attention processors](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
|
||||
### Schedulers
|
||||
|
||||
@@ -100,11 +100,11 @@ The following design principles are followed:
|
||||
- All schedulers are found in [`src/diffusers/schedulers`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers).
|
||||
- Schedulers are **not** allowed to import from large utils files and shall be kept very self-contained.
|
||||
- One scheduler Python file corresponds to one scheduler algorithm (as might be defined in a paper).
|
||||
- If schedulers share similar functionalities, we can make use of the `# Copied from` mechanism.
|
||||
- If schedulers share similar functionalities, we can make use of the `#Copied from` mechanism.
|
||||
- Schedulers all inherit from `SchedulerMixin` and `ConfigMixin`.
|
||||
- Schedulers can be easily swapped out with the [`ConfigMixin.from_config`](https://huggingface.co/docs/diffusers/main/en/api/configuration#diffusers.ConfigMixin.from_config) method as explained in detail [here](../using-diffusers/schedulers).
|
||||
- Schedulers can be easily swapped out with the [`ConfigMixin.from_config`](https://huggingface.co/docs/diffusers/main/en/api/configuration#diffusers.ConfigMixin.from_config) method as explained in detail [here](../using-diffusers/schedulers.md).
|
||||
- Every scheduler has to have a `set_num_inference_steps`, and a `step` function. `set_num_inference_steps(...)` has to be called before every denoising process, *i.e.* before `step(...)` is called.
|
||||
- Every scheduler exposes the timesteps to be "looped over" via a `timesteps` attribute, which is an array of timesteps the model will be called upon.
|
||||
- The `step(...)` function takes a predicted model output and the "current" sample (x_t) and returns the "previous", slightly more denoised sample (x_t-1).
|
||||
- Given the complexity of diffusion schedulers, the `step` function does not expose all the complexity and can be a bit of a "black box".
|
||||
- In almost all cases, novel schedulers shall be implemented in a new scheduling file.
|
||||
- In almost all cases, novel schedulers shall be implemented in a new scheduling file.
|
||||
|
||||
@@ -52,6 +52,76 @@ To learn more, take a look at the [Distributed Inference with 🤗 Accelerate](h
|
||||
|
||||
</Tip>
|
||||
|
||||
### Device placement
|
||||
|
||||
> [!WARNING]
|
||||
> This feature is experimental and its APIs might change in the future.
|
||||
|
||||
With Accelerate, you can use the `device_map` to determine how to distribute the models of a pipeline across multiple devices. This is useful in situations where you have more than one GPU.
|
||||
|
||||
For example, if you have two 8GB GPUs, then using [`~DiffusionPipeline.enable_model_cpu_offload`] may not work so well because:
|
||||
|
||||
* it only works on a single GPU
|
||||
* a single model might not fit on a single GPU ([`~DiffusionPipeline.enable_sequential_cpu_offload`] might work but it will be extremely slow and it is also limited to a single GPU)
|
||||
|
||||
To make use of both GPUs, you can use the "balanced" device placement strategy which splits the models across all available GPUs.
|
||||
|
||||
> [!WARNING]
|
||||
> Only the "balanced" strategy is supported at the moment, and we plan to support additional mapping strategies in the future.
|
||||
|
||||
```diff
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
- "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True,
|
||||
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True, device_map="balanced"
|
||||
)
|
||||
image = pipeline("a dog").images[0]
|
||||
image
|
||||
```
|
||||
|
||||
You can also pass a dictionary to enforce the maximum GPU memory that can be used on each device:
|
||||
|
||||
```diff
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
max_memory = {0:"1GB", 1:"1GB"}
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16,
|
||||
use_safetensors=True,
|
||||
device_map="balanced",
|
||||
+ max_memory=max_memory
|
||||
)
|
||||
image = pipeline("a dog").images[0]
|
||||
image
|
||||
```
|
||||
|
||||
If a device is not present in `max_memory`, then it will be completely ignored and will not participate in the device placement.
|
||||
|
||||
By default, Diffusers uses the maximum memory of all devices. If the models don't fit on the GPUs, they are offloaded to the CPU. If the CPU doesn't have enough memory, then you might see an error. In that case, you could defer to using [`~DiffusionPipeline.enable_sequential_cpu_offload`] and [`~DiffusionPipeline.enable_model_cpu_offload`].
|
||||
|
||||
Call [`~DiffusionPipeline.reset_device_map`] to reset the `device_map` of a pipeline. This is also necessary if you want to use methods like `to()`, [`~DiffusionPipeline.enable_sequential_cpu_offload`], and [`~DiffusionPipeline.enable_model_cpu_offload`] on a pipeline that was device-mapped.
|
||||
|
||||
```py
|
||||
pipeline.reset_device_map()
|
||||
```
|
||||
|
||||
Once a pipeline has been device-mapped, you can also access its device map via `hf_device_map`:
|
||||
|
||||
```py
|
||||
print(pipeline.hf_device_map)
|
||||
```
|
||||
|
||||
An example device map would look like so:
|
||||
|
||||
|
||||
```bash
|
||||
{'unet': 1, 'vae': 1, 'safety_checker': 0, 'text_encoder': 0}
|
||||
```
|
||||
|
||||
## PyTorch Distributed
|
||||
|
||||
PyTorch supports [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) which enables data parallelism.
|
||||
@@ -106,6 +176,3 @@ Once you've completed the inference script, use the `--nproc_per_node` argument
|
||||
```bash
|
||||
torchrun run_distributed.py --nproc_per_node=2
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> You can use `device_map` within a [`DiffusionPipeline`] to distribute its model-level components on multiple devices. Refer to the [Device placement](../tutorials/inference_with_big_models#device-placement) guide to learn more.
|
||||
@@ -1,139 +0,0 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Working with big models
|
||||
|
||||
A modern diffusion model, like [Stable Diffusion XL (SDXL)](../using-diffusers/sdxl), is not just a single model, but a collection of multiple models. SDXL has four different model-level components:
|
||||
|
||||
* A variational autoencoder (VAE)
|
||||
* Two text encoders
|
||||
* A UNet for denoising
|
||||
|
||||
Usually, the text encoders and the denoiser are much larger compared to the VAE.
|
||||
|
||||
As models get bigger and better, it’s possible your model is so big that even a single copy won’t fit in memory. But that doesn’t mean it can’t be loaded. If you have more than one GPU, there is more memory available to store your model. In this case, it’s better to split your model checkpoint into several smaller *checkpoint shards*.
|
||||
|
||||
When a text encoder checkpoint has multiple shards, like [T5-xxl for SD3](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers/tree/main/text_encoder_3), it is automatically handled by the [Transformers](https://huggingface.co/docs/transformers/index) library as it is a required dependency of Diffusers when using the [`StableDiffusion3Pipeline`]. More specifically, Transformers will automatically handle the loading of multiple shards within the requested model class and get it ready so that inference can be performed.
|
||||
|
||||
The denoiser checkpoint can also have multiple shards and supports inference thanks to the [Accelerate](https://huggingface.co/docs/accelerate/index) library.
|
||||
|
||||
> [!TIP]
|
||||
> Refer to the [Handling big models for inference](https://huggingface.co/docs/accelerate/main/en/concept_guides/big_model_inference) guide for general guidance when working with big models that are hard to fit into memory.
|
||||
|
||||
For example, let's save a sharded checkpoint for the [SDXL UNet](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/tree/main/unet):
|
||||
|
||||
```python
|
||||
from diffusers import UNet2DConditionModel
|
||||
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet"
|
||||
)
|
||||
unet.save_pretrained("sdxl-unet-sharded", max_shard_size="5GB")
|
||||
```
|
||||
|
||||
The size of the fp32 variant of the SDXL UNet checkpoint is ~10.4GB. Set the `max_shard_size` parameter to 5GB to create 3 shards. After saving, you can load them in [`StableDiffusionXLPipeline`]:
|
||||
|
||||
```python
|
||||
from diffusers import UNet2DConditionModel, StableDiffusionXLPipeline
|
||||
import torch
|
||||
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
"sayakpaul/sdxl-unet-sharded", torch_dtype=torch.float16
|
||||
)
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", unet=unet, torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
image = pipeline("a cute dog running on the grass", num_inference_steps=30).images[0]
|
||||
image.save("dog.png")
|
||||
```
|
||||
|
||||
If placing all the model-level components on the GPU at once is not feasible, use [`~DiffusionPipeline.enable_model_cpu_offload`] to help you:
|
||||
|
||||
```diff
|
||||
- pipeline.to("cuda")
|
||||
+ pipeline.enable_model_cpu_offload()
|
||||
```
|
||||
|
||||
In general, we recommend sharding when a checkpoint is more than 5GB (in fp32).
|
||||
|
||||
## Device placement
|
||||
|
||||
On distributed setups, you can run inference across multiple GPUs with Accelerate.
|
||||
|
||||
> [!WARNING]
|
||||
> This feature is experimental and its APIs might change in the future.
|
||||
|
||||
With Accelerate, you can use the `device_map` to determine how to distribute the models of a pipeline across multiple devices. This is useful in situations where you have more than one GPU.
|
||||
|
||||
For example, if you have two 8GB GPUs, then using [`~DiffusionPipeline.enable_model_cpu_offload`] may not work so well because:
|
||||
|
||||
* it only works on a single GPU
|
||||
* a single model might not fit on a single GPU ([`~DiffusionPipeline.enable_sequential_cpu_offload`] might work but it will be extremely slow and it is also limited to a single GPU)
|
||||
|
||||
To make use of both GPUs, you can use the "balanced" device placement strategy which splits the models across all available GPUs.
|
||||
|
||||
> [!WARNING]
|
||||
> Only the "balanced" strategy is supported at the moment, and we plan to support additional mapping strategies in the future.
|
||||
|
||||
```diff
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
- "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True,
|
||||
+ "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True, device_map="balanced"
|
||||
)
|
||||
image = pipeline("a dog").images[0]
|
||||
image
|
||||
```
|
||||
|
||||
You can also pass a dictionary to enforce the maximum GPU memory that can be used on each device:
|
||||
|
||||
```diff
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
max_memory = {0:"1GB", 1:"1GB"}
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16,
|
||||
use_safetensors=True,
|
||||
device_map="balanced",
|
||||
+ max_memory=max_memory
|
||||
)
|
||||
image = pipeline("a dog").images[0]
|
||||
image
|
||||
```
|
||||
|
||||
If a device is not present in `max_memory`, then it will be completely ignored and will not participate in the device placement.
|
||||
|
||||
By default, Diffusers uses the maximum memory of all devices. If the models don't fit on the GPUs, they are offloaded to the CPU. If the CPU doesn't have enough memory, then you might see an error. In that case, you could defer to using [`~DiffusionPipeline.enable_sequential_cpu_offload`] and [`~DiffusionPipeline.enable_model_cpu_offload`].
|
||||
|
||||
Call [`~DiffusionPipeline.reset_device_map`] to reset the `device_map` of a pipeline. This is also necessary if you want to use methods like `to()`, [`~DiffusionPipeline.enable_sequential_cpu_offload`], and [`~DiffusionPipeline.enable_model_cpu_offload`] on a pipeline that was device-mapped.
|
||||
|
||||
```py
|
||||
pipeline.reset_device_map()
|
||||
```
|
||||
|
||||
Once a pipeline has been device-mapped, you can also access its device map via `hf_device_map`:
|
||||
|
||||
```py
|
||||
print(pipeline.hf_device_map)
|
||||
```
|
||||
|
||||
An example device map would look like so:
|
||||
|
||||
|
||||
```bash
|
||||
{'unet': 1, 'vae': 1, 'safety_checker': 0, 'text_encoder': 0}
|
||||
```
|
||||
@@ -418,7 +418,7 @@ my_local_checkpoint_path = hf_hub_download(
|
||||
|
||||
my_local_config_path = snapshot_download(
|
||||
repo_id="segmind/SSD-1B",
|
||||
allow_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
|
||||
allowed_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
|
||||
)
|
||||
|
||||
pipeline = StableDiffusionXLPipeline.from_single_file(my_local_checkpoint_path, config=my_local_config_path, local_files_only=True)
|
||||
@@ -438,7 +438,7 @@ my_local_checkpoint_path = hf_hub_download(
|
||||
|
||||
my_local_config_path = snapshot_download(
|
||||
repo_id="segmind/SSD-1B",
|
||||
allow_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
|
||||
allowed_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
|
||||
local_dir="my_local_config"
|
||||
)
|
||||
|
||||
@@ -468,7 +468,7 @@ print("My local checkpoint: ", my_local_checkpoint_path)
|
||||
|
||||
my_local_config_path = snapshot_download(
|
||||
repo_id="segmind/SSD-1B",
|
||||
allow_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
|
||||
allowed_patterns=["*.json", "**/*.json", "*.txt", "**/*.txt"]
|
||||
local_dir_use_symlinks=False,
|
||||
)
|
||||
print("My local config: ", my_local_config_path)
|
||||
|
||||
@@ -1,351 +0,0 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Perturbed-Attention Guidance
|
||||
|
||||
[Perturbed-Attention Guidance (PAG)](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) is a new diffusion sampling guidance that improves sample quality across both unconditional and conditional settings, achieving this without requiring further training or the integration of external modules. PAG is designed to progressively enhance the structure of synthesized samples throughout the denoising process by considering the self-attention mechanisms' ability to capture structural information. It involves generating intermediate samples with degraded structure by substituting selected self-attention maps in diffusion U-Net with an identity matrix, and guiding the denoising process away from these degraded samples.
|
||||
|
||||
This guide will show you how to use PAG for various tasks and use cases.
|
||||
|
||||
|
||||
## General tasks
|
||||
|
||||
You can apply PAG to the [`StableDiffusionXLPipeline`] for tasks such as text-to-image, image-to-image, and inpainting. To enable PAG for a specific task, load the pipeline using the [AutoPipeline](../api/pipelines/auto_pipeline) API with the `enable_pag=True` flag and the `pag_applied_layers` argument.
|
||||
|
||||
> [!TIP]
|
||||
> 🤗 Diffusers currently only supports using PAG with selected SDXL pipelines, but feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you want to add PAG support to a new pipeline!
|
||||
|
||||
<hfoptions id="tasks">
|
||||
<hfoption id="Text-to-image">
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
from diffusers.utils import load_image
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
enable_pag=True,
|
||||
pag_applied_layers=["mid"],
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
pipeline.enable_model_cpu_offload()
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> The `pag_applied_layers` argument allows you to specify which layers PAG is applied to. Additionally, you can use `set_pag_applied_layers` method to update these layers after the pipeline has been created. Check out the [pag_applied_layers](#pag_applied_layers) section to learn more about applying PAG to other layers.
|
||||
|
||||
If you already have a pipeline created and loaded, you can enable PAG on it using the `from_pipe` API with the `enable_pag` flag. Internally, a PAG pipeline is created based on the pipeline and task you specified. In the example below, since we used `AutoPipelineForText2Image` and passed a `StableDiffusionXLPipeline`, a `StableDiffusionXLPAGPipeline` is created accordingly. Note that this does not require additional memory, and you will have both `StableDiffusionXLPipeline` and `StableDiffusionXLPAGPipeline` loaded and ready to use. You can read more about the `from_pipe` API and how to reuse pipelines in diffuser[here](https://huggingface.co/docs/diffusers/using-diffusers/loading#reuse-a-pipeline)
|
||||
|
||||
```py
|
||||
pipeline_sdxl = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0, torch_dtype=torch.float16")
|
||||
pipeline = AutoPipelineForText2Image.from_pipe(pipeline_sdxl, enable_pag=True)
|
||||
```
|
||||
|
||||
To generate an image, you will also need to pass a `pag_scale`. When `pag_scale` increases, images gain more semantically coherent structures and exhibit fewer artifacts. However overly large guidance scale can lead to smoother textures and slight saturation in the images, similarly to CFG. `pag_scale=3.0` is used in the official demo and works well in most of the use cases, but feel free to experiment and select the appropriate value according to your needs! PAG is disabled when `pag_scale=0`.
|
||||
|
||||
```py
|
||||
prompt = "an insect robot preparing a delicious meal, anime style"
|
||||
|
||||
for pag_scale in [0.0, 3.0]:
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
images = pipeline(
|
||||
prompt=prompt,
|
||||
num_inference_steps=25,
|
||||
guidance_scale=7.0,
|
||||
generator=generator,
|
||||
pag_scale=pag_scale,
|
||||
).images
|
||||
```
|
||||
|
||||
<div class="flex flex-row gap-4">
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/pag_0.0_cfg_7.0_mid.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">generated image without PAG</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/pag_3.0_cfg_7.0_mid.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">generated image with PAG</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Image-to-image">
|
||||
|
||||
You can use PAG with image-to-image pipelines.
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForImage2Image
|
||||
from diffusers.utils import load_image
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForImage2Image.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
enable_pag=True,
|
||||
pag_applied_layers=["mid"],
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
pipeline.enable_model_cpu_offload()
|
||||
```
|
||||
|
||||
If you already have a image-to-image pipeline and would like enable PAG on it, you can run this
|
||||
|
||||
```py
|
||||
pipeline_t2i = AutoPipelineForImage2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
|
||||
pipeline = AutoPipelineForImage2Image.from_pipe(pipeline_t2i, enable_pag=True)
|
||||
```
|
||||
|
||||
It is also very easy to directly switch from a text-to-image pipeline to PAG enabled image-to-image pipeline
|
||||
|
||||
```py
|
||||
pipeline_pag = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
|
||||
pipeline = AutoPipelineForImage2Image.from_pipe(pipeline_t2i, enable_pag=True)
|
||||
```
|
||||
|
||||
If you have a PAG enabled text-to-image pipeline, you can directly switch to a image-to-image pipeline with PAG still enabled
|
||||
|
||||
```py
|
||||
pipeline_pag = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", enable_pag=True, torch_dtype=torch.float16)
|
||||
pipeline = AutoPipelineForImage2Image.from_pipe(pipeline_t2i)
|
||||
```
|
||||
|
||||
Now let's generate an image!
|
||||
|
||||
```py
|
||||
pag_scales = 4.0
|
||||
guidance_scales = 7.0
|
||||
|
||||
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png"
|
||||
init_image = load_image(url)
|
||||
prompt = "a dog catching a frisbee in the jungle"
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
image = pipeline(
|
||||
prompt,
|
||||
image=init_image,
|
||||
strength=0.8,
|
||||
guidance_scale=guidance_scale,
|
||||
pag_scale=pag_scale,
|
||||
generator=generator).images[0]
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Inpainting">
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForInpainting
|
||||
from diffusers.utils import load_image
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForInpainting.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
enable_pag=True,
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
pipeline.enable_model_cpu_offload()
|
||||
```
|
||||
|
||||
You can enable PAG on an exisiting inpainting pipeline like this
|
||||
|
||||
```py
|
||||
pipeline_inpaint = AutoPipelineForInpaiting.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
|
||||
pipeline = AutoPipelineForInpaiting.from_pipe(pipeline_inpaint, enable_pag=True)
|
||||
```
|
||||
|
||||
This still works when your pipeline has a different task:
|
||||
|
||||
```py
|
||||
pipeline_t2i = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
|
||||
pipeline = AutoPipelineForInpaiting.from_pipe(pipeline_t2i, enable_pag=True)
|
||||
```
|
||||
|
||||
Let's generate an image!
|
||||
|
||||
```py
|
||||
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
|
||||
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
|
||||
init_image = load_image(img_url).convert("RGB")
|
||||
mask_image = load_image(mask_url).convert("RGB")
|
||||
|
||||
prompt = "A majestic tiger sitting on a bench"
|
||||
|
||||
pag_scales = 3.0
|
||||
guidance_scales = 7.5
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(1)
|
||||
images = pipeline(
|
||||
prompt=prompt,
|
||||
image=init_image,
|
||||
mask_image=mask_image,
|
||||
strength=0.8,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=guidance_scale,
|
||||
generator=generator,
|
||||
pag_scale=pag_scale,
|
||||
).images
|
||||
images[0]
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## PAG with ControlNet
|
||||
|
||||
To use PAG with ControlNet, first create a `controlnet`. Then, pass the `controlnet` and other PAG arguments to the `from_pretrained` method of the AutoPipeline for the specified task.
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image, ControlNetModel
|
||||
import torch
|
||||
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
"diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
controlnet=controlnet,
|
||||
enable_pag=True,
|
||||
pag_applied_layers="mid",
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
pipeline.enable_model_cpu_offload()
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
If you already have a controlnet pipeline and want to enable PAG, you can use the `from_pipe` API: `AutoPipelineForText2Image.from_pipe(pipeline_controlnet, enable_pag=True)`
|
||||
|
||||
</Tip>
|
||||
|
||||
You can use the pipeline in the same way you normally use ControlNet pipelines, with the added option to specify a `pag_scale` parameter. Note that PAG works well for unconditional generation. In this example, we will generate an image without a prompt.
|
||||
|
||||
```py
|
||||
from diffusers.utils import load_image
|
||||
canny_image = load_image(
|
||||
"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/pag_control_input.png"
|
||||
)
|
||||
|
||||
for pag_scale in [0.0, 3.0]:
|
||||
generator = torch.Generator(device="cpu").manual_seed(1)
|
||||
images = pipeline(
|
||||
prompt="",
|
||||
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
||||
image=canny_image,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=0,
|
||||
generator=generator,
|
||||
pag_scale=pag_scale,
|
||||
).images
|
||||
images[0]
|
||||
```
|
||||
|
||||
<div class="flex flex-row gap-4">
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/pag_0.0_controlnet.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">generated image without PAG</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/pag_3.0_controlnet.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">generated image with PAG</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## PAG with IP-Adapter
|
||||
|
||||
[IP-Adapter](https://hf.co/papers/2308.06721) is a popular model that can be plugged into diffusion models to enable image prompting without any changes to the underlying model. You can enable PAG on a pipeline with IP-Adapter loaded.
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
from diffusers.utils import load_image
|
||||
from transformers import CLIPVisionModelWithProjection
|
||||
import torch
|
||||
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
||||
"h94/IP-Adapter",
|
||||
subfolder="models/image_encoder",
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
image_encoder=image_encoder,
|
||||
enable_pag=True,
|
||||
torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter-plus_sdxl_vit-h.bin")
|
||||
|
||||
pag_scales = 5.0
|
||||
ip_adapter_scales = 0.8
|
||||
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_diner.png")
|
||||
|
||||
pipeline.set_ip_adapter_scale(ip_adapter_scale)
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
images = pipeline(
|
||||
prompt="a polar bear sitting in a chair drinking a milkshake",
|
||||
ip_adapter_image=image,
|
||||
negative_prompt="deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
|
||||
num_inference_steps=25,
|
||||
guidance_scale=3.0,
|
||||
generator=generator,
|
||||
pag_scale=pag_scale,
|
||||
).images
|
||||
images[0]
|
||||
|
||||
```
|
||||
|
||||
PAG reduces artifacts and improves the overall compposition.
|
||||
|
||||
<div class="flex flex-row gap-4">
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/pag_0.0_ipa_0.8.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">generated image without PAG</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/pag_5.0_ipa_0.8.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">generated image with PAG</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
## Configure parameters
|
||||
|
||||
### pag_applied_layers
|
||||
|
||||
The `pag_applied_layers` argument allows you to specify which layers PAG is applied to. By default, it applies only to the mid blocks. Changing this setting will significantly impact the output. You can use the `set_pag_applied_layers` method to adjust the PAG layers after the pipeline is created, helping you find the optimal layers for your model.
|
||||
|
||||
As an example, here is the images generated with `pag_layers = ["down.block_2"]` and `pag_layers = ["down.block_2", "up.block_1.attentions_0"]`
|
||||
|
||||
```py
|
||||
prompt = "an insect robot preparing a delicious meal, anime style"
|
||||
pipeline.set_pag_applied_layers(pag_layers)
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
images = pipeline(
|
||||
prompt=prompt,
|
||||
num_inference_steps=25,
|
||||
guidance_scale=guidance_scale,
|
||||
generator=generator,
|
||||
pag_scale=pag_scale,
|
||||
).images
|
||||
images[0]
|
||||
```
|
||||
|
||||
<div class="flex flex-row gap-4">
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/pag_3.0_cfg_7.0_down2_up1a0.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">down.block_2 + up.block1.attentions_0</figcaption>
|
||||
</div>
|
||||
<div class="flex-1">
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/pag_3.0_cfg_7.0_down2.png"/>
|
||||
<figcaption class="mt-2 text-center text-sm text-gray-500">down.block_2</figcaption>
|
||||
</div>
|
||||
</div>
|
||||
@@ -186,7 +186,7 @@ scheduler, scheduler_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
|
||||
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
scheduler=scheduler,
|
||||
variant="bf16",
|
||||
revision="bf16",
|
||||
dtype=jax.numpy.bfloat16,
|
||||
)
|
||||
params["scheduler"] = scheduler_state
|
||||
|
||||
@@ -285,12 +285,6 @@ refiner = DiffusionPipeline.from_pretrained(
|
||||
).to("cuda")
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
You can use SDXL refiner with a different base model. For example, you can use the [Hunyuan-DiT](../../api/pipelines/hunyuandit) or [PixArt-Sigma](../../api/pipelines/pixart_sigma) pipelines to generate images with better prompt adherence. Once you have generated an image, you can pass it to the SDXL refiner model to enhance final generation quality.
|
||||
|
||||
</Tip>
|
||||
|
||||
Generate an image from the base model, and set the model output to **latent** space:
|
||||
|
||||
```py
|
||||
|
||||
@@ -63,7 +63,7 @@ Flax is a functional framework, so models are stateless and parameters are store
|
||||
dtype = jnp.bfloat16
|
||||
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
variant="bf16",
|
||||
revision="bf16",
|
||||
dtype=dtype,
|
||||
)
|
||||
```
|
||||
|
||||
@@ -10,30 +10,30 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# 철학 [[philosophy]]
|
||||
# 철학
|
||||
|
||||
🧨 Diffusers는 다양한 모달리티에서 **최신의** 사전 훈련된 diffusion 모델을 제공합니다.
|
||||
그 목적은 추론과 훈련을 위한 **모듈식 툴박스**로 사용되는 것입니다.
|
||||
|
||||
저희는 시간이 지나도 변치 않는 라이브러리를 구축하는 것을 목표로 하기에 API 설계를 매우 중요하게 생각합니다.
|
||||
우리는 오랜 시간에 견딜 수 있는 라이브러리를 구축하는 것을 목표로 하고, 따라서 API 설계를 매우 중요시합니다.
|
||||
|
||||
간단히 말해서, Diffusers는 PyTorch를 자연스럽게 확장할 수 있도록 만들어졌습니다. 따라서 대부분의 설계 선택은 [PyTorch의 설계 원칙](https://pytorch.org/docs/stable/community/design.html#pytorch-design-philosophy)에 기반합니다. 이제 가장 중요한 것들을 살펴보겠습니다:
|
||||
간단히 말해서, Diffusers는 PyTorch의 자연스러운 확장이 되도록 구축되었습니다. 따라서 대부분의 설계 선택은 [PyTorch의 설계 원칙](https://pytorch.org/docs/stable/community/design.html#pytorch-design-philosophy)에 기반합니다. 이제 가장 중요한 것들을 살펴보겠습니다:
|
||||
|
||||
## 성능보다는 사용성을 [[usability-over-performance]]
|
||||
## 성능보다는 사용성을
|
||||
|
||||
- Diffusers는 다양한 성능 향상 기능이 내장되어 있지만 (자세한 내용은 [메모리와 속도](https://huggingface.co/docs/diffusers/optimization/fp16) 참조), 모델은 항상 가장 높은 정밀도와 최소한의 최적화로 로드됩니다. 따라서 사용자가 별도로 정의하지 않는 한 기본적으로 diffusion 파이프라인은 항상 float32 정밀도로 CPU에 인스턴스화됩니다. 이는 다양한 플랫폼과 가속기에서의 사용성을 보장하며, 라이브러리를 실행하기 위해 복잡한 설치가 필요하지 않다는 것을 의미합니다.
|
||||
- Diffusers는 많은 내장 성능 향상 기능을 갖고 있지만 (자세한 내용은 [메모리와 속도](https://huggingface.co/docs/diffusers/optimization/fp16) 참조), 모델은 항상 가장 높은 정밀도와 최소한의 최적화로 로드됩니다. 따라서 기본적인 diffusion 파이프라인은 따로 정의하지 않는다면 CPU에서 float32 정밀도로 인스턴스화됩니다. 이는 다양한 플랫폼과 가속기에서의 사용성을 보장하며, 라이브러리를 실행하기 위해 복잡한 설치가 필요하지 않음을 의미합니다.
|
||||
- Diffusers는 **가벼운** 패키지를 지향하기 때문에 필수 종속성은 거의 없지만 성능을 향상시킬 수 있는 많은 선택적 종속성이 있습니다 (`accelerate`, `safetensors`, `onnx` 등). 저희는 라이브러리를 가능한 한 가볍게 유지하여 다른 패키지에 대한 종속성 걱정이 없도록 노력하고 있습니다.
|
||||
- Diffusers는 간결하고 이해하기 쉬운 코드를 선호합니다. 이는 람다 함수나 고급 PyTorch 연산자와 같은 압축된 코드 구문을 자주 사용하지 않는 것을 의미합니다.
|
||||
|
||||
## 쉬움보다는 간단함을 [[simple-over-easy]]
|
||||
## 쉬움보다는 간단함을
|
||||
|
||||
PyTorch에서는 **명시적인 것이 암시적인 것보다 낫다**와 **단순한 것이 복잡한 것보다 낫다**라고 말합니다. 이 설계 철학은 라이브러리의 여러 부분에 반영되어 있습니다:
|
||||
- [`DiffusionPipeline.to`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.to)와 같은 메소드를 사용하여 사용자가 장치 관리를 할 수 있도록 PyTorch의 API를 따릅니다.
|
||||
- [`DiffusionPipeline.to`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.to)와 같은 메서드를 사용하여 사용자가 장치 관리를 할 수 있도록 PyTorch의 API를 따릅니다.
|
||||
- 잘못된 입력을 조용히 수정하는 대신 간결한 오류 메시지를 발생시키는 것이 우선입니다. Diffusers는 라이브러리를 가능한 한 쉽게 사용할 수 있도록 하는 것보다 사용자를 가르치는 것을 목표로 합니다.
|
||||
- 복잡한 모델과 스케줄러 로직이 내부에서 마법처럼 처리하는 대신 노출됩니다. 스케줄러/샘플러는 서로에게 최소한의 종속성을 가지고 분리되어 있습니다. 이로써 사용자는 언롤된 노이즈 제거 루프를 작성해야 합니다. 그러나 이 분리는 디버깅을 더 쉽게하고 노이즈 제거 과정을 조정하거나 diffusers 모델이나 스케줄러를 교체하는 데 사용자에게 더 많은 제어권을 제공합니다.
|
||||
- diffusers 파이프라인의 따로 훈련된 구성 요소인 text encoder, unet 및 variational autoencoder는 각각 자체 모델 클래스를 갖습니다. 이로써 사용자는 서로 다른 모델의 구성 요소 간의 상호 작용을 처리해야 하며, 직렬화 형식은 모델 구성 요소를 다른 파일로 분리합니다. 그러나 이는 디버깅과 커스터마이징을 더 쉽게합니다. DreamBooth나 Textual Inversion 훈련은 Diffusers의 'diffusion 파이프라인의 단일 구성 요소들을 분리할 수 있는 능력' 덕분에 매우 간단합니다.
|
||||
|
||||
## 추상화보다는 수정 가능하고 기여하기 쉬움을 [[tweakable-contributor-friendly-over-abstraction]]
|
||||
## 추상화보다는 수정 가능하고 기여하기 쉬움을
|
||||
|
||||
라이브러리의 대부분에 대해 Diffusers는 [Transformers 라이브러리](https://github.com/huggingface/transformers)의 중요한 설계 원칙을 채택합니다, 바로 성급한 추상화보다는 copy-pasted 코드를 선호한다는 것입니다. 이 설계 원칙은 [Don't repeat yourself (DRY)](https://en.wikipedia.org/wiki/Don%27t_repeat_yourself)와 같은 인기 있는 설계 원칙과는 대조적으로 매우 의견이 분분한데요.
|
||||
간단히 말해서, Transformers가 모델링 파일에 대해 수행하는 것처럼, Diffusers는 매우 낮은 수준의 추상화와 매우 독립적인 코드를 유지하는 것을 선호합니다. 함수, 긴 코드 블록, 심지어 클래스도 여러 파일에 복사할 수 있으며, 이는 처음에는 라이브러리를 유지할 수 없게 만드는 나쁜, 서투른 설계 선택으로 보일 수 있습니다. 하지만 이러한 설계는 매우 성공적이며, 커뮤니티 기반의 오픈 소스 기계 학습 라이브러리에 매우 적합합니다. 그 이유는 다음과 같습니다:
|
||||
@@ -48,16 +48,16 @@ Diffusers에서는 이러한 철학을 파이프라인과 스케줄러에 모두
|
||||
좋아요, 이제 🧨 Diffusers가 설계된 방식을 대략적으로 이해했을 것입니다 🤗.
|
||||
우리는 이러한 설계 원칙을 일관되게 라이브러리 전체에 적용하려고 노력하고 있습니다. 그럼에도 불구하고 철학에 대한 일부 예외 사항이나 불행한 설계 선택이 있을 수 있습니다. 디자인에 대한 피드백이 있다면 [GitHub에서 직접](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=) 알려주시면 감사하겠습니다.
|
||||
|
||||
## 디자인 철학 자세히 알아보기 [[design-philosophy-in-details]]
|
||||
## 디자인 철학 자세히 알아보기
|
||||
|
||||
이제 디자인 철학의 세부 사항을 좀 더 자세히 살펴보겠습니다. Diffusers는 주로 세 가지 주요 클래스로 구성됩니다: [파이프라인](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines), [모델](https://github.com/huggingface/diffusers/tree/main/src/diffusers/models), 그리고 [스케줄러](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers). 각 클래스에 대한 더 자세한 설계 결정 사항을 살펴보겠습니다.
|
||||
|
||||
### 파이프라인 [[pipelines]]
|
||||
### 파이프라인
|
||||
|
||||
파이프라인은 사용하기 쉽도록 설계되었으며 (따라서 [*쉬움보다는 간단함을*](#쉬움보다는-간단함을)을 100% 따르지는 않음), feature-complete하지 않으며, 추론을 위한 [모델](#모델)과 [스케줄러](#스케줄러)를 사용하는 방법의 예시로 간주될 수 있습니다.
|
||||
|
||||
다음과 같은 설계 원칙을 따릅니다:
|
||||
- 파이프라인은 단일 파일 정책을 따릅니다. 모든 파이프라인은 src/diffusers/pipelines의 개별 디렉토리에 있습니다. 하나의 파이프라인 폴더는 하나의 diffusion 논문/프로젝트/릴리스에 해당합니다. 여러 파이프라인 파일은 하나의 파이프라인 폴더에 모을 수 있습니다. 예를 들어 [`src/diffusers/pipelines/stable-diffusion`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/stable_diffusion)에서 그렇게 하고 있습니다. 파이프라인이 유사한 기능을 공유하는 경우, [# Copied from mechanism](https://github.com/huggingface/diffusers/blob/125d783076e5bd9785beb05367a2d2566843a271/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py#L251)을 사용할 수 있습니다.
|
||||
- 파이프라인은 단일 파일 정책을 따릅니다. 모든 파이프라인은 src/diffusers/pipelines의 개별 디렉토리에 있습니다. 하나의 파이프라인 폴더는 하나의 diffusion 논문/프로젝트/릴리스에 해당합니다. 여러 파이프라인 파일은 하나의 파이프라인 폴더에 모을 수 있습니다. 예를 들어 [`src/diffusers/pipelines/stable-diffusion`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines/stable_diffusion)에서 그렇게 하고 있습니다. 파이프라인이 유사한 기능을 공유하는 경우, [#Copied from mechanism](https://github.com/huggingface/diffusers/blob/125d783076e5bd9785beb05367a2d2566843a271/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py#L251)을 사용할 수 있습니다.
|
||||
- 파이프라인은 모두 [`DiffusionPipeline`]을 상속합니다.
|
||||
- 각 파이프라인은 서로 다른 모델 및 스케줄러 구성 요소로 구성되어 있으며, 이는 [`model_index.json` 파일](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/model_index.json)에 문서화되어 있으며, 파이프라인의 속성 이름과 동일한 이름으로 액세스할 수 있으며, [`DiffusionPipeline.components`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.components) 함수를 통해 파이프라인 간에 공유할 수 있습니다.
|
||||
- 각 파이프라인은 [`DiffusionPipeline.from_pretrained`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained) 함수를 통해 로드할 수 있어야 합니다.
|
||||
@@ -65,11 +65,11 @@ Diffusers에서는 이러한 철학을 파이프라인과 스케줄러에 모두
|
||||
- 파이프라인은 매우 가독성이 좋고, 이해하기 쉽고, 쉽게 조정할 수 있도록 설계되어야 합니다.
|
||||
- 파이프라인은 서로 상호작용하고, 상위 수준 API에 쉽게 통합할 수 있도록 설계되어야 합니다.
|
||||
- 파이프라인은 사용자 인터페이스가 feature-complete하지 않게 하는 것을 목표로 합니다. future-complete한 사용자 인터페이스를 원한다면 [InvokeAI](https://github.com/invoke-ai/InvokeAI), [Diffuzers](https://github.com/abhishekkrthakur/diffuzers), [lama-cleaner](https://github.com/Sanster/lama-cleaner)를 참조해야 합니다.
|
||||
- 모든 파이프라인은 오로지 `__call__` 메소드를 통해 실행할 수 있어야 합니다. `__call__` 인자의 이름은 모든 파이프라인에서 공유되어야 합니다.
|
||||
- 모든 파이프라인은 오로지 `__call__` 메서드를 통해 실행할 수 있어야 합니다. `__call__` 인자의 이름은 모든 파이프라인에서 공유되어야 합니다.
|
||||
- 파이프라인은 해결하고자 하는 작업의 이름으로 지정되어야 합니다.
|
||||
- 대부분의 경우에 새로운 diffusion 파이프라인은 새로운 파이프라인 폴더/파일에 구현되어야 합니다.
|
||||
|
||||
### 모델 [[models]]
|
||||
### 모델
|
||||
|
||||
모델은 [PyTorch의 Module 클래스](https://pytorch.org/docs/stable/generated/torch.nn.Module.html)의 자연스러운 확장이 되도록, 구성 가능한 툴박스로 설계되었습니다. 그리고 모델은 **단일 파일 정책**을 일부만 따릅니다.
|
||||
|
||||
@@ -85,7 +85,7 @@ Diffusers에서는 이러한 철학을 파이프라인과 스케줄러에 모두
|
||||
- 모델은 미래의 변경 사항을 쉽게 확장할 수 있도록 설계되어야 합니다. 이는 공개 함수 인수들과 구성 인수들을 제한하고,미래의 변경 사항을 "예상"하는 것을 통해 달성할 수 있습니다. 예를 들어, 불리언 `is_..._type` 인수보다는 새로운 미래 유형에 쉽게 확장할 수 있는 문자열 "...type" 인수를 추가하는 것이 일반적으로 더 좋습니다. 새로운 모델 체크포인트가 작동하도록 하기 위해 기존 아키텍처에 최소한의 변경만을 가해야 합니다.
|
||||
- 모델 디자인은 코드의 가독성과 간결성을 유지하는 것과 많은 모델 체크포인트를 지원하는 것 사이의 어려운 균형 조절입니다. 모델링 코드의 대부분은 새로운 모델 체크포인트를 위해 클래스를 수정하는 것이 좋지만, [UNet 블록](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py) 및 [Attention 프로세서](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py)와 같이 코드를 장기적으로 간결하고 읽기 쉽게 유지하기 위해 새로운 클래스를 추가하는 예외도 있습니다.
|
||||
|
||||
### 스케줄러 [[schedulers]]
|
||||
### 스케줄러
|
||||
|
||||
스케줄러는 추론을 위한 노이즈 제거 과정을 안내하고 훈련을 위한 노이즈 스케줄을 정의하는 역할을 합니다. 스케줄러는 개별 클래스로 설계되어 있으며, 로드 가능한 구성 파일과 **단일 파일 정책**을 엄격히 따릅니다.
|
||||
|
||||
@@ -93,9 +93,9 @@ Diffusers에서는 이러한 철학을 파이프라인과 스케줄러에 모두
|
||||
- 모든 스케줄러는 [`src/diffusers/schedulers`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers)에서 찾을 수 있습니다.
|
||||
- 스케줄러는 큰 유틸리티 파일에서 가져오지 **않아야** 하며, 자체 포함성을 유지해야 합니다.
|
||||
- 하나의 스케줄러 Python 파일은 하나의 스케줄러 알고리즘(논문에서 정의된 것과 같은)에 해당합니다.
|
||||
- 스케줄러가 유사한 기능을 공유하는 경우, `# Copied from` 메커니즘을 사용할 수 있습니다.
|
||||
- 스케줄러가 유사한 기능을 공유하는 경우, `#Copied from` 메커니즘을 사용할 수 있습니다.
|
||||
- 모든 스케줄러는 `SchedulerMixin`과 `ConfigMixin`을 상속합니다.
|
||||
- [`ConfigMixin.from_config`](https://huggingface.co/docs/diffusers/main/en/api/configuration#diffusers.ConfigMixin.from_config) 메소드를 사용하여 스케줄러를 쉽게 교체할 수 있습니다. 자세한 내용은 [여기](../using-diffusers/schedulers.md)에서 설명합니다.
|
||||
- [`ConfigMixin.from_config`](https://huggingface.co/docs/diffusers/main/en/api/configuration#diffusers.ConfigMixin.from_config) 메서드를 사용하여 스케줄러를 쉽게 교체할 수 있습니다. 자세한 내용은 [여기](../using-diffusers/schedulers.md)에서 설명합니다.
|
||||
- 모든 스케줄러는 `set_num_inference_steps`와 `step` 함수를 가져야 합니다. `set_num_inference_steps(...)`는 각 노이즈 제거 과정(즉, `step(...)`이 호출되기 전) 이전에 호출되어야 합니다.
|
||||
- 각 스케줄러는 모델이 호출될 타임스텝의 배열인 `timesteps` 속성을 통해 루프를 돌 수 있는 타임스텝을 노출합니다.
|
||||
- `step(...)` 함수는 예측된 모델 출력과 "현재" 샘플(x_t)을 입력으로 받고, "이전" 약간 더 노이즈가 제거된 샘플(x_t-1)을 반환합니다.
|
||||
|
||||
@@ -58,7 +58,7 @@ outputs = pipeline(
|
||||
)
|
||||
```
|
||||
|
||||
더 많은 정보를 얻기 위해, Optimum Habana의 [문서](https://huggingface.co/docs/optimum/habana/usage_guides/stable_diffusion)와 공식 GitHub 저장소에 제공된 [예시](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion)를 확인하세요.
|
||||
더 많은 정보를 얻기 위해, Optimum Habana의 [문서](https://huggingface.co/docs/optimum/habana/usage_guides/stable_diffusion)와 공식 Github 저장소에 제공된 [예시](https://github.com/huggingface/optimum-habana/tree/main/examples/stable-diffusion)를 확인하세요.
|
||||
|
||||
|
||||
## 벤치마크
|
||||
|
||||
@@ -296,7 +296,7 @@ scheduler, scheduler_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
|
||||
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||
model_id,
|
||||
scheduler=scheduler,
|
||||
variant="bf16",
|
||||
revision="bf16",
|
||||
dtype=jax.numpy.bfloat16,
|
||||
)
|
||||
params["scheduler"] = scheduler_state
|
||||
|
||||
@@ -83,7 +83,7 @@ Flax는 함수형 프레임워크이므로 모델은 무상태(stateless)형이
|
||||
```python
|
||||
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
variant="bf16",
|
||||
revision="bf16",
|
||||
dtype=dtype,
|
||||
)
|
||||
```
|
||||
|
||||
@@ -1290,7 +1290,6 @@ def main(args):
|
||||
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
@@ -1525,22 +1524,17 @@ def main(args):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
|
||||
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
|
||||
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
|
||||
num_training_steps_for_scheduler = (
|
||||
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
|
||||
)
|
||||
else:
|
||||
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=num_warmup_steps_for_scheduler,
|
||||
num_training_steps=num_training_steps_for_scheduler,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
num_cycles=args.lr_num_cycles,
|
||||
power=args.lr_power,
|
||||
)
|
||||
@@ -1557,14 +1551,8 @@ def main(args):
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
if overrode_max_train_steps:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
|
||||
logger.warning(
|
||||
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
|
||||
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
|
||||
f"This inconsistency may result in the learning rate scheduler not functioning properly."
|
||||
)
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
@@ -1857,10 +1845,10 @@ def main(args):
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
pipeline_args = {"prompt": args.validation_prompt}
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
if torch.backends.mps.is_available():
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
images = [
|
||||
@@ -1881,6 +1869,7 @@ def main(args):
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -1982,7 +1971,7 @@ def main(args):
|
||||
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
|
||||
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
|
||||
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
|
||||
save_file(kohya_state_dict, f"{args.output_dir}/{Path(args.output_dir).name}.safetensors")
|
||||
save_file(kohya_state_dict, f"{args.output_dir}/{args.output_dir}.safetensors")
|
||||
|
||||
save_model_card(
|
||||
model_id if not args.push_to_hub else repo_id,
|
||||
|
||||
@@ -31,6 +31,8 @@ from typing import List, Optional
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
# imports of the TokenEmbeddingsHandler class
|
||||
import torch.utils.checkpoint
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
@@ -75,9 +77,6 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.30.0.dev0")
|
||||
|
||||
@@ -102,12 +101,12 @@ def save_model_card(
|
||||
repo_id: str,
|
||||
use_dora: bool,
|
||||
images=None,
|
||||
base_model: str = None,
|
||||
base_model=str,
|
||||
train_text_encoder=False,
|
||||
train_text_encoder_ti=False,
|
||||
token_abstraction_dict=None,
|
||||
instance_prompt: str = None,
|
||||
validation_prompt: str = None,
|
||||
instance_prompt=str,
|
||||
validation_prompt=str,
|
||||
repo_folder=None,
|
||||
vae_path=None,
|
||||
):
|
||||
@@ -136,14 +135,6 @@ def save_model_card(
|
||||
diffusers_imports_pivotal = ""
|
||||
diffusers_example_pivotal = ""
|
||||
webui_example_pivotal = ""
|
||||
license = ""
|
||||
if "playground" in base_model:
|
||||
license = """\n
|
||||
## License
|
||||
|
||||
Please adhere to the licensing terms as described [here](https://huggingface.co/playgroundai/playground-v2.5-1024px-aesthetic/blob/main/LICENSE.md).
|
||||
"""
|
||||
|
||||
if train_text_encoder_ti:
|
||||
trigger_str = (
|
||||
"To trigger image generation of trained concept(or concepts) replace each concept identifier "
|
||||
@@ -232,75 +223,11 @@ Pivotal tuning was enabled: {train_text_encoder_ti}.
|
||||
|
||||
Special VAE used for training: {vae_path}.
|
||||
|
||||
{license}
|
||||
"""
|
||||
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
||||
f.write(yaml + model_card)
|
||||
|
||||
|
||||
def log_validation(
|
||||
pipeline,
|
||||
args,
|
||||
accelerator,
|
||||
pipeline_args,
|
||||
epoch,
|
||||
is_final_validation=False,
|
||||
):
|
||||
logger.info(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
f" {args.validation_prompt}."
|
||||
)
|
||||
|
||||
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
|
||||
scheduler_args = {}
|
||||
|
||||
if not args.do_edm_style_training:
|
||||
if "variance_type" in pipeline.scheduler.config:
|
||||
variance_type = pipeline.scheduler.config.variance_type
|
||||
|
||||
if variance_type in ["learned", "learned_range"]:
|
||||
variance_type = "fixed_small"
|
||||
|
||||
scheduler_args["variance_type"] = variance_type
|
||||
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
|
||||
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
# Currently the context determination is a bit hand-wavy. We can improve it in the future if there's a better
|
||||
# way to condition it. Reference: https://github.com/huggingface/diffusers/pull/7126#issuecomment-1968523051
|
||||
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
phase_name = "test" if is_final_validation else "validation"
|
||||
if tracker.name == "tensorboard":
|
||||
np_images = np.stack([np.asarray(img) for img in images])
|
||||
tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
|
||||
if tracker.name == "wandb":
|
||||
tracker.log(
|
||||
{
|
||||
phase_name: [
|
||||
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
del pipeline
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return images
|
||||
|
||||
|
||||
def import_model_class_from_model_name_or_path(
|
||||
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
|
||||
):
|
||||
@@ -463,7 +390,6 @@ def parse_args(input_args=None):
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_edm_style_training",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Flag to conduct training using the EDM formulation as introduced in https://arxiv.org/abs/2206.00364.",
|
||||
)
|
||||
@@ -573,13 +499,6 @@ def parse_args(input_args=None):
|
||||
default=1e-4,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clip_skip",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that "
|
||||
"the output of the pre-final layer will be used for computing the prompt embeddings.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--text_encoder_lr",
|
||||
@@ -652,7 +571,7 @@ def parse_args(input_args=None):
|
||||
parser.add_argument(
|
||||
"--optimizer",
|
||||
type=str,
|
||||
default="AdamW",
|
||||
default="adamW",
|
||||
help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
|
||||
)
|
||||
|
||||
@@ -987,6 +906,11 @@ class DreamBoothDataset(Dataset):
|
||||
instance_data_root,
|
||||
instance_prompt,
|
||||
class_prompt,
|
||||
dataset_name,
|
||||
dataset_config_name,
|
||||
cache_dir,
|
||||
image_column,
|
||||
caption_column,
|
||||
train_text_encoder_ti,
|
||||
class_data_root=None,
|
||||
class_num=None,
|
||||
@@ -1005,7 +929,7 @@ class DreamBoothDataset(Dataset):
|
||||
self.train_text_encoder_ti = train_text_encoder_ti
|
||||
# if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
|
||||
# we load the training data using load_dataset
|
||||
if args.dataset_name is not None:
|
||||
if dataset_name is not None:
|
||||
try:
|
||||
from datasets import load_dataset
|
||||
except ImportError:
|
||||
@@ -1018,26 +942,25 @@ class DreamBoothDataset(Dataset):
|
||||
# See more about loading custom images at
|
||||
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
|
||||
dataset = load_dataset(
|
||||
args.dataset_name,
|
||||
args.dataset_config_name,
|
||||
cache_dir=args.cache_dir,
|
||||
dataset_name,
|
||||
dataset_config_name,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
# Preprocessing the datasets.
|
||||
column_names = dataset["train"].column_names
|
||||
|
||||
# 6. Get the column names for input/target.
|
||||
if args.image_column is None:
|
||||
if image_column is None:
|
||||
image_column = column_names[0]
|
||||
logger.info(f"image column defaulting to {image_column}")
|
||||
else:
|
||||
image_column = args.image_column
|
||||
if image_column not in column_names:
|
||||
raise ValueError(
|
||||
f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
|
||||
f"`--image_column` value '{image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
|
||||
)
|
||||
instance_images = dataset["train"][image_column]
|
||||
|
||||
if args.caption_column is None:
|
||||
if caption_column is None:
|
||||
logger.info(
|
||||
"No caption column provided, defaulting to instance_prompt for all images. If your dataset "
|
||||
"contains captions/prompts for the images, make sure to specify the "
|
||||
@@ -1045,11 +968,11 @@ class DreamBoothDataset(Dataset):
|
||||
)
|
||||
self.custom_instance_prompts = None
|
||||
else:
|
||||
if args.caption_column not in column_names:
|
||||
if caption_column not in column_names:
|
||||
raise ValueError(
|
||||
f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
|
||||
f"`--caption_column` value '{caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
|
||||
)
|
||||
custom_instance_prompts = dataset["train"][args.caption_column]
|
||||
custom_instance_prompts = dataset["train"][caption_column]
|
||||
# create final list of captions according to --repeats
|
||||
self.custom_instance_prompts = []
|
||||
for caption in custom_instance_prompts:
|
||||
@@ -1243,7 +1166,7 @@ def tokenize_prompt(tokenizer, prompt, add_special_tokens=False):
|
||||
|
||||
|
||||
# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
|
||||
def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None, clip_skip=None):
|
||||
def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
|
||||
prompt_embeds_list = []
|
||||
|
||||
for i, text_encoder in enumerate(text_encoders):
|
||||
@@ -1255,16 +1178,13 @@ def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None, c
|
||||
text_input_ids = text_input_ids_list[i]
|
||||
|
||||
prompt_embeds = text_encoder(
|
||||
text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False
|
||||
text_input_ids.to(text_encoder.device),
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
if clip_skip is None:
|
||||
prompt_embeds = prompt_embeds[-1][-2]
|
||||
else:
|
||||
# "2" because SDXL always indexes from the penultimate layer.
|
||||
prompt_embeds = prompt_embeds[-1][-(clip_skip + 2)]
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|
||||
prompt_embeds_list.append(prompt_embeds)
|
||||
@@ -1280,16 +1200,9 @@ def main(args):
|
||||
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
|
||||
" Please use `huggingface-cli login` to authenticate with the Hub."
|
||||
)
|
||||
|
||||
if args.do_edm_style_training and args.snr_gamma is not None:
|
||||
raise ValueError("Min-SNR formulation is not supported when conducting EDM-style training.")
|
||||
|
||||
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
|
||||
# due to pytorch#99272, MPS does not yet support bfloat16.
|
||||
raise ValueError(
|
||||
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
||||
)
|
||||
|
||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||
|
||||
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
||||
@@ -1302,13 +1215,10 @@ def main(args):
|
||||
kwargs_handlers=[kwargs],
|
||||
)
|
||||
|
||||
# Disable AMP for MPS.
|
||||
if torch.backends.mps.is_available():
|
||||
accelerator.native_amp = False
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
import wandb
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
@@ -1336,8 +1246,7 @@ def main(args):
|
||||
cur_class_images = len(list(class_images_dir.iterdir()))
|
||||
|
||||
if cur_class_images < args.num_class_images:
|
||||
has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()
|
||||
torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32
|
||||
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
|
||||
if args.prior_generation_precision == "fp32":
|
||||
torch_dtype = torch.float32
|
||||
elif args.prior_generation_precision == "fp16":
|
||||
@@ -1495,12 +1404,6 @@ def main(args):
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
|
||||
# due to pytorch#99272, MPS does not yet support bfloat16.
|
||||
raise ValueError(
|
||||
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
||||
)
|
||||
|
||||
# Move unet, vae and text_encoder to device and cast to weight_dtype
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
@@ -1605,13 +1508,15 @@ def main(args):
|
||||
if isinstance(model, type(unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
if args.train_text_encoder:
|
||||
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(model)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
@@ -1659,7 +1564,6 @@ def main(args):
|
||||
)
|
||||
|
||||
if args.train_text_encoder:
|
||||
# Do we need to call `scale_lora_layers()` here?
|
||||
_set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
|
||||
|
||||
_set_state_dict_into_text_encoder(
|
||||
@@ -1674,14 +1578,14 @@ def main(args):
|
||||
if args.train_text_encoder:
|
||||
models.extend([text_encoder_one_, text_encoder_two_])
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
cast_training_params(models)
|
||||
cast_training_params(models)
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
||||
if args.allow_tf32 and torch.cuda.is_available():
|
||||
if args.allow_tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
if args.scale_lr:
|
||||
@@ -1807,7 +1711,12 @@ def main(args):
|
||||
instance_data_root=args.instance_data_dir,
|
||||
instance_prompt=args.instance_prompt,
|
||||
class_prompt=args.class_prompt,
|
||||
dataset_name=args.dataset_name,
|
||||
dataset_config_name=args.dataset_config_name,
|
||||
cache_dir=args.cache_dir,
|
||||
image_column=args.image_column,
|
||||
train_text_encoder_ti=args.train_text_encoder_ti,
|
||||
caption_column=args.caption_column,
|
||||
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
|
||||
token_abstraction_dict=token_abstraction_dict if args.train_text_encoder_ti else None,
|
||||
class_num=args.num_class_images,
|
||||
@@ -1831,6 +1740,8 @@ def main(args):
|
||||
|
||||
def compute_time_ids(crops_coords_top_left, original_size=None):
|
||||
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
|
||||
if original_size is None:
|
||||
original_size = (args.resolution, args.resolution)
|
||||
target_size = (args.resolution, args.resolution)
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
add_time_ids = torch.tensor([add_time_ids])
|
||||
@@ -1841,9 +1752,9 @@ def main(args):
|
||||
tokenizers = [tokenizer_one, tokenizer_two]
|
||||
text_encoders = [text_encoder_one, text_encoder_two]
|
||||
|
||||
def compute_text_embeddings(prompt, text_encoders, tokenizers, clip_skip):
|
||||
def compute_text_embeddings(prompt, text_encoders, tokenizers):
|
||||
with torch.no_grad():
|
||||
prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt, clip_skip)
|
||||
prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)
|
||||
prompt_embeds = prompt_embeds.to(accelerator.device)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
@@ -1853,7 +1764,7 @@ def main(args):
|
||||
# the redundant encoding.
|
||||
if freeze_text_encoder and not train_dataset.custom_instance_prompts:
|
||||
instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(
|
||||
args.instance_prompt, text_encoders, tokenizers, args.clip_skip
|
||||
args.instance_prompt, text_encoders, tokenizers
|
||||
)
|
||||
|
||||
# Handle class prompt for prior-preservation.
|
||||
@@ -1867,8 +1778,7 @@ def main(args):
|
||||
if freeze_text_encoder and not train_dataset.custom_instance_prompts:
|
||||
del tokenizers, text_encoders
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
|
||||
# pack the statically computed variables appropriately here. This is so that we don't
|
||||
@@ -1910,22 +1820,17 @@ def main(args):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
|
||||
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
|
||||
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
|
||||
num_training_steps_for_scheduler = (
|
||||
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
|
||||
)
|
||||
else:
|
||||
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=num_warmup_steps_for_scheduler,
|
||||
num_training_steps=num_training_steps_for_scheduler,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
num_cycles=args.lr_num_cycles,
|
||||
power=args.lr_power,
|
||||
)
|
||||
@@ -1942,14 +1847,8 @@ def main(args):
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
if overrode_max_train_steps:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
|
||||
logger.warning(
|
||||
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
|
||||
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
|
||||
f"This inconsistency may result in the learning rate scheduler not functioning properly."
|
||||
)
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
@@ -2047,8 +1946,8 @@ def main(args):
|
||||
text_encoder_two.train()
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
if args.train_text_encoder:
|
||||
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
|
||||
accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)
|
||||
text_encoder_one.text_model.embeddings.requires_grad_(True)
|
||||
text_encoder_two.text_model.embeddings.requires_grad_(True)
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
if pivoted:
|
||||
@@ -2063,7 +1962,7 @@ def main(args):
|
||||
if train_dataset.custom_instance_prompts:
|
||||
if freeze_text_encoder:
|
||||
prompt_embeds, unet_add_text_embeds = compute_text_embeddings(
|
||||
prompts, text_encoders, tokenizers, args.clip_skip
|
||||
prompts, text_encoders, tokenizers
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -2141,6 +2040,7 @@ def main(args):
|
||||
if freeze_text_encoder:
|
||||
unet_added_conditions = {
|
||||
"time_ids": add_time_ids,
|
||||
# "time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1),
|
||||
"text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1),
|
||||
}
|
||||
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
|
||||
@@ -2158,7 +2058,6 @@ def main(args):
|
||||
tokenizers=None,
|
||||
prompt=None,
|
||||
text_input_ids_list=[tokens_one, tokens_two],
|
||||
clip_skip=args.clip_skip,
|
||||
)
|
||||
unet_added_conditions.update(
|
||||
{"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)}
|
||||
@@ -2321,6 +2220,10 @@ def main(args):
|
||||
|
||||
if accelerator.is_main_process:
|
||||
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
|
||||
logger.info(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
f" {args.validation_prompt}."
|
||||
)
|
||||
# create pipeline
|
||||
if freeze_text_encoder:
|
||||
text_encoder_one = text_encoder_cls_one.from_pretrained(
|
||||
@@ -2347,29 +2250,70 @@ def main(args):
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline_args = {"prompt": args.validation_prompt}
|
||||
|
||||
images = log_validation(
|
||||
pipeline,
|
||||
args,
|
||||
accelerator,
|
||||
pipeline_args,
|
||||
epoch,
|
||||
)
|
||||
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
|
||||
scheduler_args = {}
|
||||
|
||||
if not args.do_edm_style_training:
|
||||
if "variance_type" in pipeline.scheduler.config:
|
||||
variance_type = pipeline.scheduler.config.variance_type
|
||||
|
||||
if variance_type in ["learned", "learned_range"]:
|
||||
variance_type = "fixed_small"
|
||||
|
||||
scheduler_args["variance_type"] = variance_type
|
||||
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
|
||||
pipeline.scheduler.config, **scheduler_args
|
||||
)
|
||||
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
pipeline_args = {"prompt": args.validation_prompt}
|
||||
if torch.backends.mps.is_available() or "playground" in args.pretrained_model_name_or_path:
|
||||
autocast_ctx = nullcontext()
|
||||
else:
|
||||
autocast_ctx = torch.autocast(accelerator.device.type)
|
||||
|
||||
with autocast_ctx:
|
||||
images = [
|
||||
pipeline(**pipeline_args, generator=generator).images[0]
|
||||
for _ in range(args.num_validation_images)
|
||||
]
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "tensorboard":
|
||||
np_images = np.stack([np.asarray(img) for img in images])
|
||||
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
|
||||
if tracker.name == "wandb":
|
||||
tracker.log(
|
||||
{
|
||||
"validation": [
|
||||
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
|
||||
for i, image in enumerate(images)
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
unet = unwrap_model(unet)
|
||||
unet = accelerator.unwrap_model(unet)
|
||||
unet = unet.to(torch.float32)
|
||||
unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one = unwrap_model(text_encoder_one)
|
||||
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
|
||||
text_encoder_lora_layers = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(text_encoder_one.to(torch.float32))
|
||||
)
|
||||
text_encoder_two = unwrap_model(text_encoder_two)
|
||||
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
|
||||
text_encoder_2_lora_layers = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(text_encoder_two.to(torch.float32))
|
||||
)
|
||||
@@ -2388,44 +2332,90 @@ def main(args):
|
||||
embeddings_path = f"{args.output_dir}/{args.output_dir}_emb.safetensors"
|
||||
embedding_handler.save_embeddings(embeddings_path)
|
||||
|
||||
# Final inference
|
||||
# Load previous pipeline
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
vae_path,
|
||||
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
|
||||
# load attention processors
|
||||
pipeline.load_lora_weights(args.output_dir)
|
||||
|
||||
# run inference
|
||||
images = []
|
||||
if args.validation_prompt and args.num_validation_images > 0:
|
||||
pipeline_args = {"prompt": args.validation_prompt, "num_inference_steps": 25}
|
||||
images = log_validation(
|
||||
pipeline,
|
||||
args,
|
||||
accelerator,
|
||||
pipeline_args,
|
||||
epoch,
|
||||
is_final_validation=True,
|
||||
# Final inference
|
||||
# Load previous pipeline
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
vae_path,
|
||||
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
|
||||
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
|
||||
scheduler_args = {}
|
||||
|
||||
if not args.do_edm_style_training:
|
||||
if "variance_type" in pipeline.scheduler.config:
|
||||
variance_type = pipeline.scheduler.config.variance_type
|
||||
|
||||
if variance_type in ["learned", "learned_range"]:
|
||||
variance_type = "fixed_small"
|
||||
|
||||
scheduler_args["variance_type"] = variance_type
|
||||
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
|
||||
pipeline.scheduler.config, **scheduler_args
|
||||
)
|
||||
|
||||
# load attention processors
|
||||
pipeline.load_lora_weights(args.output_dir)
|
||||
|
||||
# load new tokens
|
||||
if args.train_text_encoder_ti:
|
||||
state_dict = load_file(embeddings_path)
|
||||
all_new_tokens = []
|
||||
for key, value in token_abstraction_dict.items():
|
||||
all_new_tokens.extend(value)
|
||||
pipeline.load_textual_inversion(
|
||||
state_dict["clip_l"],
|
||||
token=all_new_tokens,
|
||||
text_encoder=pipeline.text_encoder,
|
||||
tokenizer=pipeline.tokenizer,
|
||||
)
|
||||
pipeline.load_textual_inversion(
|
||||
state_dict["clip_g"],
|
||||
token=all_new_tokens,
|
||||
text_encoder=pipeline.text_encoder_2,
|
||||
tokenizer=pipeline.tokenizer_2,
|
||||
)
|
||||
|
||||
# run inference
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
images = [
|
||||
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
|
||||
for _ in range(args.num_validation_images)
|
||||
]
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "tensorboard":
|
||||
np_images = np.stack([np.asarray(img) for img in images])
|
||||
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
|
||||
if tracker.name == "wandb":
|
||||
tracker.log(
|
||||
{
|
||||
"test": [
|
||||
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
|
||||
for i, image in enumerate(images)
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
# Convert to WebUI format
|
||||
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
|
||||
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
|
||||
kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict)
|
||||
save_file(kohya_state_dict, f"{args.output_dir}/{Path(args.output_dir).name}.safetensors")
|
||||
save_file(kohya_state_dict, f"{args.output_dir}/{args.output_dir}.safetensors")
|
||||
|
||||
save_model_card(
|
||||
model_id if not args.push_to_hub else repo_id,
|
||||
@@ -2440,7 +2430,6 @@ def main(args):
|
||||
repo_folder=args.output_dir,
|
||||
vae_path=args.pretrained_vae_model_name_or_path,
|
||||
)
|
||||
|
||||
if args.push_to_hub:
|
||||
upload_folder(
|
||||
repo_id=repo_id,
|
||||
|
||||
+213
-301
File diff suppressed because it is too large
Load Diff
@@ -71,7 +71,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
|
||||
**kwargs:
|
||||
Supports all the default DiffusionPipeline.get_config_dict kwargs viz..
|
||||
|
||||
cache_dir, force_download, proxies, local_files_only, token, revision, torch_dtype, device_map.
|
||||
cache_dir, resume_download, force_download, proxies, local_files_only, token, revision, torch_dtype, device_map.
|
||||
|
||||
alpha - The interpolation parameter. Ranges from 0 to 1. It affects the ratio in which the checkpoints are merged. A 0.8 alpha
|
||||
would mean that the first model checkpoints would affect the final result far less than an alpha of 0.2
|
||||
@@ -86,6 +86,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
|
||||
"""
|
||||
# Default kwargs from DiffusionPipeline
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
@@ -123,6 +124,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
|
||||
config_dict = DiffusionPipeline.load_config(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
@@ -158,6 +160,7 @@ class CheckpointMergerPipeline(DiffusionPipeline):
|
||||
else snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
|
||||
@@ -267,6 +267,7 @@ class IPAdapterFaceIDStableDiffusionPipeline(
|
||||
def load_ip_adapter_face_id(self, pretrained_model_name_or_path_or_dict, weight_name, **kwargs):
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
@@ -282,6 +283,7 @@ class IPAdapterFaceIDStableDiffusionPipeline(
|
||||
weights_name=weight_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# A SDXL pipeline can take unlimited weighted prompt
|
||||
#
|
||||
# Author: Andrew Zhu
|
||||
# GitHub: https://github.com/xhinker
|
||||
# Github: https://github.com/xhinker
|
||||
# Medium: https://medium.com/@xhinker
|
||||
## -----------------------------------------------------------
|
||||
|
||||
@@ -2165,7 +2165,7 @@ class SDXLLongPromptWeightingPipeline(
|
||||
|
||||
@classmethod
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
@@ -2188,7 +2188,7 @@ class SDXLLongPromptWeightingPipeline(
|
||||
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
|
||||
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
||||
|
||||
cls.write_lora_layers(
|
||||
self.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
save_directory=save_directory,
|
||||
is_main_process=is_main_process,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1339,7 +1339,7 @@ class DemoFusionSDXLPipeline(
|
||||
|
||||
@classmethod
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
@@ -1368,7 +1368,7 @@ class DemoFusionSDXLPipeline(
|
||||
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
|
||||
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
||||
|
||||
cls.write_lora_layers(
|
||||
self.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
save_directory=save_directory,
|
||||
is_main_process=is_main_process,
|
||||
|
||||
@@ -1,981 +0,0 @@
|
||||
# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
T5EncoderModel,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
|
||||
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from diffusers.models.autoencoders import AutoencoderKL
|
||||
from diffusers.models.transformers import SD3Transformer2DModel
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils import (
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
|
||||
>>> from diffusers import AutoPipelineForImage2Image
|
||||
>>> from diffusers.utils import load_image
|
||||
|
||||
>>> device = "cuda"
|
||||
>>> model_id_or_path = "stabilityai/stable-diffusion-3-medium-diffusers"
|
||||
>>> pipe = AutoPipelineForImage2Image.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
|
||||
>>> pipe = pipe.to(device)
|
||||
|
||||
>>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
||||
>>> init_image = load_image(url).resize((512, 512))
|
||||
|
||||
>>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k"
|
||||
|
||||
>>> images = pipe(prompt=prompt, image=init_image, strength=0.95, guidance_scale=7.5).images[0]
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class StableDiffusion3DifferentialImg2ImgPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Args:
|
||||
transformer ([`SD3Transformer2DModel`]):
|
||||
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModelWithProjection`]):
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
|
||||
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
|
||||
with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
|
||||
as its dimension.
|
||||
text_encoder_2 ([`CLIPTextModelWithProjection`]):
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
|
||||
specifically the
|
||||
[laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
|
||||
variant.
|
||||
text_encoder_3 ([`T5EncoderModel`]):
|
||||
Frozen text-encoder. Stable Diffusion 3 uses
|
||||
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
|
||||
[t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
tokenizer_2 (`CLIPTokenizer`):
|
||||
Second Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
tokenizer_3 (`T5TokenizerFast`):
|
||||
Tokenizer of class
|
||||
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformer: SD3Transformer2DModel,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModelWithProjection,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder_2: CLIPTextModelWithProjection,
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
text_encoder_3: T5EncoderModel,
|
||||
tokenizer_3: T5TokenizerFast,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
text_encoder_2=text_encoder_2,
|
||||
text_encoder_3=text_encoder_3,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_2=tokenizer_2,
|
||||
tokenizer_3=tokenizer_3,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor, vae_latent_channels=self.vae.config.latent_channels
|
||||
)
|
||||
self.mask_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_convert_grayscale=True
|
||||
)
|
||||
|
||||
self.tokenizer_max_length = self.tokenizer.model_max_length
|
||||
self.default_sample_size = self.transformer.config.sample_size
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
max_sequence_length: int = 256,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if self.text_encoder_3 is None:
|
||||
return torch.zeros(
|
||||
(
|
||||
batch_size * num_images_per_prompt,
|
||||
self.tokenizer_max_length,
|
||||
self.transformer.config.joint_attention_dim,
|
||||
),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
text_inputs = self.tokenizer_3(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because `max_sequence_length` is set to "
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
|
||||
|
||||
dtype = self.text_encoder_3.dtype
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
|
||||
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds
|
||||
def _get_clip_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
num_images_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
clip_model_index: int = 0,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
|
||||
clip_tokenizers = [self.tokenizer, self.tokenizer_2]
|
||||
clip_text_encoders = [self.text_encoder, self.text_encoder_2]
|
||||
|
||||
tokenizer = clip_tokenizers[clip_model_index]
|
||||
text_encoder = clip_text_encoders[clip_model_index]
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
||||
)
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
|
||||
if clip_skip is None:
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
else:
|
||||
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_3: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_3: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
max_sequence_length: int = 256,
|
||||
):
|
||||
r"""
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
||||
used in all text-encoders
|
||||
prompt_3 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
|
||||
used in all text-encoders
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
||||
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
|
||||
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
|
||||
`text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
||||
input argument.
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_2 = prompt_2 or prompt
|
||||
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
||||
|
||||
prompt_3 = prompt_3 or prompt
|
||||
prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
|
||||
|
||||
prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
clip_skip=clip_skip,
|
||||
clip_model_index=0,
|
||||
)
|
||||
prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
|
||||
prompt=prompt_2,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
clip_skip=clip_skip,
|
||||
clip_model_index=1,
|
||||
)
|
||||
clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
|
||||
|
||||
t5_prompt_embed = self._get_t5_prompt_embeds(
|
||||
prompt=prompt_3,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
clip_prompt_embeds = torch.nn.functional.pad(
|
||||
clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
|
||||
)
|
||||
|
||||
prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
|
||||
pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt_2 = negative_prompt_2 or negative_prompt
|
||||
negative_prompt_3 = negative_prompt_3 or negative_prompt
|
||||
|
||||
# normalize str to list
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
negative_prompt_2 = (
|
||||
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
|
||||
)
|
||||
negative_prompt_3 = (
|
||||
batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
|
||||
)
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
|
||||
negative_prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
clip_skip=None,
|
||||
clip_model_index=0,
|
||||
)
|
||||
negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
|
||||
negative_prompt_2,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
clip_skip=None,
|
||||
clip_model_index=1,
|
||||
)
|
||||
negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
|
||||
|
||||
t5_negative_prompt_embed = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt_3,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
negative_clip_prompt_embeds = torch.nn.functional.pad(
|
||||
negative_clip_prompt_embeds,
|
||||
(0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
|
||||
)
|
||||
|
||||
negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
|
||||
negative_pooled_prompt_embeds = torch.cat(
|
||||
[negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
|
||||
)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
prompt_2,
|
||||
prompt_3,
|
||||
strength,
|
||||
negative_prompt=None,
|
||||
negative_prompt_2=None,
|
||||
negative_prompt_3=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
pooled_prompt_embeds=None,
|
||||
negative_pooled_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
max_sequence_length=None,
|
||||
):
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt_2 is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt_3 is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
||||
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
||||
elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
|
||||
raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
elif negative_prompt_3 is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
)
|
||||
|
||||
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
if max_sequence_length is not None and max_sequence_length > 512:
|
||||
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
||||
|
||||
def get_timesteps(self, num_inference_steps, strength, device):
|
||||
# get the original timestep using init_timestep
|
||||
init_timestep = min(num_inference_steps * strength, num_inference_steps)
|
||||
|
||||
t_start = int(max(num_inference_steps - init_timestep, 0))
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
def prepare_latents(
|
||||
self, batch_size, num_channels_latents, height, width, image, timestep, dtype, device, generator=None
|
||||
):
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
int(height) // self.vae_scale_factor,
|
||||
int(width) // self.vae_scale_factor,
|
||||
)
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
elif isinstance(generator, list):
|
||||
init_latents = [
|
||||
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size)
|
||||
]
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
else:
|
||||
init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
|
||||
|
||||
init_latents = (init_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
|
||||
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
additional_image_per_prompt = batch_size // init_latents.shape[0]
|
||||
init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
|
||||
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
else:
|
||||
init_latents = torch.cat([init_latents], dim=0)
|
||||
|
||||
shape = init_latents.shape
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
init_latents = self.scheduler.scale_noise(init_latents, timestep, noise)
|
||||
latents = init_latents.to(device=device, dtype=dtype)
|
||||
|
||||
return latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def clip_skip(self):
|
||||
return self._clip_skip
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
prompt_3: Optional[Union[str, List[str]]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
image: PipelineImageInput = None,
|
||||
strength: float = 0.6,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 7.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_3: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 256,
|
||||
map: PipelineImageInput = None,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
||||
will be used instead
|
||||
prompt_3 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
|
||||
will be used instead
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
||||
`text_encoder_2`. If not defined, `negative_prompt` is used instead
|
||||
negative_prompt_3 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
|
||||
`text_encoder_3`. If not defined, `negative_prompt` is used instead
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
||||
input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
||||
of a plain tuple.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
# 0. Default height and width
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
prompt_2,
|
||||
prompt_3,
|
||||
strength,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt_2=negative_prompt_2,
|
||||
negative_prompt_3=negative_prompt_3,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_2=prompt_2,
|
||||
prompt_3=prompt_3,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt_2=negative_prompt_2,
|
||||
negative_prompt_3=negative_prompt_3,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
device=device,
|
||||
clip_skip=self.clip_skip,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
||||
|
||||
# 3. Preprocess image
|
||||
init_image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
|
||||
|
||||
map = self.mask_processor.preprocess(
|
||||
map, height=height // self.vae_scale_factor, width=width // self.vae_scale_factor
|
||||
).to(device)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
|
||||
# begin diff diff change
|
||||
total_time_steps = num_inference_steps
|
||||
# end diff diff change
|
||||
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
if latents is None:
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
init_image,
|
||||
latent_timestep,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# preparations for diff diff
|
||||
original_with_noise = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
init_image,
|
||||
timesteps,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
)
|
||||
thresholds = torch.arange(total_time_steps, dtype=map.dtype) / total_time_steps
|
||||
thresholds = thresholds.unsqueeze(1).unsqueeze(1).to(device)
|
||||
masks = map.squeeze() > thresholds
|
||||
# end diff diff preparations
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# diff diff
|
||||
if i == 0:
|
||||
latents = original_with_noise[:1]
|
||||
else:
|
||||
mask = masks[i].unsqueeze(0).to(latents.dtype)
|
||||
mask = mask.unsqueeze(1) # fit shape
|
||||
latents = original_with_noise[i] * mask + latents * (1 - mask)
|
||||
# end diff diff
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
negative_pooled_prompt_embeds = callback_outputs.pop(
|
||||
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
||||
)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
|
||||
else:
|
||||
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return StableDiffusion3PipelineOutput(images=image)
|
||||
@@ -467,6 +467,8 @@ def make_emblist(self, prompts):
|
||||
|
||||
|
||||
def split_dims(xs, height, width):
|
||||
xs = xs
|
||||
|
||||
def repeat_div(x, y):
|
||||
while y > 0:
|
||||
x = math.ceil(x / 2)
|
||||
|
||||
@@ -783,6 +783,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
|
||||
@validate_hf_hub_args
|
||||
def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
token = kwargs.pop("token", None)
|
||||
@@ -794,6 +795,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline):
|
||||
else snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
|
||||
@@ -783,6 +783,7 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
|
||||
@validate_hf_hub_args
|
||||
def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
token = kwargs.pop("token", None)
|
||||
@@ -794,6 +795,7 @@ class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline):
|
||||
else snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
|
||||
@@ -695,6 +695,7 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
@validate_hf_hub_args
|
||||
def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
token = kwargs.pop("token", None)
|
||||
@@ -706,6 +707,7 @@ class TensorRTStableDiffusionPipeline(StableDiffusionPipeline):
|
||||
else snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
|
||||
@@ -282,7 +282,7 @@ class StableDiffusionTiledUpscalePipeline(StableDiffusionUpscalePipeline):
|
||||
def main():
|
||||
# Run a demo
|
||||
model_id = "stabilityai/stable-diffusion-x4-upscaler"
|
||||
pipe = StableDiffusionTiledUpscalePipeline.from_pretrained(model_id, variant="fp16", torch_dtype=torch.float16)
|
||||
pipe = StableDiffusionTiledUpscalePipeline.from_pretrained(model_id, revision="fp16", torch_dtype=torch.float16)
|
||||
pipe = pipe.to("cuda")
|
||||
image = Image.open("../../docs/source/imgs/diffusers_library.jpg")
|
||||
|
||||
|
||||
@@ -1088,22 +1088,17 @@ def main(args):
|
||||
)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
|
||||
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
|
||||
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
|
||||
num_training_steps_for_scheduler = (
|
||||
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
|
||||
)
|
||||
else:
|
||||
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=num_warmup_steps_for_scheduler,
|
||||
num_training_steps=num_training_steps_for_scheduler,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
num_cycles=args.lr_num_cycles,
|
||||
power=args.lr_power,
|
||||
)
|
||||
@@ -1115,14 +1110,8 @@ def main(args):
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
if overrode_max_train_steps:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
|
||||
logger.warning(
|
||||
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
|
||||
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
|
||||
f"This inconsistency may result in the learning rate scheduler not functioning properly."
|
||||
)
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
|
||||
@@ -147,40 +147,6 @@ accelerate launch train_dreambooth_lora_sd3.py \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
### Text Encoder Training
|
||||
Alongside the transformer, LoRA fine-tuning of the CLIP text encoders is now also supported.
|
||||
To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind:
|
||||
|
||||
> [!NOTE]
|
||||
> SD3 has three text encoders (CLIP L/14, OpenCLIP bigG/14, and T5-v1.1-XXL).
|
||||
By enabling `--train_text_encoder`, LoRA fine-tuning of both **CLIP encoders** is performed. At the moment, T5 fine-tuning is not supported and weights remain frozen when text encoder training is enabled.
|
||||
|
||||
To perform DreamBooth LoRA with text-encoder training, run:
|
||||
```bash
|
||||
export MODEL_NAME="stabilityai/stable-diffusion-3-medium-diffusers"
|
||||
export OUTPUT_DIR="trained-sd3-lora"
|
||||
|
||||
accelerate launch train_dreambooth_lora_sd3.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--dataset_name="Norod78/Yarn-art-style" \
|
||||
--instance_prompt="a photo of TOK yarn art dog" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--train_text_encoder\
|
||||
--gradient_accumulation_steps=1 \
|
||||
--optimizer="prodigy"\
|
||||
--learning_rate=1.0 \
|
||||
--text_encoder_lr=1.0 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=1500 \
|
||||
--rank=32 \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
## Other notes
|
||||
|
||||
We default to the "logit_normal" weighting scheme for the loss following the SD3 paper. Thanks to @bghira for helping us discover that for other weighting schemes supported from the training script, training may incur numerical instabilities.
|
||||
@@ -261,7 +261,7 @@ The authors found that by using DoRA, both the learning capacity and training st
|
||||
**Usage**
|
||||
1. To use DoRA you need to upgrade the installation of `peft`:
|
||||
```bash
|
||||
pip install -U peft
|
||||
pip install-U peft
|
||||
```
|
||||
2. Enable DoRA training by adding this flag
|
||||
```bash
|
||||
|
||||
@@ -1,165 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import safetensors
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
class DreamBoothLoRASD3(ExamplesTestsAccelerate):
|
||||
instance_data_dir = "docs/source/en/imgs"
|
||||
instance_prompt = "photo"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-sd3-pipe"
|
||||
script_path = "examples/dreambooth/train_dreambooth_lora_sd3.py"
|
||||
|
||||
def test_dreambooth_lora_sd3(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"transformer"` in their names.
|
||||
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_transformer)
|
||||
|
||||
def test_dreambooth_lora_text_encoder_sd3(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--train_text_encoder
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
starts_with_expected_prefix = all(
|
||||
(key.startswith("transformer") or key.startswith("text_encoder")) for key in lora_state_dict.keys()
|
||||
)
|
||||
self.assertTrue(starts_with_expected_prefix)
|
||||
|
||||
def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=6
|
||||
--checkpoints_total_limit=2
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=4
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
|
||||
|
||||
resume_run_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=8
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--checkpoints_total_limit=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
|
||||
@@ -1,203 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
from diffusers import DiffusionPipeline, SD3Transformer2DModel
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger()
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
class DreamBoothSD3(ExamplesTestsAccelerate):
|
||||
instance_data_dir = "docs/source/en/imgs"
|
||||
instance_prompt = "photo"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-sd3-pipe"
|
||||
script_path = "examples/dreambooth/train_dreambooth_sd3.py"
|
||||
|
||||
def test_dreambooth(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "transformer", "diffusion_pytorch_model.safetensors")))
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
|
||||
|
||||
def test_dreambooth_checkpointing(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Run training script with checkpointing
|
||||
# max_train_steps == 4, checkpointing_steps == 2
|
||||
# Should create checkpoints at steps 2, 4
|
||||
|
||||
initial_run_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 4
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--checkpointing_steps=2
|
||||
--seed=0
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + initial_run_args)
|
||||
|
||||
# check can run the original fully trained output pipeline
|
||||
pipe = DiffusionPipeline.from_pretrained(tmpdir)
|
||||
pipe(self.instance_prompt, num_inference_steps=1)
|
||||
|
||||
# check checkpoint directories exist
|
||||
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
|
||||
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
|
||||
|
||||
# check can run an intermediate checkpoint
|
||||
transformer = SD3Transformer2DModel.from_pretrained(tmpdir, subfolder="checkpoint-2/transformer")
|
||||
pipe = DiffusionPipeline.from_pretrained(self.pretrained_model_name_or_path, transformer=transformer)
|
||||
pipe(self.instance_prompt, num_inference_steps=1)
|
||||
|
||||
# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
|
||||
shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
|
||||
|
||||
# Run training script for 7 total steps resuming from checkpoint 4
|
||||
|
||||
resume_run_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
|
||||
--instance_data_dir {self.instance_data_dir}
|
||||
--instance_prompt {self.instance_prompt}
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 6
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--seed=0
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
# check can run new fully trained pipeline
|
||||
pipe = DiffusionPipeline.from_pretrained(tmpdir)
|
||||
pipe(self.instance_prompt, num_inference_steps=1)
|
||||
|
||||
# check old checkpoints do not exist
|
||||
self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
|
||||
|
||||
# check new checkpoints exist
|
||||
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
|
||||
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6")))
|
||||
|
||||
def test_dreambooth_checkpointing_checkpoints_total_limit(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=6
|
||||
--checkpoints_total_limit=2
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-4", "checkpoint-6"},
|
||||
)
|
||||
|
||||
def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=4
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertEqual(
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-2", "checkpoint-4"},
|
||||
)
|
||||
|
||||
resume_run_args = f"""
|
||||
{self.script_path}
|
||||
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
|
||||
--instance_data_dir={self.instance_data_dir}
|
||||
--output_dir={tmpdir}
|
||||
--instance_prompt={self.instance_prompt}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--max_train_steps=8
|
||||
--checkpointing_steps=2
|
||||
--resume_from_checkpoint=checkpoint-4
|
||||
--checkpoints_total_limit=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + resume_run_args)
|
||||
|
||||
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
|
||||
@@ -54,7 +54,6 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import (
|
||||
_set_state_dict_into_text_encoder,
|
||||
cast_training_params,
|
||||
compute_density_for_timestep_sampling,
|
||||
compute_loss_weighting_for_sd3,
|
||||
@@ -81,7 +80,6 @@ def save_model_card(
|
||||
repo_id: str,
|
||||
images=None,
|
||||
base_model: str = None,
|
||||
train_text_encoder=False,
|
||||
instance_prompt=None,
|
||||
validation_prompt=None,
|
||||
repo_folder=None,
|
||||
@@ -101,41 +99,21 @@ def save_model_card(
|
||||
|
||||
## Model description
|
||||
|
||||
These are {repo_id} DreamBooth LoRA weights for {base_model}.
|
||||
These are {repo_id} DreamBooth weights for {base_model}.
|
||||
|
||||
The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [SD3 diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_sd3.md).
|
||||
|
||||
Was LoRA for the text encoder enabled? {train_text_encoder}.
|
||||
The weights were trained using [DreamBooth](https://dreambooth.github.io/).
|
||||
|
||||
## Trigger words
|
||||
|
||||
You should use `{instance_prompt}` to trigger the image generation.
|
||||
You should use {instance_prompt} to trigger the image generation.
|
||||
|
||||
## Download model
|
||||
|
||||
[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.
|
||||
|
||||
## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained('stabilityai/stable-diffusion-3-medium-diffusers', torch_dtype=torch.float16).to('cuda')
|
||||
pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')
|
||||
image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
|
||||
```
|
||||
|
||||
### Use it with UIs such as AUTOMATIC1111, Comfy UI, SD.Next, Invoke
|
||||
|
||||
- **LoRA**: download **[`diffusers_lora_weights.safetensors` here 💾](/{repo_id}/blob/main/diffusers_lora_weights.safetensors)**.
|
||||
- Rename it and place it on your `models/Lora` folder.
|
||||
- On AUTOMATIC1111, load the LoRA by adding `<lora:your_new_name:1>` to your prompt. On ComfyUI just [load it as a regular LoRA](https://comfyanonymous.github.io/ComfyUI_examples/lora/).
|
||||
|
||||
For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
|
||||
[Download]({repo_id}/tree/main) them in the Files & versions tab.
|
||||
|
||||
## License
|
||||
|
||||
Please adhere to the licensing terms as described [here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE).
|
||||
Please adhere to the licensing terms as described `[here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE)`.
|
||||
"""
|
||||
model_card = load_or_create_model_card(
|
||||
repo_id_or_path=repo_id,
|
||||
@@ -150,7 +128,6 @@ Please adhere to the licensing terms as described [here](https://huggingface.co/
|
||||
"text-to-image",
|
||||
"diffusers-training",
|
||||
"diffusers",
|
||||
"lora",
|
||||
"sd3",
|
||||
"sd3-diffusers",
|
||||
"template:sd-lora",
|
||||
@@ -404,12 +381,6 @@ def parse_args(input_args=None):
|
||||
action="store_true",
|
||||
help="whether to randomly flip images horizontally",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_text_encoder",
|
||||
action="store_true",
|
||||
help="Whether to train the text encoder (clip text encoders only). If set, the text encoder should be float32 precision.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
@@ -885,25 +856,19 @@ def _encode_prompt_with_t5(
|
||||
prompt=None,
|
||||
num_images_per_prompt=1,
|
||||
device=None,
|
||||
text_input_ids=None,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if tokenizer is not None:
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
else:
|
||||
if text_input_ids is None:
|
||||
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
|
||||
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
|
||||
|
||||
dtype = text_encoder.dtype
|
||||
@@ -923,26 +888,20 @@ def _encode_prompt_with_clip(
|
||||
tokenizer,
|
||||
prompt: str,
|
||||
device=None,
|
||||
text_input_ids=None,
|
||||
num_images_per_prompt: int = 1,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if tokenizer is not None:
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=77,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
else:
|
||||
if text_input_ids is None:
|
||||
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=77,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
|
||||
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
@@ -964,7 +923,6 @@ def encode_prompt(
|
||||
max_sequence_length,
|
||||
device=None,
|
||||
num_images_per_prompt: int = 1,
|
||||
text_input_ids_list=None,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
@@ -973,14 +931,13 @@ def encode_prompt(
|
||||
|
||||
clip_prompt_embeds_list = []
|
||||
clip_pooled_prompt_embeds_list = []
|
||||
for i, (tokenizer, text_encoder) in enumerate(zip(clip_tokenizers, clip_text_encoders)):
|
||||
for tokenizer, text_encoder in zip(clip_tokenizers, clip_text_encoders):
|
||||
prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip(
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
device=device if device is not None else text_encoder.device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
text_input_ids=text_input_ids_list[i] if text_input_ids_list else None,
|
||||
)
|
||||
clip_prompt_embeds_list.append(prompt_embeds)
|
||||
clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds)
|
||||
@@ -994,7 +951,6 @@ def encode_prompt(
|
||||
max_sequence_length,
|
||||
prompt=prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
text_input_ids=text_input_ids_list[-1] if text_input_ids_list else None,
|
||||
device=device if device is not None else text_encoders[-1].device,
|
||||
)
|
||||
|
||||
@@ -1189,9 +1145,6 @@ def main(args):
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
transformer.enable_gradient_checkpointing()
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one.gradient_checkpointing_enable()
|
||||
text_encoder_two.gradient_checkpointing_enable()
|
||||
|
||||
# now we will add new LoRA weights to the attention layers
|
||||
transformer_lora_config = LoraConfig(
|
||||
@@ -1202,16 +1155,6 @@ def main(args):
|
||||
)
|
||||
transformer.add_adapter(transformer_lora_config)
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
)
|
||||
text_encoder_one.add_adapter(text_lora_config)
|
||||
text_encoder_two.add_adapter(text_lora_config)
|
||||
|
||||
def unwrap_model(model):
|
||||
model = accelerator.unwrap_model(model)
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
@@ -1221,16 +1164,10 @@ def main(args):
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if accelerator.is_main_process:
|
||||
transformer_lora_layers_to_save = None
|
||||
text_encoder_one_lora_layers_to_save = None
|
||||
text_encoder_two_lora_layers_to_save = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
@@ -1238,26 +1175,17 @@ def main(args):
|
||||
weights.pop()
|
||||
|
||||
StableDiffusion3Pipeline.save_lora_weights(
|
||||
output_dir,
|
||||
transformer_lora_layers=transformer_lora_layers_to_save,
|
||||
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
|
||||
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
|
||||
output_dir, transformer_lora_layers=transformer_lora_layers_to_save
|
||||
)
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
transformer_ = None
|
||||
text_encoder_one_ = None
|
||||
text_encoder_two_ = None
|
||||
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_ = model
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_ = model
|
||||
elif isinstance(model, type(unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_ = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
@@ -1276,21 +1204,12 @@ def main(args):
|
||||
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
||||
f" {unexpected_keys}. "
|
||||
)
|
||||
if args.train_text_encoder:
|
||||
# Do we need to call `scale_lora_layers()` here?
|
||||
_set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
|
||||
|
||||
_set_state_dict_into_text_encoder(
|
||||
lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_
|
||||
)
|
||||
|
||||
# Make sure the trainable params are in float32. This is again needed since the base models
|
||||
# are in `weight_dtype`. More details:
|
||||
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
|
||||
if args.mixed_precision == "fp16":
|
||||
models = [transformer_]
|
||||
if args.train_text_encoder:
|
||||
models.extend([text_encoder_one_, text_encoder_two_])
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
cast_training_params(models)
|
||||
|
||||
@@ -1310,37 +1229,14 @@ def main(args):
|
||||
# Make sure the trainable params are in float32.
|
||||
if args.mixed_precision == "fp16":
|
||||
models = [transformer]
|
||||
if args.train_text_encoder:
|
||||
models.extend([text_encoder_one, text_encoder_two])
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
cast_training_params(models, dtype=torch.float32)
|
||||
|
||||
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
|
||||
if args.train_text_encoder:
|
||||
text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
|
||||
text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters()))
|
||||
|
||||
# Optimization parameters
|
||||
transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
|
||||
if args.train_text_encoder:
|
||||
# different learning rate for text encoder and unet
|
||||
text_lora_parameters_one_with_lr = {
|
||||
"params": text_lora_parameters_one,
|
||||
"weight_decay": args.adam_weight_decay_text_encoder,
|
||||
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
|
||||
}
|
||||
text_lora_parameters_two_with_lr = {
|
||||
"params": text_lora_parameters_two,
|
||||
"weight_decay": args.adam_weight_decay_text_encoder,
|
||||
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
|
||||
}
|
||||
params_to_optimize = [
|
||||
transformer_parameters_with_lr,
|
||||
text_lora_parameters_one_with_lr,
|
||||
text_lora_parameters_two_with_lr,
|
||||
]
|
||||
else:
|
||||
params_to_optimize = [transformer_parameters_with_lr]
|
||||
params_to_optimize = [transformer_parameters_with_lr]
|
||||
|
||||
# Optimizer creation
|
||||
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
|
||||
@@ -1421,33 +1317,31 @@ def main(args):
|
||||
num_workers=args.dataloader_num_workers,
|
||||
)
|
||||
|
||||
if not args.train_text_encoder:
|
||||
tokenizers = [tokenizer_one, tokenizer_two, tokenizer_three]
|
||||
text_encoders = [text_encoder_one, text_encoder_two, text_encoder_three]
|
||||
tokenizers = [tokenizer_one, tokenizer_two, tokenizer_three]
|
||||
text_encoders = [text_encoder_one, text_encoder_two, text_encoder_three]
|
||||
|
||||
def compute_text_embeddings(prompt, text_encoders, tokenizers):
|
||||
with torch.no_grad():
|
||||
prompt_embeds, pooled_prompt_embeds = encode_prompt(
|
||||
text_encoders, tokenizers, prompt, args.max_sequence_length
|
||||
)
|
||||
prompt_embeds = prompt_embeds.to(accelerator.device)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
def compute_text_embeddings(prompt, text_encoders, tokenizers):
|
||||
with torch.no_grad():
|
||||
prompt_embeds, pooled_prompt_embeds = encode_prompt(
|
||||
text_encoders, tokenizers, prompt, args.max_sequence_length
|
||||
)
|
||||
prompt_embeds = prompt_embeds.to(accelerator.device)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
|
||||
if not train_dataset.custom_instance_prompts:
|
||||
instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(
|
||||
args.instance_prompt, text_encoders, tokenizers
|
||||
)
|
||||
|
||||
# Handle class prompt for prior-preservation.
|
||||
if args.with_prior_preservation:
|
||||
if not args.train_text_encoder:
|
||||
class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(
|
||||
args.class_prompt, text_encoders, tokenizers
|
||||
)
|
||||
class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(
|
||||
args.class_prompt, text_encoders, tokenizers
|
||||
)
|
||||
|
||||
# Clear the memory here
|
||||
if not args.train_text_encoder and train_dataset.custom_instance_prompts:
|
||||
if not train_dataset.custom_instance_prompts:
|
||||
del tokenizers, text_encoders
|
||||
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
|
||||
del text_encoder_one, text_encoder_two, text_encoder_three
|
||||
@@ -1460,13 +1354,12 @@ def main(args):
|
||||
# have to pass them to the dataloader.
|
||||
|
||||
if not train_dataset.custom_instance_prompts:
|
||||
if not args.train_text_encoder:
|
||||
prompt_embeds = instance_prompt_hidden_states
|
||||
pooled_prompt_embeds = instance_pooled_prompt_embeds
|
||||
if args.with_prior_preservation:
|
||||
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
|
||||
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0)
|
||||
# if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
|
||||
prompt_embeds = instance_prompt_hidden_states
|
||||
pooled_prompt_embeds = instance_pooled_prompt_embeds
|
||||
if args.with_prior_preservation:
|
||||
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
|
||||
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0)
|
||||
# if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
|
||||
# batch prompts on all training steps
|
||||
else:
|
||||
tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
|
||||
@@ -1497,25 +1390,9 @@ def main(args):
|
||||
)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
# Prepare everything with our `accelerator`.
|
||||
if args.train_text_encoder:
|
||||
(
|
||||
transformer,
|
||||
text_encoder_one,
|
||||
text_encoder_two,
|
||||
optimizer,
|
||||
train_dataloader,
|
||||
lr_scheduler,
|
||||
) = accelerator.prepare(
|
||||
transformer, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
assert text_encoder_one is not None
|
||||
assert text_encoder_two is not None
|
||||
assert text_encoder_three is not None
|
||||
else:
|
||||
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
transformer, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
transformer, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
@@ -1593,13 +1470,6 @@ def main(args):
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
transformer.train()
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one.train()
|
||||
text_encoder_two.train()
|
||||
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
|
||||
accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True)
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
models_to_accumulate = [transformer]
|
||||
@@ -1609,30 +1479,7 @@ def main(args):
|
||||
|
||||
# encode batch prompts when custom prompts are provided for each image -
|
||||
if train_dataset.custom_instance_prompts:
|
||||
if not args.train_text_encoder:
|
||||
prompt_embeds, pooled_prompt_embeds = compute_text_embeddings(
|
||||
prompts, text_encoders, tokenizers
|
||||
)
|
||||
else:
|
||||
tokens_one = tokenize_prompt(tokenizer_one, prompts)
|
||||
tokens_two = tokenize_prompt(tokenizer_two, prompts)
|
||||
tokens_three = tokenize_prompt(tokenizer_three, prompts)
|
||||
prompt_embeds, pooled_prompt_embeds = encode_prompt(
|
||||
text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three],
|
||||
tokenizers=[None, None, None],
|
||||
prompt=prompts,
|
||||
max_sequence_length=args.max_sequence_length,
|
||||
text_input_ids_list=[tokens_one, tokens_two, tokens_three],
|
||||
)
|
||||
else:
|
||||
if args.train_text_encoder:
|
||||
prompt_embeds, pooled_prompt_embeds = encode_prompt(
|
||||
text_encoders=[text_encoder_one, text_encoder_two, text_encoder_three],
|
||||
tokenizers=[None, None, tokenizer_three],
|
||||
prompt=args.instance_prompt,
|
||||
max_sequence_length=args.max_sequence_length,
|
||||
text_input_ids_list=[tokens_one, tokens_two, tokens_three],
|
||||
)
|
||||
prompt_embeds, pooled_prompt_embeds = compute_text_embeddings(prompts, text_encoders, tokenizers)
|
||||
|
||||
# Convert images to latent space
|
||||
model_input = vae.encode(pixel_values).latent_dist.sample()
|
||||
@@ -1706,13 +1553,7 @@ def main(args):
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
params_to_clip = (
|
||||
itertools.chain(
|
||||
transformer_lora_parameters, text_lora_parameters_one, text_lora_parameters_two
|
||||
)
|
||||
if args.train_text_encoder
|
||||
else transformer_lora_parameters
|
||||
)
|
||||
params_to_clip = transformer_lora_parameters
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
@@ -1759,11 +1600,10 @@ def main(args):
|
||||
|
||||
if accelerator.is_main_process:
|
||||
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
|
||||
if not args.train_text_encoder:
|
||||
# create pipeline
|
||||
text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
|
||||
text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
|
||||
)
|
||||
# create pipeline
|
||||
text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
|
||||
text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
|
||||
)
|
||||
pipeline = StableDiffusion3Pipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
@@ -1783,9 +1623,7 @@ def main(args):
|
||||
pipeline_args=pipeline_args,
|
||||
epoch=epoch,
|
||||
)
|
||||
if not args.train_text_encoder:
|
||||
del text_encoder_one, text_encoder_two, text_encoder_three
|
||||
|
||||
del text_encoder_one, text_encoder_two, text_encoder_three
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
@@ -1796,20 +1634,15 @@ def main(args):
|
||||
transformer = transformer.to(torch.float32)
|
||||
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one = unwrap_model(text_encoder_one)
|
||||
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
|
||||
text_encoder_two = unwrap_model(text_encoder_two)
|
||||
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two.to(torch.float32))
|
||||
else:
|
||||
text_encoder_lora_layers = None
|
||||
text_encoder_2_lora_layers = None
|
||||
|
||||
StableDiffusion3Pipeline.save_lora_weights(
|
||||
save_directory=args.output_dir,
|
||||
transformer_lora_layers=transformer_lora_layers,
|
||||
text_encoder_lora_layers=text_encoder_lora_layers,
|
||||
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
|
||||
save_directory=args.output_dir, transformer_lora_layers=transformer_lora_layers
|
||||
)
|
||||
|
||||
pipeline = StableDiffusion3Pipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
|
||||
# Final inference
|
||||
@@ -1843,7 +1676,6 @@ def main(args):
|
||||
base_model=args.pretrained_model_name_or_path,
|
||||
instance_prompt=args.instance_prompt,
|
||||
validation_prompt=args.validation_prompt,
|
||||
train_text_encoder=args.train_text_encoder,
|
||||
repo_folder=args.output_dir,
|
||||
)
|
||||
upload_folder(
|
||||
|
||||
@@ -95,22 +95,17 @@ def save_model_card(
|
||||
|
||||
These are {repo_id} DreamBooth weights for {base_model}.
|
||||
|
||||
The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [SD3 diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_sd3.md).
|
||||
The weights were trained using [DreamBooth](https://dreambooth.github.io/).
|
||||
|
||||
Was the text encoder fine-tuned? {train_text_encoder}.
|
||||
Text encoder was fine-tuned: {train_text_encoder}.
|
||||
|
||||
## Trigger words
|
||||
|
||||
You should use `{instance_prompt}` to trigger the image generation.
|
||||
You should use {instance_prompt} to trigger the image generation.
|
||||
|
||||
## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
|
||||
## Download model
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained('{repo_id}', torch_dtype=torch.float16).to('cuda')
|
||||
image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
|
||||
```
|
||||
[Download]({repo_id}/tree/main) them in the Files & versions tab.
|
||||
|
||||
## License
|
||||
|
||||
|
||||
@@ -1195,7 +1195,7 @@ def main(args):
|
||||
|
||||
# Resolve the c parameter for the Pseudo-Huber loss
|
||||
if args.huber_c is None:
|
||||
args.huber_c = 0.00054 * args.resolution * math.sqrt(unwrap_model(unet).config.in_channels)
|
||||
args.huber_c = 0.00054 * args.resolution * math.sqrt(unet.config.in_channels)
|
||||
|
||||
# Get current number of discretization steps N according to our discretization curriculum
|
||||
current_discretization_steps = get_discretization_steps(
|
||||
|
||||
@@ -30,7 +30,7 @@ accelerate launch finetune_instruct_pix2pix.py \
|
||||
## Inference
|
||||
After training the model and the lora weight of the model is stored in the ```$OUTPUT_DIR```.
|
||||
|
||||
```py
|
||||
```bash
|
||||
# load the base model pipeline
|
||||
pipe_lora = StableDiffusionInstructPix2PixPipeline.from_pretrained("timbrooks/instruct-pix2pix")
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ This aims to provide diffusers examples with Intel optimizations such as Bfloat1
|
||||
|
||||
## Accelerating the fine-tuning for textual inversion
|
||||
|
||||
We accelerate the fine-tuning for textual inversion with Intel Extension for PyTorch. The [examples](textual_inversion) enable both single node and multi-node distributed training with Bfloat16 support on Intel Xeon Scalable Processor.
|
||||
We accelereate the fine-tuning for textual inversion with Intel Extension for PyTorch. The [examples](textual_inversion) enable both single node and multi-node distributed training with Bfloat16 support on Intel Xeon Scalable Processor.
|
||||
|
||||
## Accelerating the inference for Stable Diffusion using Bfloat16
|
||||
|
||||
|
||||
@@ -323,7 +323,7 @@ accelerate launch train_dreambooth.py \
|
||||
|
||||
### Using DreamBooth for other pipelines than Stable Diffusion
|
||||
|
||||
Altdiffusion also supports dreambooth now, the running command is basically the same as above, all you need to do is replace the `MODEL_NAME` like this:
|
||||
Altdiffusion also support dreambooth now, the runing comman is basically the same as above, all you need to do is replace the `MODEL_NAME` like this:
|
||||
One can now simply change the `pretrained_model_name_or_path` to another architecture such as [`AltDiffusion`](https://huggingface.co/docs/diffusers/api/pipelines/alt_diffusion).
|
||||
|
||||
```
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
# Running Stable Diffusion 3 DreamBooth LoRA training under 16GB
|
||||
|
||||
This is an **EDUCATIONAL** project that provides utilities for DreamBooth LoRA training for [Stable Diffusion 3 (SD3)](ttps://huggingface.co/papers/2403.03206) under 16GB GPU VRAM. This means you can successfully try out this project using a [free-tier Colab Notebook](https://colab.research.google.com/github/huggingface/diffusers/blob/main/examples/research_projects/sd3_lora_colab/sd3_dreambooth_lora_16gb.ipynb) instance. 🤗
|
||||
|
||||
> [!NOTE]
|
||||
> SD3 is gated, so you need to make sure you agree to [share your contact info](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers) to access the model before using it with Diffusers. Once you have access, you need to log in so your system knows you’re authorized. Use the command below to log in:
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
|
||||
|
||||
For setup, inference code, and details on how to run the code, please follow the Colab Notebook provided above.
|
||||
|
||||
## How
|
||||
|
||||
We make use of several techniques to make this possible:
|
||||
|
||||
* Compute the embeddings from the instance prompt and serialize them for later reuse. This is implemented in the [`compute_embeddings.py`](./compute_embeddings.py) script. We use an 8bit (as introduced in [`LLM.int8()`](https://arxiv.org/abs/2208.07339)) T5 to reduce memory requirements to ~10.5GB.
|
||||
* In the `train_dreambooth_sd3_lora_miniature.py` script, we make use of:
|
||||
* 8bit Adam for optimization through the `bitsandbytes` library.
|
||||
* Gradient checkpointing and gradient accumulation.
|
||||
* FP16 precision.
|
||||
* Flash attention through `F.scaled_dot_product_attention()`.
|
||||
|
||||
Computing the text embeddings is arguably the most memory-intensive part in the pipeline as SD3 employs three text encoders. If we run them in FP32, it will take about 20GB of VRAM. With FP16, we are down to 12GB.
|
||||
|
||||
|
||||
## Gotchas
|
||||
|
||||
This project is educational. It exists to showcase the possibility of fine-tuning a big diffusion system on consumer GPUs. But additional components might have to be added to obtain state-of-the-art performance. Below are some commonly known gotchas that users should be aware of:
|
||||
|
||||
* Training of text encoders is purposefully disabled.
|
||||
* Techniques such as prior-preservation is unsupported.
|
||||
* Custom instance captions for instance images are unsupported, but this should be relatively easy to integrate.
|
||||
|
||||
Hopefully, this project gives you a template to extend it further to suit your needs.
|
||||
@@ -1,123 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import hashlib
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
from diffusers import StableDiffusion3Pipeline
|
||||
|
||||
|
||||
PROMPT = "a photo of sks dog"
|
||||
MAX_SEQ_LENGTH = 77
|
||||
LOCAL_DATA_DIR = "dog"
|
||||
OUTPUT_PATH = "sample_embeddings.parquet"
|
||||
|
||||
|
||||
def bytes_to_giga_bytes(bytes):
|
||||
return bytes / 1024 / 1024 / 1024
|
||||
|
||||
|
||||
def generate_image_hash(image_path):
|
||||
with open(image_path, "rb") as f:
|
||||
img_data = f.read()
|
||||
return hashlib.sha256(img_data).hexdigest()
|
||||
|
||||
|
||||
def load_sd3_pipeline():
|
||||
id = "stabilityai/stable-diffusion-3-medium-diffusers"
|
||||
text_encoder = T5EncoderModel.from_pretrained(id, subfolder="text_encoder_3", load_in_8bit=True, device_map="auto")
|
||||
pipeline = StableDiffusion3Pipeline.from_pretrained(
|
||||
id, text_encoder_3=text_encoder, transformer=None, vae=None, device_map="balanced"
|
||||
)
|
||||
return pipeline
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_embeddings(pipeline, prompt, max_sequence_length):
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = pipeline.encode_prompt(prompt=prompt, prompt_2=None, prompt_3=None, max_sequence_length=max_sequence_length)
|
||||
|
||||
print(
|
||||
f"{prompt_embeds.shape=}, {negative_prompt_embeds.shape=}, {pooled_prompt_embeds.shape=}, {negative_pooled_prompt_embeds.shape}"
|
||||
)
|
||||
|
||||
max_memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated())
|
||||
print(f"Max memory allocated: {max_memory:.3f} GB")
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
|
||||
def run(args):
|
||||
pipeline = load_sd3_pipeline()
|
||||
prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = compute_embeddings(
|
||||
pipeline, args.prompt, args.max_sequence_length
|
||||
)
|
||||
|
||||
# Assumes that the images within `args.local_image_dir` have a JPEG extension. Change
|
||||
# as needed.
|
||||
image_paths = glob.glob(f"{args.local_data_dir}/*.jpeg")
|
||||
data = []
|
||||
for image_path in image_paths:
|
||||
img_hash = generate_image_hash(image_path)
|
||||
data.append(
|
||||
(img_hash, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds)
|
||||
)
|
||||
|
||||
# Create a DataFrame
|
||||
embedding_cols = [
|
||||
"prompt_embeds",
|
||||
"negative_prompt_embeds",
|
||||
"pooled_prompt_embeds",
|
||||
"negative_pooled_prompt_embeds",
|
||||
]
|
||||
df = pd.DataFrame(
|
||||
data,
|
||||
columns=["image_hash"] + embedding_cols,
|
||||
)
|
||||
|
||||
# Convert embedding lists to arrays (for proper storage in parquet)
|
||||
for col in embedding_cols:
|
||||
df[col] = df[col].apply(lambda x: x.cpu().numpy().flatten().tolist())
|
||||
|
||||
# Save the dataframe to a parquet file
|
||||
df.to_parquet(args.output_path)
|
||||
print(f"Data successfully serialized to {args.output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--prompt", type=str, default=PROMPT, help="The instance prompt.")
|
||||
parser.add_argument(
|
||||
"--max_sequence_length",
|
||||
type=int,
|
||||
default=MAX_SEQ_LENGTH,
|
||||
help="Maximum sequence length to use for computing the embeddings. The more the higher computational costs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local_data_dir", type=str, default=LOCAL_DATA_DIR, help="Path to the directory containing instance images."
|
||||
)
|
||||
parser.add_argument("--output_path", type=str, default=OUTPUT_PATH, help="Path to serialize the parquet file.")
|
||||
args = parser.parse_args()
|
||||
|
||||
run(args)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,11 +0,0 @@
|
||||
# VAE
|
||||
|
||||
`vae_roundtrip.py` Demonstrates the use of a VAE by roundtripping an image through the encoder and decoder. Original and reconstructed images are displayed side by side.
|
||||
|
||||
```
|
||||
cd examples/research_projects/vae
|
||||
python vae_roundtrip.py \
|
||||
--pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
|
||||
--subfolder="vae" \
|
||||
--input_image="/path/to/your/input.png"
|
||||
```
|
||||
@@ -1,282 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
import argparse
|
||||
import typing
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision import transforms # type: ignore
|
||||
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.models.autoencoders.autoencoder_kl import (
|
||||
AutoencoderKL,
|
||||
AutoencoderKLOutput,
|
||||
)
|
||||
from diffusers.models.autoencoders.autoencoder_tiny import (
|
||||
AutoencoderTiny,
|
||||
AutoencoderTinyOutput,
|
||||
)
|
||||
from diffusers.models.autoencoders.vae import DecoderOutput
|
||||
|
||||
|
||||
SupportedAutoencoder = Union[AutoencoderKL, AutoencoderTiny]
|
||||
|
||||
|
||||
def load_vae_model(
|
||||
*,
|
||||
device: torch.device,
|
||||
model_name_or_path: str,
|
||||
revision: Optional[str],
|
||||
variant: Optional[str],
|
||||
# NOTE: use subfolder="vae" if the pointed model is for stable diffusion as a whole instead of just the VAE
|
||||
subfolder: Optional[str],
|
||||
use_tiny_nn: bool,
|
||||
) -> SupportedAutoencoder:
|
||||
if use_tiny_nn:
|
||||
# NOTE: These scaling factors don't have to be the same as each other.
|
||||
down_scale = 2
|
||||
up_scale = 2
|
||||
vae = AutoencoderTiny.from_pretrained( # type: ignore
|
||||
model_name_or_path,
|
||||
subfolder=subfolder,
|
||||
revision=revision,
|
||||
variant=variant,
|
||||
downscaling_scaling_factor=down_scale,
|
||||
upsampling_scaling_factor=up_scale,
|
||||
)
|
||||
assert isinstance(vae, AutoencoderTiny)
|
||||
else:
|
||||
vae = AutoencoderKL.from_pretrained( # type: ignore
|
||||
model_name_or_path,
|
||||
subfolder=subfolder,
|
||||
revision=revision,
|
||||
variant=variant,
|
||||
)
|
||||
assert isinstance(vae, AutoencoderKL)
|
||||
vae = vae.to(device)
|
||||
vae.eval() # Set the model to inference mode
|
||||
return vae
|
||||
|
||||
|
||||
def pil_to_nhwc(
|
||||
*,
|
||||
device: torch.device,
|
||||
image: Image.Image,
|
||||
) -> torch.Tensor:
|
||||
assert image.mode == "RGB"
|
||||
transform = transforms.ToTensor()
|
||||
nhwc = transform(image).unsqueeze(0).to(device) # type: ignore
|
||||
assert isinstance(nhwc, torch.Tensor)
|
||||
return nhwc
|
||||
|
||||
|
||||
def nhwc_to_pil(
|
||||
*,
|
||||
nhwc: torch.Tensor,
|
||||
) -> Image.Image:
|
||||
assert nhwc.shape[0] == 1
|
||||
hwc = nhwc.squeeze(0).cpu()
|
||||
return transforms.ToPILImage()(hwc) # type: ignore
|
||||
|
||||
|
||||
def concatenate_images(
|
||||
*,
|
||||
left: Image.Image,
|
||||
right: Image.Image,
|
||||
vertical: bool = False,
|
||||
) -> Image.Image:
|
||||
width1, height1 = left.size
|
||||
width2, height2 = right.size
|
||||
if vertical:
|
||||
total_height = height1 + height2
|
||||
max_width = max(width1, width2)
|
||||
new_image = Image.new("RGB", (max_width, total_height))
|
||||
new_image.paste(left, (0, 0))
|
||||
new_image.paste(right, (0, height1))
|
||||
else:
|
||||
total_width = width1 + width2
|
||||
max_height = max(height1, height2)
|
||||
new_image = Image.new("RGB", (total_width, max_height))
|
||||
new_image.paste(left, (0, 0))
|
||||
new_image.paste(right, (width1, 0))
|
||||
return new_image
|
||||
|
||||
|
||||
def to_latent(
|
||||
*,
|
||||
rgb_nchw: torch.Tensor,
|
||||
vae: SupportedAutoencoder,
|
||||
) -> torch.Tensor:
|
||||
rgb_nchw = VaeImageProcessor.normalize(rgb_nchw) # type: ignore
|
||||
encoding_nchw = vae.encode(typing.cast(torch.FloatTensor, rgb_nchw))
|
||||
if isinstance(encoding_nchw, AutoencoderKLOutput):
|
||||
latent = encoding_nchw.latent_dist.sample() # type: ignore
|
||||
assert isinstance(latent, torch.Tensor)
|
||||
elif isinstance(encoding_nchw, AutoencoderTinyOutput):
|
||||
latent = encoding_nchw.latents
|
||||
do_internal_vae_scaling = False # Is this needed?
|
||||
if do_internal_vae_scaling:
|
||||
latent = vae.scale_latents(latent).mul(255).round().byte() # type: ignore
|
||||
latent = vae.unscale_latents(latent / 255.0) # type: ignore
|
||||
assert isinstance(latent, torch.Tensor)
|
||||
else:
|
||||
assert False, f"Unknown encoding type: {type(encoding_nchw)}"
|
||||
return latent
|
||||
|
||||
|
||||
def from_latent(
|
||||
*,
|
||||
latent_nchw: torch.Tensor,
|
||||
vae: SupportedAutoencoder,
|
||||
) -> torch.Tensor:
|
||||
decoding_nchw = vae.decode(latent_nchw) # type: ignore
|
||||
assert isinstance(decoding_nchw, DecoderOutput)
|
||||
rgb_nchw = VaeImageProcessor.denormalize(decoding_nchw.sample) # type: ignore
|
||||
assert isinstance(rgb_nchw, torch.Tensor)
|
||||
return rgb_nchw
|
||||
|
||||
|
||||
def main_kwargs(
|
||||
*,
|
||||
device: torch.device,
|
||||
input_image_path: str,
|
||||
pretrained_model_name_or_path: str,
|
||||
revision: Optional[str],
|
||||
variant: Optional[str],
|
||||
subfolder: Optional[str],
|
||||
use_tiny_nn: bool,
|
||||
) -> None:
|
||||
vae = load_vae_model(
|
||||
device=device,
|
||||
model_name_or_path=pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
variant=variant,
|
||||
subfolder=subfolder,
|
||||
use_tiny_nn=use_tiny_nn,
|
||||
)
|
||||
original_pil = Image.open(input_image_path).convert("RGB")
|
||||
original_image = pil_to_nhwc(
|
||||
device=device,
|
||||
image=original_pil,
|
||||
)
|
||||
print(f"Original image shape: {original_image.shape}")
|
||||
reconstructed_image: Optional[torch.Tensor] = None
|
||||
|
||||
with torch.no_grad():
|
||||
latent_image = to_latent(rgb_nchw=original_image, vae=vae)
|
||||
print(f"Latent shape: {latent_image.shape}")
|
||||
reconstructed_image = from_latent(latent_nchw=latent_image, vae=vae)
|
||||
reconstructed_pil = nhwc_to_pil(nhwc=reconstructed_image)
|
||||
combined_image = concatenate_images(
|
||||
left=original_pil,
|
||||
right=reconstructed_pil,
|
||||
vertical=False,
|
||||
)
|
||||
combined_image.show("Original | Reconstruction")
|
||||
print(f"Reconstructed image shape: {reconstructed_image.shape}")
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Inference with VAE")
|
||||
parser.add_argument(
|
||||
"--input_image",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the input image for inference.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pretrained VAE model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Model version.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--variant",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Model file variant, e.g., 'fp16'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--subfolder",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Subfolder in the model file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_cuda",
|
||||
action="store_true",
|
||||
help="Use CUDA if available.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_tiny_nn",
|
||||
action="store_true",
|
||||
help="Use tiny neural network.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
# EXAMPLE USAGE:
|
||||
#
|
||||
# python vae_roundtrip.py --use_cuda --pretrained_model_name_or_path "runwayml/stable-diffusion-v1-5" --subfolder "vae" --input_image "foo.png"
|
||||
#
|
||||
# python vae_roundtrip.py --use_cuda --pretrained_model_name_or_path "madebyollin/taesd" --use_tiny_nn --input_image "foo.png"
|
||||
#
|
||||
def main_cli() -> None:
|
||||
args = parse_args()
|
||||
|
||||
input_image_path = args.input_image
|
||||
assert isinstance(input_image_path, str)
|
||||
|
||||
pretrained_model_name_or_path = args.pretrained_model_name_or_path
|
||||
assert isinstance(pretrained_model_name_or_path, str)
|
||||
|
||||
revision = args.revision
|
||||
assert isinstance(revision, (str, type(None)))
|
||||
|
||||
variant = args.variant
|
||||
assert isinstance(variant, (str, type(None)))
|
||||
|
||||
subfolder = args.subfolder
|
||||
assert isinstance(subfolder, (str, type(None)))
|
||||
|
||||
use_cuda = args.use_cuda
|
||||
assert isinstance(use_cuda, bool)
|
||||
|
||||
use_tiny_nn = args.use_tiny_nn
|
||||
assert isinstance(use_tiny_nn, bool)
|
||||
|
||||
device = torch.device("cuda" if use_cuda else "cpu")
|
||||
|
||||
main_kwargs(
|
||||
device=device,
|
||||
input_image_path=input_image_path,
|
||||
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
variant=variant,
|
||||
subfolder=subfolder,
|
||||
use_tiny_nn=use_tiny_nn,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_cli()
|
||||
@@ -45,7 +45,7 @@ accelerate launch train_vqgan.py \
|
||||
```
|
||||
|
||||
An example training run is [here](https://wandb.ai/sayakpaul/vqgan-training/runs/0m5kzdfp) by @sayakpaul and a lower scale one [here](https://wandb.ai/dsbuddy27/vqgan-training/runs/eqd6xi4n?nw=nwuserisamu). The validation images can be obtained from [here](https://huggingface.co/datasets/diffusers/docs-images/tree/main/vqgan_validation_images).
|
||||
The simplest way to improve the quality of a VQGAN model is to maximize the amount of information present in the bottleneck. The easiest way to do this is increasing the image resolution. However, other ways include, but not limited to, lowering compression by downsampling fewer times or increasing the vocabulary size which at most can be around 16384. How to do this is shown below.
|
||||
The simplest way to improve the quality of a VQGAN model is to maximize the amount of information present in the bottleneck. The easiest way to do this is increasing the image resolution. However, other ways include, but not limited to, lowering compression by downsampling fewer times or increasing the vocaburary size which at most can be around 16384. How to do this is shown below.
|
||||
|
||||
# Modifying the architecture
|
||||
|
||||
@@ -118,10 +118,10 @@ To lower the amount of layers in a VQGan, you can remove layers by modifying the
|
||||
"vq_embed_dim": 4
|
||||
}
|
||||
```
|
||||
For increasing the size of the vocabularies you can increase num_vq_embeddings. However, [some research](https://magvit.cs.cmu.edu/v2/) shows that the representation of VQGANs start degrading after 2^14~16384 vq embeddings so it's not recommended to go past that.
|
||||
For increasing the size of the vocaburaries you can increase num_vq_embeddings. However, [some research](https://magvit.cs.cmu.edu/v2/) shows that the representation of VQGANs start degrading after 2^14~16384 vq embeddings so it's not recommended to go past that.
|
||||
|
||||
## Extra training tips/ideas
|
||||
During logging take care to make sure data_time is low. data_time is the amount spent loading the data and where the GPU is not active. So essentially, it's the time wasted. The easiest way to lower data time is to increase the --dataloader_num_workers to a higher number like 4. Due to a bug in Pytorch, this only works on linux based systems. For more details check [here](https://github.com/huggingface/diffusers/issues/7646)
|
||||
Secondly, training should seem to be done when both the discriminator and the generator loss converges.
|
||||
Thirdly, another low hanging fruit is just using ema using the --use_ema parameter. This tends to make the output images smoother. This has a con where you have to lower your batch size by 1 but it may be worth it.
|
||||
Another more experimental low hanging fruit is changing from the vgg19 to different models for the lpips loss using the --timm_model_backend. If you do this, I recommend also changing the timm_model_layers parameter to the layer in your model which you think is best for representation. However, be careful with the feature map norms since this can easily overdominate the loss.
|
||||
Another more experimental low hanging fruit is changing from the vgg19 to different models for the lpips loss using the --timm_model_backend. If you do this, I recommend also changing the timm_model_layers parameter to the layer in your model which you think is best for representation. However, becareful with the feature map norms since this can easily overdominate the loss.
|
||||
@@ -1,131 +0,0 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from diffusers.models.transformers.auraflow_transformer_2d import AuraFlowTransformer2DModel
|
||||
|
||||
|
||||
def load_original_state_dict(args):
|
||||
model_pt = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename="aura_diffusion_pytorch_model.bin")
|
||||
state_dict = torch.load(model_pt, map_location="cpu")
|
||||
return state_dict
|
||||
|
||||
|
||||
def calculate_layers(state_dict_keys, key_prefix):
|
||||
dit_layers = set()
|
||||
for k in state_dict_keys:
|
||||
if key_prefix in k:
|
||||
dit_layers.add(int(k.split(".")[2]))
|
||||
print(f"{key_prefix}: {len(dit_layers)}")
|
||||
return len(dit_layers)
|
||||
|
||||
|
||||
# similar to SD3 but only for the last norm layer
|
||||
def swap_scale_shift(weight, dim):
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
return new_weight
|
||||
|
||||
|
||||
def convert_transformer(state_dict):
|
||||
converted_state_dict = {}
|
||||
state_dict_keys = list(state_dict.keys())
|
||||
|
||||
converted_state_dict["register_tokens"] = state_dict.pop("model.register_tokens")
|
||||
converted_state_dict["pos_embed.pos_embed"] = state_dict.pop("model.positional_encoding")
|
||||
converted_state_dict["pos_embed.proj.weight"] = state_dict.pop("model.init_x_linear.weight")
|
||||
converted_state_dict["pos_embed.proj.bias"] = state_dict.pop("model.init_x_linear.bias")
|
||||
|
||||
converted_state_dict["time_step_proj.linear_1.weight"] = state_dict.pop("model.t_embedder.mlp.0.weight")
|
||||
converted_state_dict["time_step_proj.linear_1.bias"] = state_dict.pop("model.t_embedder.mlp.0.bias")
|
||||
converted_state_dict["time_step_proj.linear_2.weight"] = state_dict.pop("model.t_embedder.mlp.2.weight")
|
||||
converted_state_dict["time_step_proj.linear_2.bias"] = state_dict.pop("model.t_embedder.mlp.2.bias")
|
||||
|
||||
converted_state_dict["context_embedder.weight"] = state_dict.pop("model.cond_seq_linear.weight")
|
||||
|
||||
mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers")
|
||||
single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers")
|
||||
|
||||
# MMDiT blocks 🎸.
|
||||
for i in range(mmdit_layers):
|
||||
# feed-forward
|
||||
path_mapping = {"mlpX": "ff", "mlpC": "ff_context"}
|
||||
weight_mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
|
||||
for orig_k, diffuser_k in path_mapping.items():
|
||||
for k, v in weight_mapping.items():
|
||||
converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.{v}.weight"] = state_dict.pop(
|
||||
f"model.double_layers.{i}.{orig_k}.{k}.weight"
|
||||
)
|
||||
|
||||
# norms
|
||||
path_mapping = {"modX": "norm1", "modC": "norm1_context"}
|
||||
for orig_k, diffuser_k in path_mapping.items():
|
||||
converted_state_dict[f"joint_transformer_blocks.{i}.{diffuser_k}.linear.weight"] = state_dict.pop(
|
||||
f"model.double_layers.{i}.{orig_k}.1.weight"
|
||||
)
|
||||
|
||||
# attns
|
||||
x_attn_mapping = {"w2q": "to_q", "w2k": "to_k", "w2v": "to_v", "w2o": "to_out.0"}
|
||||
context_attn_mapping = {"w1q": "add_q_proj", "w1k": "add_k_proj", "w1v": "add_v_proj", "w1o": "to_add_out"}
|
||||
for attn_mapping in [x_attn_mapping, context_attn_mapping]:
|
||||
for k, v in attn_mapping.items():
|
||||
converted_state_dict[f"joint_transformer_blocks.{i}.attn.{v}.weight"] = state_dict.pop(
|
||||
f"model.double_layers.{i}.attn.{k}.weight"
|
||||
)
|
||||
|
||||
# Single-DiT blocks.
|
||||
for i in range(single_dit_layers):
|
||||
# feed-forward
|
||||
mapping = {"c_fc1": "linear_1", "c_fc2": "linear_2", "c_proj": "out_projection"}
|
||||
for k, v in mapping.items():
|
||||
converted_state_dict[f"single_transformer_blocks.{i}.ff.{v}.weight"] = state_dict.pop(
|
||||
f"model.single_layers.{i}.mlp.{k}.weight"
|
||||
)
|
||||
|
||||
# norms
|
||||
converted_state_dict[f"single_transformer_blocks.{i}.norm1.linear.weight"] = state_dict.pop(
|
||||
f"model.single_layers.{i}.modCX.1.weight"
|
||||
)
|
||||
|
||||
# attns
|
||||
x_attn_mapping = {"w1q": "to_q", "w1k": "to_k", "w1v": "to_v", "w1o": "to_out.0"}
|
||||
for k, v in x_attn_mapping.items():
|
||||
converted_state_dict[f"single_transformer_blocks.{i}.attn.{v}.weight"] = state_dict.pop(
|
||||
f"model.single_layers.{i}.attn.{k}.weight"
|
||||
)
|
||||
|
||||
# Final blocks.
|
||||
converted_state_dict["proj_out.weight"] = state_dict.pop("model.final_linear.weight")
|
||||
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(state_dict.pop("model.modF.1.weight"), dim=None)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def populate_state_dict(args):
|
||||
original_state_dict = load_original_state_dict(args)
|
||||
state_dict_keys = list(original_state_dict.keys())
|
||||
mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers")
|
||||
single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers")
|
||||
|
||||
converted_state_dict = convert_transformer(original_state_dict)
|
||||
model_diffusers = AuraFlowTransformer2DModel(
|
||||
num_mmdit_layers=mmdit_layers, num_single_dit_layers=single_dit_layers
|
||||
)
|
||||
model_diffusers.load_state_dict(converted_state_dict, strict=True)
|
||||
|
||||
return model_diffusers
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--original_state_dict_repo_id", default="AuraDiffusion/auradiffusion-v0.1a0", type=str)
|
||||
parser.add_argument("--dump_path", default="aura-flow", type=str)
|
||||
parser.add_argument("--hub_id", default=None, type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
model_diffusers = populate_state_dict(args)
|
||||
model_diffusers.save_pretrained(args.dump_path)
|
||||
if args.hub_id is not None:
|
||||
model_diffusers.push_to_hub(args.hub_id)
|
||||
@@ -1,241 +0,0 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import HunyuanDiT2DControlNetModel
|
||||
|
||||
|
||||
def main(args):
|
||||
state_dict = torch.load(args.pt_checkpoint_path, map_location="cpu")
|
||||
|
||||
if args.load_key != "none":
|
||||
try:
|
||||
state_dict = state_dict[args.load_key]
|
||||
except KeyError:
|
||||
raise KeyError(
|
||||
f"{args.load_key} not found in the checkpoint."
|
||||
"Please load from the following keys:{state_dict.keys()}"
|
||||
)
|
||||
device = "cuda"
|
||||
|
||||
model_config = HunyuanDiT2DControlNetModel.load_config(
|
||||
"Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", subfolder="transformer"
|
||||
)
|
||||
model_config[
|
||||
"use_style_cond_and_image_meta_size"
|
||||
] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False
|
||||
print(model_config)
|
||||
|
||||
for key in state_dict:
|
||||
print("local:", key)
|
||||
|
||||
model = HunyuanDiT2DControlNetModel.from_config(model_config).to(device)
|
||||
|
||||
for key in model.state_dict():
|
||||
print("diffusers:", key)
|
||||
|
||||
num_layers = 19
|
||||
for i in range(num_layers):
|
||||
# attn1
|
||||
# Wkqv -> to_q, to_k, to_v
|
||||
q, k, v = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.weight"], 3, dim=0)
|
||||
q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.bias"], 3, dim=0)
|
||||
state_dict[f"blocks.{i}.attn1.to_q.weight"] = q
|
||||
state_dict[f"blocks.{i}.attn1.to_q.bias"] = q_bias
|
||||
state_dict[f"blocks.{i}.attn1.to_k.weight"] = k
|
||||
state_dict[f"blocks.{i}.attn1.to_k.bias"] = k_bias
|
||||
state_dict[f"blocks.{i}.attn1.to_v.weight"] = v
|
||||
state_dict[f"blocks.{i}.attn1.to_v.bias"] = v_bias
|
||||
state_dict.pop(f"blocks.{i}.attn1.Wqkv.weight")
|
||||
state_dict.pop(f"blocks.{i}.attn1.Wqkv.bias")
|
||||
|
||||
# q_norm, k_norm -> norm_q, norm_k
|
||||
state_dict[f"blocks.{i}.attn1.norm_q.weight"] = state_dict[f"blocks.{i}.attn1.q_norm.weight"]
|
||||
state_dict[f"blocks.{i}.attn1.norm_q.bias"] = state_dict[f"blocks.{i}.attn1.q_norm.bias"]
|
||||
state_dict[f"blocks.{i}.attn1.norm_k.weight"] = state_dict[f"blocks.{i}.attn1.k_norm.weight"]
|
||||
state_dict[f"blocks.{i}.attn1.norm_k.bias"] = state_dict[f"blocks.{i}.attn1.k_norm.bias"]
|
||||
|
||||
state_dict.pop(f"blocks.{i}.attn1.q_norm.weight")
|
||||
state_dict.pop(f"blocks.{i}.attn1.q_norm.bias")
|
||||
state_dict.pop(f"blocks.{i}.attn1.k_norm.weight")
|
||||
state_dict.pop(f"blocks.{i}.attn1.k_norm.bias")
|
||||
|
||||
# out_proj -> to_out
|
||||
state_dict[f"blocks.{i}.attn1.to_out.0.weight"] = state_dict[f"blocks.{i}.attn1.out_proj.weight"]
|
||||
state_dict[f"blocks.{i}.attn1.to_out.0.bias"] = state_dict[f"blocks.{i}.attn1.out_proj.bias"]
|
||||
state_dict.pop(f"blocks.{i}.attn1.out_proj.weight")
|
||||
state_dict.pop(f"blocks.{i}.attn1.out_proj.bias")
|
||||
|
||||
# attn2
|
||||
# kq_proj -> to_k, to_v
|
||||
k, v = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.weight"], 2, dim=0)
|
||||
k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.bias"], 2, dim=0)
|
||||
state_dict[f"blocks.{i}.attn2.to_k.weight"] = k
|
||||
state_dict[f"blocks.{i}.attn2.to_k.bias"] = k_bias
|
||||
state_dict[f"blocks.{i}.attn2.to_v.weight"] = v
|
||||
state_dict[f"blocks.{i}.attn2.to_v.bias"] = v_bias
|
||||
state_dict.pop(f"blocks.{i}.attn2.kv_proj.weight")
|
||||
state_dict.pop(f"blocks.{i}.attn2.kv_proj.bias")
|
||||
|
||||
# q_proj -> to_q
|
||||
state_dict[f"blocks.{i}.attn2.to_q.weight"] = state_dict[f"blocks.{i}.attn2.q_proj.weight"]
|
||||
state_dict[f"blocks.{i}.attn2.to_q.bias"] = state_dict[f"blocks.{i}.attn2.q_proj.bias"]
|
||||
state_dict.pop(f"blocks.{i}.attn2.q_proj.weight")
|
||||
state_dict.pop(f"blocks.{i}.attn2.q_proj.bias")
|
||||
|
||||
# q_norm, k_norm -> norm_q, norm_k
|
||||
state_dict[f"blocks.{i}.attn2.norm_q.weight"] = state_dict[f"blocks.{i}.attn2.q_norm.weight"]
|
||||
state_dict[f"blocks.{i}.attn2.norm_q.bias"] = state_dict[f"blocks.{i}.attn2.q_norm.bias"]
|
||||
state_dict[f"blocks.{i}.attn2.norm_k.weight"] = state_dict[f"blocks.{i}.attn2.k_norm.weight"]
|
||||
state_dict[f"blocks.{i}.attn2.norm_k.bias"] = state_dict[f"blocks.{i}.attn2.k_norm.bias"]
|
||||
|
||||
state_dict.pop(f"blocks.{i}.attn2.q_norm.weight")
|
||||
state_dict.pop(f"blocks.{i}.attn2.q_norm.bias")
|
||||
state_dict.pop(f"blocks.{i}.attn2.k_norm.weight")
|
||||
state_dict.pop(f"blocks.{i}.attn2.k_norm.bias")
|
||||
|
||||
# out_proj -> to_out
|
||||
state_dict[f"blocks.{i}.attn2.to_out.0.weight"] = state_dict[f"blocks.{i}.attn2.out_proj.weight"]
|
||||
state_dict[f"blocks.{i}.attn2.to_out.0.bias"] = state_dict[f"blocks.{i}.attn2.out_proj.bias"]
|
||||
state_dict.pop(f"blocks.{i}.attn2.out_proj.weight")
|
||||
state_dict.pop(f"blocks.{i}.attn2.out_proj.bias")
|
||||
|
||||
# switch norm 2 and norm 3
|
||||
norm2_weight = state_dict[f"blocks.{i}.norm2.weight"]
|
||||
norm2_bias = state_dict[f"blocks.{i}.norm2.bias"]
|
||||
state_dict[f"blocks.{i}.norm2.weight"] = state_dict[f"blocks.{i}.norm3.weight"]
|
||||
state_dict[f"blocks.{i}.norm2.bias"] = state_dict[f"blocks.{i}.norm3.bias"]
|
||||
state_dict[f"blocks.{i}.norm3.weight"] = norm2_weight
|
||||
state_dict[f"blocks.{i}.norm3.bias"] = norm2_bias
|
||||
|
||||
# norm1 -> norm1.norm
|
||||
# default_modulation.1 -> norm1.linear
|
||||
state_dict[f"blocks.{i}.norm1.norm.weight"] = state_dict[f"blocks.{i}.norm1.weight"]
|
||||
state_dict[f"blocks.{i}.norm1.norm.bias"] = state_dict[f"blocks.{i}.norm1.bias"]
|
||||
state_dict[f"blocks.{i}.norm1.linear.weight"] = state_dict[f"blocks.{i}.default_modulation.1.weight"]
|
||||
state_dict[f"blocks.{i}.norm1.linear.bias"] = state_dict[f"blocks.{i}.default_modulation.1.bias"]
|
||||
state_dict.pop(f"blocks.{i}.norm1.weight")
|
||||
state_dict.pop(f"blocks.{i}.norm1.bias")
|
||||
state_dict.pop(f"blocks.{i}.default_modulation.1.weight")
|
||||
state_dict.pop(f"blocks.{i}.default_modulation.1.bias")
|
||||
|
||||
# mlp.fc1 -> ff.net.0, mlp.fc2 -> ff.net.2
|
||||
state_dict[f"blocks.{i}.ff.net.0.proj.weight"] = state_dict[f"blocks.{i}.mlp.fc1.weight"]
|
||||
state_dict[f"blocks.{i}.ff.net.0.proj.bias"] = state_dict[f"blocks.{i}.mlp.fc1.bias"]
|
||||
state_dict[f"blocks.{i}.ff.net.2.weight"] = state_dict[f"blocks.{i}.mlp.fc2.weight"]
|
||||
state_dict[f"blocks.{i}.ff.net.2.bias"] = state_dict[f"blocks.{i}.mlp.fc2.bias"]
|
||||
state_dict.pop(f"blocks.{i}.mlp.fc1.weight")
|
||||
state_dict.pop(f"blocks.{i}.mlp.fc1.bias")
|
||||
state_dict.pop(f"blocks.{i}.mlp.fc2.weight")
|
||||
state_dict.pop(f"blocks.{i}.mlp.fc2.bias")
|
||||
|
||||
# after_proj_list -> controlnet_blocks
|
||||
state_dict[f"controlnet_blocks.{i}.weight"] = state_dict[f"after_proj_list.{i}.weight"]
|
||||
state_dict[f"controlnet_blocks.{i}.bias"] = state_dict[f"after_proj_list.{i}.bias"]
|
||||
state_dict.pop(f"after_proj_list.{i}.weight")
|
||||
state_dict.pop(f"after_proj_list.{i}.bias")
|
||||
|
||||
# before_proj -> input_block
|
||||
state_dict["input_block.weight"] = state_dict["before_proj.weight"]
|
||||
state_dict["input_block.bias"] = state_dict["before_proj.bias"]
|
||||
state_dict.pop("before_proj.weight")
|
||||
state_dict.pop("before_proj.bias")
|
||||
|
||||
# pooler -> time_extra_emb
|
||||
state_dict["time_extra_emb.pooler.positional_embedding"] = state_dict["pooler.positional_embedding"]
|
||||
state_dict["time_extra_emb.pooler.k_proj.weight"] = state_dict["pooler.k_proj.weight"]
|
||||
state_dict["time_extra_emb.pooler.k_proj.bias"] = state_dict["pooler.k_proj.bias"]
|
||||
state_dict["time_extra_emb.pooler.q_proj.weight"] = state_dict["pooler.q_proj.weight"]
|
||||
state_dict["time_extra_emb.pooler.q_proj.bias"] = state_dict["pooler.q_proj.bias"]
|
||||
state_dict["time_extra_emb.pooler.v_proj.weight"] = state_dict["pooler.v_proj.weight"]
|
||||
state_dict["time_extra_emb.pooler.v_proj.bias"] = state_dict["pooler.v_proj.bias"]
|
||||
state_dict["time_extra_emb.pooler.c_proj.weight"] = state_dict["pooler.c_proj.weight"]
|
||||
state_dict["time_extra_emb.pooler.c_proj.bias"] = state_dict["pooler.c_proj.bias"]
|
||||
state_dict.pop("pooler.k_proj.weight")
|
||||
state_dict.pop("pooler.k_proj.bias")
|
||||
state_dict.pop("pooler.q_proj.weight")
|
||||
state_dict.pop("pooler.q_proj.bias")
|
||||
state_dict.pop("pooler.v_proj.weight")
|
||||
state_dict.pop("pooler.v_proj.bias")
|
||||
state_dict.pop("pooler.c_proj.weight")
|
||||
state_dict.pop("pooler.c_proj.bias")
|
||||
state_dict.pop("pooler.positional_embedding")
|
||||
|
||||
# t_embedder -> time_embedding (`TimestepEmbedding`)
|
||||
state_dict["time_extra_emb.timestep_embedder.linear_1.bias"] = state_dict["t_embedder.mlp.0.bias"]
|
||||
state_dict["time_extra_emb.timestep_embedder.linear_1.weight"] = state_dict["t_embedder.mlp.0.weight"]
|
||||
state_dict["time_extra_emb.timestep_embedder.linear_2.bias"] = state_dict["t_embedder.mlp.2.bias"]
|
||||
state_dict["time_extra_emb.timestep_embedder.linear_2.weight"] = state_dict["t_embedder.mlp.2.weight"]
|
||||
|
||||
state_dict.pop("t_embedder.mlp.0.bias")
|
||||
state_dict.pop("t_embedder.mlp.0.weight")
|
||||
state_dict.pop("t_embedder.mlp.2.bias")
|
||||
state_dict.pop("t_embedder.mlp.2.weight")
|
||||
|
||||
# x_embedder -> pos_embd (`PatchEmbed`)
|
||||
state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"]
|
||||
state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"]
|
||||
state_dict.pop("x_embedder.proj.weight")
|
||||
state_dict.pop("x_embedder.proj.bias")
|
||||
|
||||
# mlp_t5 -> text_embedder
|
||||
state_dict["text_embedder.linear_1.bias"] = state_dict["mlp_t5.0.bias"]
|
||||
state_dict["text_embedder.linear_1.weight"] = state_dict["mlp_t5.0.weight"]
|
||||
state_dict["text_embedder.linear_2.bias"] = state_dict["mlp_t5.2.bias"]
|
||||
state_dict["text_embedder.linear_2.weight"] = state_dict["mlp_t5.2.weight"]
|
||||
state_dict.pop("mlp_t5.0.bias")
|
||||
state_dict.pop("mlp_t5.0.weight")
|
||||
state_dict.pop("mlp_t5.2.bias")
|
||||
state_dict.pop("mlp_t5.2.weight")
|
||||
|
||||
# extra_embedder -> extra_embedder
|
||||
state_dict["time_extra_emb.extra_embedder.linear_1.bias"] = state_dict["extra_embedder.0.bias"]
|
||||
state_dict["time_extra_emb.extra_embedder.linear_1.weight"] = state_dict["extra_embedder.0.weight"]
|
||||
state_dict["time_extra_emb.extra_embedder.linear_2.bias"] = state_dict["extra_embedder.2.bias"]
|
||||
state_dict["time_extra_emb.extra_embedder.linear_2.weight"] = state_dict["extra_embedder.2.weight"]
|
||||
state_dict.pop("extra_embedder.0.bias")
|
||||
state_dict.pop("extra_embedder.0.weight")
|
||||
state_dict.pop("extra_embedder.2.bias")
|
||||
state_dict.pop("extra_embedder.2.weight")
|
||||
|
||||
# style_embedder
|
||||
if model_config["use_style_cond_and_image_meta_size"]:
|
||||
print(state_dict["style_embedder.weight"])
|
||||
print(state_dict["style_embedder.weight"].shape)
|
||||
state_dict["time_extra_emb.style_embedder.weight"] = state_dict["style_embedder.weight"][0:1]
|
||||
state_dict.pop("style_embedder.weight")
|
||||
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
if args.save:
|
||||
model.save_pretrained(args.output_checkpoint_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--save", default=True, type=bool, required=False, help="Whether to save the converted pipeline or not."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pt_checkpoint_path", default=None, type=str, required=True, help="Path to the .pt pretrained model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=False,
|
||||
help="Path to the output converted diffusers pipeline.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_key", default="none", type=str, required=False, help="The key to load from the pretrained .pt file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_style_cond_and_image_meta_size",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="version <= v1.1: True; version >= v1.2: False",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@@ -1,267 +0,0 @@
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import HunyuanDiT2DModel
|
||||
|
||||
|
||||
def main(args):
|
||||
state_dict = torch.load(args.pt_checkpoint_path, map_location="cpu")
|
||||
|
||||
if args.load_key != "none":
|
||||
try:
|
||||
state_dict = state_dict[args.load_key]
|
||||
except KeyError:
|
||||
raise KeyError(
|
||||
f"{args.load_key} not found in the checkpoint."
|
||||
f"Please load from the following keys:{state_dict.keys()}"
|
||||
)
|
||||
|
||||
device = "cuda"
|
||||
model_config = HunyuanDiT2DModel.load_config("Tencent-Hunyuan/HunyuanDiT-Diffusers", subfolder="transformer")
|
||||
model_config[
|
||||
"use_style_cond_and_image_meta_size"
|
||||
] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False
|
||||
|
||||
# input_size -> sample_size, text_dim -> cross_attention_dim
|
||||
for key in state_dict:
|
||||
print("local:", key)
|
||||
|
||||
model = HunyuanDiT2DModel.from_config(model_config).to(device)
|
||||
|
||||
for key in model.state_dict():
|
||||
print("diffusers:", key)
|
||||
|
||||
num_layers = 40
|
||||
for i in range(num_layers):
|
||||
# attn1
|
||||
# Wkqv -> to_q, to_k, to_v
|
||||
q, k, v = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.weight"], 3, dim=0)
|
||||
q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.bias"], 3, dim=0)
|
||||
state_dict[f"blocks.{i}.attn1.to_q.weight"] = q
|
||||
state_dict[f"blocks.{i}.attn1.to_q.bias"] = q_bias
|
||||
state_dict[f"blocks.{i}.attn1.to_k.weight"] = k
|
||||
state_dict[f"blocks.{i}.attn1.to_k.bias"] = k_bias
|
||||
state_dict[f"blocks.{i}.attn1.to_v.weight"] = v
|
||||
state_dict[f"blocks.{i}.attn1.to_v.bias"] = v_bias
|
||||
state_dict.pop(f"blocks.{i}.attn1.Wqkv.weight")
|
||||
state_dict.pop(f"blocks.{i}.attn1.Wqkv.bias")
|
||||
|
||||
# q_norm, k_norm -> norm_q, norm_k
|
||||
state_dict[f"blocks.{i}.attn1.norm_q.weight"] = state_dict[f"blocks.{i}.attn1.q_norm.weight"]
|
||||
state_dict[f"blocks.{i}.attn1.norm_q.bias"] = state_dict[f"blocks.{i}.attn1.q_norm.bias"]
|
||||
state_dict[f"blocks.{i}.attn1.norm_k.weight"] = state_dict[f"blocks.{i}.attn1.k_norm.weight"]
|
||||
state_dict[f"blocks.{i}.attn1.norm_k.bias"] = state_dict[f"blocks.{i}.attn1.k_norm.bias"]
|
||||
|
||||
state_dict.pop(f"blocks.{i}.attn1.q_norm.weight")
|
||||
state_dict.pop(f"blocks.{i}.attn1.q_norm.bias")
|
||||
state_dict.pop(f"blocks.{i}.attn1.k_norm.weight")
|
||||
state_dict.pop(f"blocks.{i}.attn1.k_norm.bias")
|
||||
|
||||
# out_proj -> to_out
|
||||
state_dict[f"blocks.{i}.attn1.to_out.0.weight"] = state_dict[f"blocks.{i}.attn1.out_proj.weight"]
|
||||
state_dict[f"blocks.{i}.attn1.to_out.0.bias"] = state_dict[f"blocks.{i}.attn1.out_proj.bias"]
|
||||
state_dict.pop(f"blocks.{i}.attn1.out_proj.weight")
|
||||
state_dict.pop(f"blocks.{i}.attn1.out_proj.bias")
|
||||
|
||||
# attn2
|
||||
# kq_proj -> to_k, to_v
|
||||
k, v = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.weight"], 2, dim=0)
|
||||
k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.bias"], 2, dim=0)
|
||||
state_dict[f"blocks.{i}.attn2.to_k.weight"] = k
|
||||
state_dict[f"blocks.{i}.attn2.to_k.bias"] = k_bias
|
||||
state_dict[f"blocks.{i}.attn2.to_v.weight"] = v
|
||||
state_dict[f"blocks.{i}.attn2.to_v.bias"] = v_bias
|
||||
state_dict.pop(f"blocks.{i}.attn2.kv_proj.weight")
|
||||
state_dict.pop(f"blocks.{i}.attn2.kv_proj.bias")
|
||||
|
||||
# q_proj -> to_q
|
||||
state_dict[f"blocks.{i}.attn2.to_q.weight"] = state_dict[f"blocks.{i}.attn2.q_proj.weight"]
|
||||
state_dict[f"blocks.{i}.attn2.to_q.bias"] = state_dict[f"blocks.{i}.attn2.q_proj.bias"]
|
||||
state_dict.pop(f"blocks.{i}.attn2.q_proj.weight")
|
||||
state_dict.pop(f"blocks.{i}.attn2.q_proj.bias")
|
||||
|
||||
# q_norm, k_norm -> norm_q, norm_k
|
||||
state_dict[f"blocks.{i}.attn2.norm_q.weight"] = state_dict[f"blocks.{i}.attn2.q_norm.weight"]
|
||||
state_dict[f"blocks.{i}.attn2.norm_q.bias"] = state_dict[f"blocks.{i}.attn2.q_norm.bias"]
|
||||
state_dict[f"blocks.{i}.attn2.norm_k.weight"] = state_dict[f"blocks.{i}.attn2.k_norm.weight"]
|
||||
state_dict[f"blocks.{i}.attn2.norm_k.bias"] = state_dict[f"blocks.{i}.attn2.k_norm.bias"]
|
||||
|
||||
state_dict.pop(f"blocks.{i}.attn2.q_norm.weight")
|
||||
state_dict.pop(f"blocks.{i}.attn2.q_norm.bias")
|
||||
state_dict.pop(f"blocks.{i}.attn2.k_norm.weight")
|
||||
state_dict.pop(f"blocks.{i}.attn2.k_norm.bias")
|
||||
|
||||
# out_proj -> to_out
|
||||
state_dict[f"blocks.{i}.attn2.to_out.0.weight"] = state_dict[f"blocks.{i}.attn2.out_proj.weight"]
|
||||
state_dict[f"blocks.{i}.attn2.to_out.0.bias"] = state_dict[f"blocks.{i}.attn2.out_proj.bias"]
|
||||
state_dict.pop(f"blocks.{i}.attn2.out_proj.weight")
|
||||
state_dict.pop(f"blocks.{i}.attn2.out_proj.bias")
|
||||
|
||||
# switch norm 2 and norm 3
|
||||
norm2_weight = state_dict[f"blocks.{i}.norm2.weight"]
|
||||
norm2_bias = state_dict[f"blocks.{i}.norm2.bias"]
|
||||
state_dict[f"blocks.{i}.norm2.weight"] = state_dict[f"blocks.{i}.norm3.weight"]
|
||||
state_dict[f"blocks.{i}.norm2.bias"] = state_dict[f"blocks.{i}.norm3.bias"]
|
||||
state_dict[f"blocks.{i}.norm3.weight"] = norm2_weight
|
||||
state_dict[f"blocks.{i}.norm3.bias"] = norm2_bias
|
||||
|
||||
# norm1 -> norm1.norm
|
||||
# default_modulation.1 -> norm1.linear
|
||||
state_dict[f"blocks.{i}.norm1.norm.weight"] = state_dict[f"blocks.{i}.norm1.weight"]
|
||||
state_dict[f"blocks.{i}.norm1.norm.bias"] = state_dict[f"blocks.{i}.norm1.bias"]
|
||||
state_dict[f"blocks.{i}.norm1.linear.weight"] = state_dict[f"blocks.{i}.default_modulation.1.weight"]
|
||||
state_dict[f"blocks.{i}.norm1.linear.bias"] = state_dict[f"blocks.{i}.default_modulation.1.bias"]
|
||||
state_dict.pop(f"blocks.{i}.norm1.weight")
|
||||
state_dict.pop(f"blocks.{i}.norm1.bias")
|
||||
state_dict.pop(f"blocks.{i}.default_modulation.1.weight")
|
||||
state_dict.pop(f"blocks.{i}.default_modulation.1.bias")
|
||||
|
||||
# mlp.fc1 -> ff.net.0, mlp.fc2 -> ff.net.2
|
||||
state_dict[f"blocks.{i}.ff.net.0.proj.weight"] = state_dict[f"blocks.{i}.mlp.fc1.weight"]
|
||||
state_dict[f"blocks.{i}.ff.net.0.proj.bias"] = state_dict[f"blocks.{i}.mlp.fc1.bias"]
|
||||
state_dict[f"blocks.{i}.ff.net.2.weight"] = state_dict[f"blocks.{i}.mlp.fc2.weight"]
|
||||
state_dict[f"blocks.{i}.ff.net.2.bias"] = state_dict[f"blocks.{i}.mlp.fc2.bias"]
|
||||
state_dict.pop(f"blocks.{i}.mlp.fc1.weight")
|
||||
state_dict.pop(f"blocks.{i}.mlp.fc1.bias")
|
||||
state_dict.pop(f"blocks.{i}.mlp.fc2.weight")
|
||||
state_dict.pop(f"blocks.{i}.mlp.fc2.bias")
|
||||
|
||||
# pooler -> time_extra_emb
|
||||
state_dict["time_extra_emb.pooler.positional_embedding"] = state_dict["pooler.positional_embedding"]
|
||||
state_dict["time_extra_emb.pooler.k_proj.weight"] = state_dict["pooler.k_proj.weight"]
|
||||
state_dict["time_extra_emb.pooler.k_proj.bias"] = state_dict["pooler.k_proj.bias"]
|
||||
state_dict["time_extra_emb.pooler.q_proj.weight"] = state_dict["pooler.q_proj.weight"]
|
||||
state_dict["time_extra_emb.pooler.q_proj.bias"] = state_dict["pooler.q_proj.bias"]
|
||||
state_dict["time_extra_emb.pooler.v_proj.weight"] = state_dict["pooler.v_proj.weight"]
|
||||
state_dict["time_extra_emb.pooler.v_proj.bias"] = state_dict["pooler.v_proj.bias"]
|
||||
state_dict["time_extra_emb.pooler.c_proj.weight"] = state_dict["pooler.c_proj.weight"]
|
||||
state_dict["time_extra_emb.pooler.c_proj.bias"] = state_dict["pooler.c_proj.bias"]
|
||||
state_dict.pop("pooler.k_proj.weight")
|
||||
state_dict.pop("pooler.k_proj.bias")
|
||||
state_dict.pop("pooler.q_proj.weight")
|
||||
state_dict.pop("pooler.q_proj.bias")
|
||||
state_dict.pop("pooler.v_proj.weight")
|
||||
state_dict.pop("pooler.v_proj.bias")
|
||||
state_dict.pop("pooler.c_proj.weight")
|
||||
state_dict.pop("pooler.c_proj.bias")
|
||||
state_dict.pop("pooler.positional_embedding")
|
||||
|
||||
# t_embedder -> time_embedding (`TimestepEmbedding`)
|
||||
state_dict["time_extra_emb.timestep_embedder.linear_1.bias"] = state_dict["t_embedder.mlp.0.bias"]
|
||||
state_dict["time_extra_emb.timestep_embedder.linear_1.weight"] = state_dict["t_embedder.mlp.0.weight"]
|
||||
state_dict["time_extra_emb.timestep_embedder.linear_2.bias"] = state_dict["t_embedder.mlp.2.bias"]
|
||||
state_dict["time_extra_emb.timestep_embedder.linear_2.weight"] = state_dict["t_embedder.mlp.2.weight"]
|
||||
|
||||
state_dict.pop("t_embedder.mlp.0.bias")
|
||||
state_dict.pop("t_embedder.mlp.0.weight")
|
||||
state_dict.pop("t_embedder.mlp.2.bias")
|
||||
state_dict.pop("t_embedder.mlp.2.weight")
|
||||
|
||||
# x_embedder -> pos_embd (`PatchEmbed`)
|
||||
state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"]
|
||||
state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"]
|
||||
state_dict.pop("x_embedder.proj.weight")
|
||||
state_dict.pop("x_embedder.proj.bias")
|
||||
|
||||
# mlp_t5 -> text_embedder
|
||||
state_dict["text_embedder.linear_1.bias"] = state_dict["mlp_t5.0.bias"]
|
||||
state_dict["text_embedder.linear_1.weight"] = state_dict["mlp_t5.0.weight"]
|
||||
state_dict["text_embedder.linear_2.bias"] = state_dict["mlp_t5.2.bias"]
|
||||
state_dict["text_embedder.linear_2.weight"] = state_dict["mlp_t5.2.weight"]
|
||||
state_dict.pop("mlp_t5.0.bias")
|
||||
state_dict.pop("mlp_t5.0.weight")
|
||||
state_dict.pop("mlp_t5.2.bias")
|
||||
state_dict.pop("mlp_t5.2.weight")
|
||||
|
||||
# extra_embedder -> extra_embedder
|
||||
state_dict["time_extra_emb.extra_embedder.linear_1.bias"] = state_dict["extra_embedder.0.bias"]
|
||||
state_dict["time_extra_emb.extra_embedder.linear_1.weight"] = state_dict["extra_embedder.0.weight"]
|
||||
state_dict["time_extra_emb.extra_embedder.linear_2.bias"] = state_dict["extra_embedder.2.bias"]
|
||||
state_dict["time_extra_emb.extra_embedder.linear_2.weight"] = state_dict["extra_embedder.2.weight"]
|
||||
state_dict.pop("extra_embedder.0.bias")
|
||||
state_dict.pop("extra_embedder.0.weight")
|
||||
state_dict.pop("extra_embedder.2.bias")
|
||||
state_dict.pop("extra_embedder.2.weight")
|
||||
|
||||
# model.final_adaLN_modulation.1 -> norm_out.linear
|
||||
def swap_scale_shift(weight):
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
return new_weight
|
||||
|
||||
state_dict["norm_out.linear.weight"] = swap_scale_shift(state_dict["final_layer.adaLN_modulation.1.weight"])
|
||||
state_dict["norm_out.linear.bias"] = swap_scale_shift(state_dict["final_layer.adaLN_modulation.1.bias"])
|
||||
state_dict.pop("final_layer.adaLN_modulation.1.weight")
|
||||
state_dict.pop("final_layer.adaLN_modulation.1.bias")
|
||||
|
||||
# final_linear -> proj_out
|
||||
state_dict["proj_out.weight"] = state_dict["final_layer.linear.weight"]
|
||||
state_dict["proj_out.bias"] = state_dict["final_layer.linear.bias"]
|
||||
state_dict.pop("final_layer.linear.weight")
|
||||
state_dict.pop("final_layer.linear.bias")
|
||||
|
||||
# style_embedder
|
||||
if model_config["use_style_cond_and_image_meta_size"]:
|
||||
print(state_dict["style_embedder.weight"])
|
||||
print(state_dict["style_embedder.weight"].shape)
|
||||
state_dict["time_extra_emb.style_embedder.weight"] = state_dict["style_embedder.weight"][0:1]
|
||||
state_dict.pop("style_embedder.weight")
|
||||
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
from diffusers import HunyuanDiTPipeline
|
||||
|
||||
if args.use_style_cond_and_image_meta_size:
|
||||
pipe = HunyuanDiTPipeline.from_pretrained(
|
||||
"Tencent-Hunyuan/HunyuanDiT-Diffusers", transformer=model, torch_dtype=torch.float32
|
||||
)
|
||||
else:
|
||||
pipe = HunyuanDiTPipeline.from_pretrained(
|
||||
"Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", transformer=model, torch_dtype=torch.float32
|
||||
)
|
||||
pipe.to("cuda")
|
||||
pipe.to(dtype=torch.float32)
|
||||
|
||||
if args.save:
|
||||
pipe.save_pretrained(args.output_checkpoint_path)
|
||||
|
||||
# ### NOTE: HunyuanDiT supports both Chinese and English inputs
|
||||
prompt = "一个宇航员在骑马"
|
||||
# prompt = "An astronaut riding a horse"
|
||||
generator = torch.Generator(device="cuda").manual_seed(0)
|
||||
image = pipe(
|
||||
height=1024, width=1024, prompt=prompt, generator=generator, num_inference_steps=25, guidance_scale=5.0
|
||||
).images[0]
|
||||
|
||||
image.save("img.png")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--save", default=True, type=bool, required=False, help="Whether to save the converted pipeline or not."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pt_checkpoint_path", default=None, type=str, required=True, help="Path to the .pt pretrained model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=False,
|
||||
help="Path to the output converted diffusers pipeline.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_key", default="none", type=str, required=False, help="The key to load from the pretrained .pt file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_style_cond_and_image_meta_size",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="version <= v1.1: True; version >= v1.2: False",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@@ -1,142 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, LuminaNextDiT2DModel, LuminaText2ImgPipeline
|
||||
|
||||
|
||||
def main(args):
|
||||
# checkpoint from https://huggingface.co/Alpha-VLLM/Lumina-Next-SFT or https://huggingface.co/Alpha-VLLM/Lumina-Next-T2I
|
||||
all_sd = load_file(args.origin_ckpt_path, device="cpu")
|
||||
converted_state_dict = {}
|
||||
# pad token
|
||||
converted_state_dict["pad_token"] = all_sd["pad_token"]
|
||||
|
||||
# patch embed
|
||||
converted_state_dict["patch_embedder.weight"] = all_sd["x_embedder.weight"]
|
||||
converted_state_dict["patch_embedder.bias"] = all_sd["x_embedder.bias"]
|
||||
|
||||
# time and caption embed
|
||||
converted_state_dict["time_caption_embed.timestep_embedder.linear_1.weight"] = all_sd["t_embedder.mlp.0.weight"]
|
||||
converted_state_dict["time_caption_embed.timestep_embedder.linear_1.bias"] = all_sd["t_embedder.mlp.0.bias"]
|
||||
converted_state_dict["time_caption_embed.timestep_embedder.linear_2.weight"] = all_sd["t_embedder.mlp.2.weight"]
|
||||
converted_state_dict["time_caption_embed.timestep_embedder.linear_2.bias"] = all_sd["t_embedder.mlp.2.bias"]
|
||||
converted_state_dict["time_caption_embed.caption_embedder.0.weight"] = all_sd["cap_embedder.0.weight"]
|
||||
converted_state_dict["time_caption_embed.caption_embedder.0.bias"] = all_sd["cap_embedder.0.bias"]
|
||||
converted_state_dict["time_caption_embed.caption_embedder.1.weight"] = all_sd["cap_embedder.1.weight"]
|
||||
converted_state_dict["time_caption_embed.caption_embedder.1.bias"] = all_sd["cap_embedder.1.bias"]
|
||||
|
||||
for i in range(24):
|
||||
# adaln
|
||||
converted_state_dict[f"layers.{i}.gate"] = all_sd[f"layers.{i}.attention.gate"]
|
||||
converted_state_dict[f"layers.{i}.adaLN_modulation.1.weight"] = all_sd[f"layers.{i}.adaLN_modulation.1.weight"]
|
||||
converted_state_dict[f"layers.{i}.adaLN_modulation.1.bias"] = all_sd[f"layers.{i}.adaLN_modulation.1.bias"]
|
||||
|
||||
# qkv
|
||||
converted_state_dict[f"layers.{i}.attn1.to_q.weight"] = all_sd[f"layers.{i}.attention.wq.weight"]
|
||||
converted_state_dict[f"layers.{i}.attn1.to_k.weight"] = all_sd[f"layers.{i}.attention.wk.weight"]
|
||||
converted_state_dict[f"layers.{i}.attn1.to_v.weight"] = all_sd[f"layers.{i}.attention.wv.weight"]
|
||||
|
||||
# cap
|
||||
converted_state_dict[f"layers.{i}.attn2.to_q.weight"] = all_sd[f"layers.{i}.attention.wq.weight"]
|
||||
converted_state_dict[f"layers.{i}.attn2.to_k.weight"] = all_sd[f"layers.{i}.attention.wk_y.weight"]
|
||||
converted_state_dict[f"layers.{i}.attn2.to_v.weight"] = all_sd[f"layers.{i}.attention.wv_y.weight"]
|
||||
|
||||
# output
|
||||
converted_state_dict[f"layers.{i}.attn2.to_out.0.weight"] = all_sd[f"layers.{i}.attention.wo.weight"]
|
||||
|
||||
# attention
|
||||
# qk norm
|
||||
converted_state_dict[f"layers.{i}.attn1.norm_q.weight"] = all_sd[f"layers.{i}.attention.q_norm.weight"]
|
||||
converted_state_dict[f"layers.{i}.attn1.norm_q.bias"] = all_sd[f"layers.{i}.attention.q_norm.bias"]
|
||||
|
||||
converted_state_dict[f"layers.{i}.attn1.norm_k.weight"] = all_sd[f"layers.{i}.attention.k_norm.weight"]
|
||||
converted_state_dict[f"layers.{i}.attn1.norm_k.bias"] = all_sd[f"layers.{i}.attention.k_norm.bias"]
|
||||
|
||||
converted_state_dict[f"layers.{i}.attn2.norm_q.weight"] = all_sd[f"layers.{i}.attention.q_norm.weight"]
|
||||
converted_state_dict[f"layers.{i}.attn2.norm_q.bias"] = all_sd[f"layers.{i}.attention.q_norm.bias"]
|
||||
|
||||
converted_state_dict[f"layers.{i}.attn2.norm_k.weight"] = all_sd[f"layers.{i}.attention.ky_norm.weight"]
|
||||
converted_state_dict[f"layers.{i}.attn2.norm_k.bias"] = all_sd[f"layers.{i}.attention.ky_norm.bias"]
|
||||
|
||||
# attention norm
|
||||
converted_state_dict[f"layers.{i}.attn_norm1.weight"] = all_sd[f"layers.{i}.attention_norm1.weight"]
|
||||
converted_state_dict[f"layers.{i}.attn_norm2.weight"] = all_sd[f"layers.{i}.attention_norm2.weight"]
|
||||
converted_state_dict[f"layers.{i}.norm1_context.weight"] = all_sd[f"layers.{i}.attention_y_norm.weight"]
|
||||
|
||||
# feed forward
|
||||
converted_state_dict[f"layers.{i}.feed_forward.linear_1.weight"] = all_sd[f"layers.{i}.feed_forward.w1.weight"]
|
||||
converted_state_dict[f"layers.{i}.feed_forward.linear_2.weight"] = all_sd[f"layers.{i}.feed_forward.w2.weight"]
|
||||
converted_state_dict[f"layers.{i}.feed_forward.linear_3.weight"] = all_sd[f"layers.{i}.feed_forward.w3.weight"]
|
||||
|
||||
# feed forward norm
|
||||
converted_state_dict[f"layers.{i}.ffn_norm1.weight"] = all_sd[f"layers.{i}.ffn_norm1.weight"]
|
||||
converted_state_dict[f"layers.{i}.ffn_norm2.weight"] = all_sd[f"layers.{i}.ffn_norm2.weight"]
|
||||
|
||||
# final layer
|
||||
converted_state_dict["final_layer.linear.weight"] = all_sd["final_layer.linear.weight"]
|
||||
converted_state_dict["final_layer.linear.bias"] = all_sd["final_layer.linear.bias"]
|
||||
|
||||
converted_state_dict["final_layer.adaLN_modulation.1.weight"] = all_sd["final_layer.adaLN_modulation.1.weight"]
|
||||
converted_state_dict["final_layer.adaLN_modulation.1.bias"] = all_sd["final_layer.adaLN_modulation.1.bias"]
|
||||
|
||||
# Lumina-Next-SFT 2B
|
||||
transformer = LuminaNextDiT2DModel(
|
||||
sample_size=128,
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
hidden_size=2304,
|
||||
num_layers=24,
|
||||
num_attention_heads=32,
|
||||
num_kv_heads=8,
|
||||
multiple_of=256,
|
||||
ffn_dim_multiplier=None,
|
||||
norm_eps=1e-5,
|
||||
learn_sigma=True,
|
||||
qk_norm=True,
|
||||
cross_attention_dim=2048,
|
||||
scaling_factor=1.0,
|
||||
)
|
||||
transformer.load_state_dict(converted_state_dict, strict=True)
|
||||
|
||||
num_model_params = sum(p.numel() for p in transformer.parameters())
|
||||
print(f"Total number of transformer parameters: {num_model_params}")
|
||||
|
||||
if args.only_transformer:
|
||||
transformer.save_pretrained(os.path.join(args.dump_path, "transformer"))
|
||||
else:
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae", torch_dtype=torch.float32)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
|
||||
text_encoder = AutoModel.from_pretrained("google/gemma-2b")
|
||||
|
||||
pipeline = LuminaText2ImgPipeline(
|
||||
tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, scheduler=scheduler
|
||||
)
|
||||
pipeline.save_pretrained(args.dump_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--origin_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_size",
|
||||
default=1024,
|
||||
type=int,
|
||||
choices=[256, 512, 1024],
|
||||
required=False,
|
||||
help="Image size of pretrained model, either 512 or 1024.",
|
||||
)
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
|
||||
parser.add_argument("--only_transformer", default=True, type=bool, required=True)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@@ -1,248 +0,0 @@
|
||||
import argparse
|
||||
from contextlib import nullcontext
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
from diffusers import AutoencoderKL, SD3Transformer2DModel
|
||||
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
|
||||
from diffusers.models.modeling_utils import load_model_dict_into_meta
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
|
||||
CTX = init_empty_weights if is_accelerate_available else nullcontext
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--checkpoint_path", type=str)
|
||||
parser.add_argument("--output_path", type=str)
|
||||
parser.add_argument("--dtype", type=str, default="fp16")
|
||||
|
||||
args = parser.parse_args()
|
||||
dtype = torch.float16 if args.dtype == "fp16" else torch.float32
|
||||
|
||||
|
||||
def load_original_checkpoint(ckpt_path):
|
||||
original_state_dict = safetensors.torch.load_file(ckpt_path)
|
||||
keys = list(original_state_dict.keys())
|
||||
for k in keys:
|
||||
if "model.diffusion_model." in k:
|
||||
original_state_dict[k.replace("model.diffusion_model.", "")] = original_state_dict.pop(k)
|
||||
|
||||
return original_state_dict
|
||||
|
||||
|
||||
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
|
||||
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
|
||||
def swap_scale_shift(weight, dim):
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
return new_weight
|
||||
|
||||
|
||||
def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_layers, caption_projection_dim):
|
||||
converted_state_dict = {}
|
||||
|
||||
# Positional and patch embeddings.
|
||||
converted_state_dict["pos_embed.pos_embed"] = original_state_dict.pop("pos_embed")
|
||||
converted_state_dict["pos_embed.proj.weight"] = original_state_dict.pop("x_embedder.proj.weight")
|
||||
converted_state_dict["pos_embed.proj.bias"] = original_state_dict.pop("x_embedder.proj.bias")
|
||||
|
||||
# Timestep embeddings.
|
||||
converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
|
||||
"t_embedder.mlp.0.weight"
|
||||
)
|
||||
converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
|
||||
"t_embedder.mlp.0.bias"
|
||||
)
|
||||
converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
|
||||
"t_embedder.mlp.2.weight"
|
||||
)
|
||||
converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
|
||||
"t_embedder.mlp.2.bias"
|
||||
)
|
||||
|
||||
# Context projections.
|
||||
converted_state_dict["context_embedder.weight"] = original_state_dict.pop("context_embedder.weight")
|
||||
converted_state_dict["context_embedder.bias"] = original_state_dict.pop("context_embedder.bias")
|
||||
|
||||
# Pooled context projection.
|
||||
converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = original_state_dict.pop(
|
||||
"y_embedder.mlp.0.weight"
|
||||
)
|
||||
converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = original_state_dict.pop(
|
||||
"y_embedder.mlp.0.bias"
|
||||
)
|
||||
converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = original_state_dict.pop(
|
||||
"y_embedder.mlp.2.weight"
|
||||
)
|
||||
converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = original_state_dict.pop(
|
||||
"y_embedder.mlp.2.bias"
|
||||
)
|
||||
|
||||
# Transformer blocks 🎸.
|
||||
for i in range(num_layers):
|
||||
# Q, K, V
|
||||
sample_q, sample_k, sample_v = torch.chunk(
|
||||
original_state_dict.pop(f"joint_blocks.{i}.x_block.attn.qkv.weight"), 3, dim=0
|
||||
)
|
||||
context_q, context_k, context_v = torch.chunk(
|
||||
original_state_dict.pop(f"joint_blocks.{i}.context_block.attn.qkv.weight"), 3, dim=0
|
||||
)
|
||||
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
|
||||
original_state_dict.pop(f"joint_blocks.{i}.x_block.attn.qkv.bias"), 3, dim=0
|
||||
)
|
||||
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
|
||||
original_state_dict.pop(f"joint_blocks.{i}.context_block.attn.qkv.bias"), 3, dim=0
|
||||
)
|
||||
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.to_q.weight"] = torch.cat([sample_q])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.to_q.bias"] = torch.cat([sample_q_bias])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.to_k.weight"] = torch.cat([sample_k])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.to_k.bias"] = torch.cat([sample_k_bias])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.to_v.weight"] = torch.cat([sample_v])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.to_v.bias"] = torch.cat([sample_v_bias])
|
||||
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.weight"] = torch.cat([context_q])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.bias"] = torch.cat([context_q_bias])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.weight"] = torch.cat([context_k])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.bias"] = torch.cat([context_k_bias])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias])
|
||||
|
||||
# output projections.
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = original_state_dict.pop(
|
||||
f"joint_blocks.{i}.x_block.attn.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.bias"] = original_state_dict.pop(
|
||||
f"joint_blocks.{i}.x_block.attn.proj.bias"
|
||||
)
|
||||
if not (i == num_layers - 1):
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.weight"] = original_state_dict.pop(
|
||||
f"joint_blocks.{i}.context_block.attn.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.bias"] = original_state_dict.pop(
|
||||
f"joint_blocks.{i}.context_block.attn.proj.bias"
|
||||
)
|
||||
|
||||
# norms.
|
||||
converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = original_state_dict.pop(
|
||||
f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.norm1.linear.bias"] = original_state_dict.pop(
|
||||
f"joint_blocks.{i}.x_block.adaLN_modulation.1.bias"
|
||||
)
|
||||
if not (i == num_layers - 1):
|
||||
converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = original_state_dict.pop(
|
||||
f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = original_state_dict.pop(
|
||||
f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias"
|
||||
)
|
||||
else:
|
||||
converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = swap_scale_shift(
|
||||
original_state_dict.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight"),
|
||||
dim=caption_projection_dim,
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = swap_scale_shift(
|
||||
original_state_dict.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias"),
|
||||
dim=caption_projection_dim,
|
||||
)
|
||||
|
||||
# ffs.
|
||||
converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.weight"] = original_state_dict.pop(
|
||||
f"joint_blocks.{i}.x_block.mlp.fc1.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.bias"] = original_state_dict.pop(
|
||||
f"joint_blocks.{i}.x_block.mlp.fc1.bias"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.ff.net.2.weight"] = original_state_dict.pop(
|
||||
f"joint_blocks.{i}.x_block.mlp.fc2.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.ff.net.2.bias"] = original_state_dict.pop(
|
||||
f"joint_blocks.{i}.x_block.mlp.fc2.bias"
|
||||
)
|
||||
if not (i == num_layers - 1):
|
||||
converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.weight"] = original_state_dict.pop(
|
||||
f"joint_blocks.{i}.context_block.mlp.fc1.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.bias"] = original_state_dict.pop(
|
||||
f"joint_blocks.{i}.context_block.mlp.fc1.bias"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.weight"] = original_state_dict.pop(
|
||||
f"joint_blocks.{i}.context_block.mlp.fc2.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.bias"] = original_state_dict.pop(
|
||||
f"joint_blocks.{i}.context_block.mlp.fc2.bias"
|
||||
)
|
||||
|
||||
# Final blocks.
|
||||
converted_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight")
|
||||
converted_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias")
|
||||
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
|
||||
original_state_dict.pop("final_layer.adaLN_modulation.1.weight"), dim=caption_projection_dim
|
||||
)
|
||||
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
|
||||
original_state_dict.pop("final_layer.adaLN_modulation.1.bias"), dim=caption_projection_dim
|
||||
)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def is_vae_in_checkpoint(original_state_dict):
|
||||
return ("first_stage_model.decoder.conv_in.weight" in original_state_dict) and (
|
||||
"first_stage_model.encoder.conv_in.weight" in original_state_dict
|
||||
)
|
||||
|
||||
|
||||
def main(args):
|
||||
original_ckpt = load_original_checkpoint(args.checkpoint_path)
|
||||
num_layers = list(set(int(k.split(".", 2)[1]) for k in original_ckpt if "joint_blocks" in k))[-1] + 1 # noqa: C401
|
||||
caption_projection_dim = 1536
|
||||
|
||||
converted_transformer_state_dict = convert_sd3_transformer_checkpoint_to_diffusers(
|
||||
original_ckpt, num_layers, caption_projection_dim
|
||||
)
|
||||
|
||||
with CTX():
|
||||
transformer = SD3Transformer2DModel(
|
||||
sample_size=64,
|
||||
patch_size=2,
|
||||
in_channels=16,
|
||||
joint_attention_dim=4096,
|
||||
num_layers=num_layers,
|
||||
caption_projection_dim=caption_projection_dim,
|
||||
num_attention_heads=24,
|
||||
pos_embed_max_size=192,
|
||||
)
|
||||
if is_accelerate_available():
|
||||
load_model_dict_into_meta(transformer, converted_transformer_state_dict)
|
||||
else:
|
||||
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
|
||||
|
||||
print("Saving SD3 Transformer in Diffusers format.")
|
||||
transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer")
|
||||
|
||||
if is_vae_in_checkpoint(original_ckpt):
|
||||
with CTX():
|
||||
vae = AutoencoderKL.from_config(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
subfolder="vae",
|
||||
latent_channels=16,
|
||||
use_post_quant_conv=False,
|
||||
use_quant_conv=False,
|
||||
scaling_factor=1.5305,
|
||||
shift_factor=0.0609,
|
||||
)
|
||||
converted_vae_state_dict = convert_ldm_vae_checkpoint(original_ckpt, vae.config)
|
||||
if is_accelerate_available():
|
||||
load_model_dict_into_meta(vae, converted_vae_state_dict)
|
||||
else:
|
||||
vae.load_state_dict(converted_vae_state_dict, strict=True)
|
||||
|
||||
print("Saving SD3 Autoencoder in Diffusers format.")
|
||||
vae.to(dtype).save_pretrained(f"{args.output_path}/vae")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(args)
|
||||
@@ -76,7 +76,6 @@ else:
|
||||
_import_structure["models"].extend(
|
||||
[
|
||||
"AsymmetricAutoencoderKL",
|
||||
"AuraFlowTransformer2DModel",
|
||||
"AutoencoderKL",
|
||||
"AutoencoderKLTemporalDecoder",
|
||||
"AutoencoderTiny",
|
||||
@@ -84,13 +83,9 @@ else:
|
||||
"ControlNetModel",
|
||||
"ControlNetXSAdapter",
|
||||
"DiTTransformer2DModel",
|
||||
"HunyuanDiT2DControlNetModel",
|
||||
"HunyuanDiT2DModel",
|
||||
"HunyuanDiT2DMultiControlNetModel",
|
||||
"I2VGenXLUNet",
|
||||
"Kandinsky3UNet",
|
||||
"LatteTransformer3DModel",
|
||||
"LuminaNextDiT2DModel",
|
||||
"ModelMixin",
|
||||
"MotionAdapter",
|
||||
"MultiAdapter",
|
||||
@@ -165,7 +160,6 @@ else:
|
||||
"EulerAncestralDiscreteScheduler",
|
||||
"EulerDiscreteScheduler",
|
||||
"FlowMatchEulerDiscreteScheduler",
|
||||
"FlowMatchHeunDiscreteScheduler",
|
||||
"HeunDiscreteScheduler",
|
||||
"IPNDMScheduler",
|
||||
"KarrasVeScheduler",
|
||||
@@ -236,14 +230,10 @@ else:
|
||||
"AudioLDM2ProjectionModel",
|
||||
"AudioLDM2UNet2DConditionModel",
|
||||
"AudioLDMPipeline",
|
||||
"AuraFlowPipeline",
|
||||
"BlipDiffusionControlNetPipeline",
|
||||
"BlipDiffusionPipeline",
|
||||
"ChatGLMModel",
|
||||
"ChatGLMTokenizer",
|
||||
"CLIPImageProjection",
|
||||
"CycleDiffusionPipeline",
|
||||
"HunyuanDiTControlNetPipeline",
|
||||
"HunyuanDiTPipeline",
|
||||
"I2VGenXLPipeline",
|
||||
"IFImg2ImgPipeline",
|
||||
@@ -272,15 +262,11 @@ else:
|
||||
"KandinskyV22Pipeline",
|
||||
"KandinskyV22PriorEmb2EmbPipeline",
|
||||
"KandinskyV22PriorPipeline",
|
||||
"KolorsImg2ImgPipeline",
|
||||
"KolorsPipeline",
|
||||
"LatentConsistencyModelImg2ImgPipeline",
|
||||
"LatentConsistencyModelPipeline",
|
||||
"LattePipeline",
|
||||
"LDMTextToImagePipeline",
|
||||
"LEditsPPPipelineStableDiffusion",
|
||||
"LEditsPPPipelineStableDiffusionXL",
|
||||
"LuminaText2ImgPipeline",
|
||||
"MarigoldDepthPipeline",
|
||||
"MarigoldNormalsPipeline",
|
||||
"MusicLDMPipeline",
|
||||
@@ -296,13 +282,11 @@ else:
|
||||
"StableCascadePriorPipeline",
|
||||
"StableDiffusion3ControlNetPipeline",
|
||||
"StableDiffusion3Img2ImgPipeline",
|
||||
"StableDiffusion3InpaintPipeline",
|
||||
"StableDiffusion3Pipeline",
|
||||
"StableDiffusionAdapterPipeline",
|
||||
"StableDiffusionAttendAndExcitePipeline",
|
||||
"StableDiffusionControlNetImg2ImgPipeline",
|
||||
"StableDiffusionControlNetInpaintPipeline",
|
||||
"StableDiffusionControlNetPAGPipeline",
|
||||
"StableDiffusionControlNetPipeline",
|
||||
"StableDiffusionControlNetXSPipeline",
|
||||
"StableDiffusionDepth2ImgPipeline",
|
||||
@@ -317,7 +301,6 @@ else:
|
||||
"StableDiffusionLatentUpscalePipeline",
|
||||
"StableDiffusionLDM3DPipeline",
|
||||
"StableDiffusionModelEditingPipeline",
|
||||
"StableDiffusionPAGPipeline",
|
||||
"StableDiffusionPanoramaPipeline",
|
||||
"StableDiffusionParadigmsPipeline",
|
||||
"StableDiffusionPipeline",
|
||||
@@ -328,15 +311,11 @@ else:
|
||||
"StableDiffusionXLAdapterPipeline",
|
||||
"StableDiffusionXLControlNetImg2ImgPipeline",
|
||||
"StableDiffusionXLControlNetInpaintPipeline",
|
||||
"StableDiffusionXLControlNetPAGPipeline",
|
||||
"StableDiffusionXLControlNetPipeline",
|
||||
"StableDiffusionXLControlNetXSPipeline",
|
||||
"StableDiffusionXLImg2ImgPipeline",
|
||||
"StableDiffusionXLInpaintPipeline",
|
||||
"StableDiffusionXLInstructPix2PixPipeline",
|
||||
"StableDiffusionXLPAGImg2ImgPipeline",
|
||||
"StableDiffusionXLPAGInpaintPipeline",
|
||||
"StableDiffusionXLPAGPipeline",
|
||||
"StableDiffusionXLPipeline",
|
||||
"StableUnCLIPImg2ImgPipeline",
|
||||
"StableUnCLIPPipeline",
|
||||
@@ -510,7 +489,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .models import (
|
||||
AsymmetricAutoencoderKL,
|
||||
AuraFlowTransformer2DModel,
|
||||
AutoencoderKL,
|
||||
AutoencoderKLTemporalDecoder,
|
||||
AutoencoderTiny,
|
||||
@@ -518,13 +496,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ControlNetModel,
|
||||
ControlNetXSAdapter,
|
||||
DiTTransformer2DModel,
|
||||
HunyuanDiT2DControlNetModel,
|
||||
HunyuanDiT2DModel,
|
||||
HunyuanDiT2DMultiControlNetModel,
|
||||
I2VGenXLUNet,
|
||||
Kandinsky3UNet,
|
||||
LatteTransformer3DModel,
|
||||
LuminaNextDiT2DModel,
|
||||
ModelMixin,
|
||||
MotionAdapter,
|
||||
MultiAdapter,
|
||||
@@ -596,7 +570,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
FlowMatchHeunDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
IPNDMScheduler,
|
||||
KarrasVeScheduler,
|
||||
@@ -650,12 +623,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AudioLDM2ProjectionModel,
|
||||
AudioLDM2UNet2DConditionModel,
|
||||
AudioLDMPipeline,
|
||||
AuraFlowPipeline,
|
||||
ChatGLMModel,
|
||||
ChatGLMTokenizer,
|
||||
CLIPImageProjection,
|
||||
CycleDiffusionPipeline,
|
||||
HunyuanDiTControlNetPipeline,
|
||||
HunyuanDiTPipeline,
|
||||
I2VGenXLPipeline,
|
||||
IFImg2ImgPipeline,
|
||||
@@ -684,15 +653,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
KandinskyV22Pipeline,
|
||||
KandinskyV22PriorEmb2EmbPipeline,
|
||||
KandinskyV22PriorPipeline,
|
||||
KolorsImg2ImgPipeline,
|
||||
KolorsPipeline,
|
||||
LatentConsistencyModelImg2ImgPipeline,
|
||||
LatentConsistencyModelPipeline,
|
||||
LattePipeline,
|
||||
LDMTextToImagePipeline,
|
||||
LEditsPPPipelineStableDiffusion,
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
LuminaText2ImgPipeline,
|
||||
MarigoldDepthPipeline,
|
||||
MarigoldNormalsPipeline,
|
||||
MusicLDMPipeline,
|
||||
@@ -708,13 +673,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableCascadePriorPipeline,
|
||||
StableDiffusion3ControlNetPipeline,
|
||||
StableDiffusion3Img2ImgPipeline,
|
||||
StableDiffusion3InpaintPipeline,
|
||||
StableDiffusion3Pipeline,
|
||||
StableDiffusionAdapterPipeline,
|
||||
StableDiffusionAttendAndExcitePipeline,
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
StableDiffusionControlNetPAGPipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionControlNetXSPipeline,
|
||||
StableDiffusionDepth2ImgPipeline,
|
||||
@@ -729,7 +692,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionLatentUpscalePipeline,
|
||||
StableDiffusionLDM3DPipeline,
|
||||
StableDiffusionModelEditingPipeline,
|
||||
StableDiffusionPAGPipeline,
|
||||
StableDiffusionPanoramaPipeline,
|
||||
StableDiffusionParadigmsPipeline,
|
||||
StableDiffusionPipeline,
|
||||
@@ -740,15 +702,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionXLAdapterPipeline,
|
||||
StableDiffusionXLControlNetImg2ImgPipeline,
|
||||
StableDiffusionXLControlNetInpaintPipeline,
|
||||
StableDiffusionXLControlNetPAGPipeline,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
StableDiffusionXLControlNetXSPipeline,
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLInstructPix2PixPipeline,
|
||||
StableDiffusionXLPAGImg2ImgPipeline,
|
||||
StableDiffusionXLPAGInpaintPipeline,
|
||||
StableDiffusionXLPAGPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
StableUnCLIPPipeline,
|
||||
|
||||
@@ -23,7 +23,7 @@ import json
|
||||
import os
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from pathlib import PosixPath
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -310,6 +310,9 @@ class ConfigMixin:
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
||||
of Diffusers.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
@@ -340,6 +343,7 @@ class ConfigMixin:
|
||||
local_dir = kwargs.pop("local_dir", None)
|
||||
local_dir_use_symlinks = kwargs.pop("local_dir_use_symlinks", "auto")
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", None)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
@@ -382,6 +386,7 @@ class ConfigMixin:
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
user_agent=user_agent,
|
||||
@@ -582,8 +587,8 @@ class ConfigMixin:
|
||||
def to_json_saveable(value):
|
||||
if isinstance(value, np.ndarray):
|
||||
value = value.tolist()
|
||||
elif isinstance(value, Path):
|
||||
value = value.as_posix()
|
||||
elif isinstance(value, PosixPath):
|
||||
value = str(value)
|
||||
return value
|
||||
|
||||
config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
|
||||
|
||||
@@ -0,0 +1,146 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from .single_file_utils import (
|
||||
create_diffusers_vae_model_from_ldm,
|
||||
fetch_ldm_config_and_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
class FromOriginalVAEMixin:
|
||||
"""
|
||||
Load pretrained AutoencoderKL weights saved in the `.ckpt` or `.safetensors` format into a [`AutoencoderKL`].
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiate a [`AutoencoderKL`] from pretrained ControlNet weights saved in the original `.ckpt` or
|
||||
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
- A link to the `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
config_file (`str`, *optional*):
|
||||
Filepath to the configuration YAML file associated with the model. If not provided it will default to:
|
||||
https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
||||
of Diffusers.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to True, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
image_size (`int`, *optional*, defaults to 512):
|
||||
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
|
||||
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
|
||||
scaling_factor (`float`, *optional*, defaults to 0.18215):
|
||||
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
||||
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
||||
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
||||
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z
|
||||
= 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution
|
||||
Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
||||
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
||||
method. See example below for more information.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Make sure to pass both `image_size` and `scaling_factor` to `from_single_file()` if you're loading
|
||||
a VAE from SDXL or a Stable Diffusion v2 model or higher.
|
||||
|
||||
</Tip>
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
from diffusers import AutoencoderKL
|
||||
|
||||
url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors" # can also be local file
|
||||
model = AutoencoderKL.from_single_file(url)
|
||||
```
|
||||
"""
|
||||
|
||||
original_config_file = kwargs.pop("original_config_file", None)
|
||||
config_file = kwargs.pop("config_file", None)
|
||||
resume_download = kwargs.pop("resume_download", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
|
||||
class_name = cls.__name__
|
||||
|
||||
if (config_file is not None) and (original_config_file is not None):
|
||||
raise ValueError(
|
||||
"You cannot pass both `config_file` and `original_config_file` to `from_single_file`. Please use only one of these arguments."
|
||||
)
|
||||
|
||||
original_config_file = original_config_file or config_file
|
||||
original_config, checkpoint = fetch_ldm_config_and_checkpoint(
|
||||
pretrained_model_link_or_path=pretrained_model_link_or_path,
|
||||
class_name=class_name,
|
||||
original_config_file=original_config_file,
|
||||
resume_download=resume_download,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
scaling_factor = kwargs.pop("scaling_factor", None)
|
||||
component = create_diffusers_vae_model_from_ldm(
|
||||
class_name,
|
||||
original_config,
|
||||
checkpoint,
|
||||
image_size=image_size,
|
||||
scaling_factor=scaling_factor,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
vae = component["vae"]
|
||||
if torch_dtype is not None:
|
||||
vae = vae.to(torch_dtype)
|
||||
|
||||
return vae
|
||||
@@ -0,0 +1,136 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from .single_file_utils import (
|
||||
create_diffusers_controlnet_model_from_ldm,
|
||||
fetch_ldm_config_and_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
class FromOriginalControlNetMixin:
|
||||
"""
|
||||
Load pretrained ControlNet weights saved in the `.ckpt` or `.safetensors` format into a [`ControlNetModel`].
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
|
||||
r"""
|
||||
Instantiate a [`ControlNetModel`] from pretrained ControlNet weights saved in the original `.ckpt` or
|
||||
`.safetensors` format. The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
- A link to the `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
config_file (`str`, *optional*):
|
||||
Filepath to the configuration YAML file associated with the model. If not provided it will default to:
|
||||
https://raw.githubusercontent.com/lllyasviel/ControlNet/main/models/cldm_v15.yaml
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
||||
of Diffusers.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to True, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
image_size (`int`, *optional*, defaults to 512):
|
||||
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
|
||||
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
|
||||
upcast_attention (`bool`, *optional*, defaults to `None`):
|
||||
Whether the attention computation should always be upcasted.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
||||
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
||||
method. See example below for more information.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
|
||||
|
||||
url = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth" # can also be a local path
|
||||
model = ControlNetModel.from_single_file(url)
|
||||
|
||||
url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned.safetensors" # can also be a local path
|
||||
pipe = StableDiffusionControlNetPipeline.from_single_file(url, controlnet=controlnet)
|
||||
```
|
||||
"""
|
||||
original_config_file = kwargs.pop("original_config_file", None)
|
||||
config_file = kwargs.pop("config_file", None)
|
||||
resume_download = kwargs.pop("resume_download", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
|
||||
class_name = cls.__name__
|
||||
if (config_file is not None) and (original_config_file is not None):
|
||||
raise ValueError(
|
||||
"You cannot pass both `config_file` and `original_config_file` to `from_single_file`. Please use only one of these arguments."
|
||||
)
|
||||
|
||||
original_config_file = config_file or original_config_file
|
||||
original_config, checkpoint = fetch_ldm_config_and_checkpoint(
|
||||
pretrained_model_link_or_path=pretrained_model_link_or_path,
|
||||
class_name=class_name,
|
||||
original_config_file=original_config_file,
|
||||
resume_download=resume_download,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
|
||||
upcast_attention = kwargs.pop("upcast_attention", False)
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
|
||||
component = create_diffusers_controlnet_model_from_ldm(
|
||||
class_name,
|
||||
original_config,
|
||||
checkpoint,
|
||||
upcast_attention=upcast_attention,
|
||||
image_size=image_size,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
controlnet = component["controlnet"]
|
||||
if torch_dtype is not None:
|
||||
controlnet = controlnet.to(torch_dtype)
|
||||
|
||||
return controlnet
|
||||
@@ -90,7 +90,9 @@ class IPAdapterMixin:
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
||||
of Diffusers.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
@@ -133,6 +135,7 @@ class IPAdapterMixin:
|
||||
# Load the main state dict first.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", None)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
@@ -168,6 +171,7 @@ class IPAdapterMixin:
|
||||
weights_name=weight_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
|
||||
@@ -170,7 +170,9 @@ class LoraLoaderMixin:
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
||||
of Diffusers.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
@@ -192,6 +194,7 @@ class LoraLoaderMixin:
|
||||
# UNet and text encoder or both.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", None)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
@@ -232,6 +235,7 @@ class LoraLoaderMixin:
|
||||
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
@@ -257,6 +261,7 @@ class LoraLoaderMixin:
|
||||
weights_name=weight_name or LORA_WEIGHT_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
@@ -391,7 +396,8 @@ class LoraLoaderMixin:
|
||||
# their prefixes.
|
||||
keys = list(state_dict.keys())
|
||||
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
|
||||
if not only_text_encoder:
|
||||
|
||||
if any(key.startswith(cls.unet_name) for key in keys) and not only_text_encoder:
|
||||
# Load the layers corresponding to UNet.
|
||||
logger.info(f"Loading {cls.unet_name}.")
|
||||
unet.load_attn_procs(
|
||||
@@ -1422,7 +1428,9 @@ class SD3LoraLoaderMixin:
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
||||
incompletely downloaded files are deleted.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
@@ -1443,6 +1451,7 @@ class SD3LoraLoaderMixin:
|
||||
# UNet and text encoder or both.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
@@ -1473,6 +1482,7 @@ class SD3LoraLoaderMixin:
|
||||
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
@@ -1494,6 +1504,7 @@ class SD3LoraLoaderMixin:
|
||||
weights_name=weight_name or LORA_WEIGHT_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
@@ -1590,8 +1601,6 @@ class SD3LoraLoaderMixin:
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, torch.nn.Module] = None,
|
||||
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
@@ -1623,20 +1632,12 @@ class SD3LoraLoaderMixin:
|
||||
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
||||
return layers_state_dict
|
||||
|
||||
if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
|
||||
raise ValueError(
|
||||
"You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
|
||||
)
|
||||
if not transformer_lora_layers:
|
||||
raise ValueError("You must pass `transformer_lora_layers`.")
|
||||
|
||||
if transformer_lora_layers:
|
||||
state_dict.update(pack_weights(transformer_lora_layers, cls.transformer_name))
|
||||
|
||||
if text_encoder_lora_layers:
|
||||
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
|
||||
|
||||
if text_encoder_2_lora_layers:
|
||||
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
||||
|
||||
# Save the model
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
|
||||
@@ -142,10 +142,10 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
|
||||
network_alphas = {}
|
||||
|
||||
# Check for DoRA-enabled LoRAs.
|
||||
dora_present_in_unet = any("dora_scale" in k and "lora_unet_" in k for k in state_dict)
|
||||
dora_present_in_te = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict)
|
||||
dora_present_in_te2 = any("dora_scale" in k and "lora_te2_" in k for k in state_dict)
|
||||
if dora_present_in_unet or dora_present_in_te or dora_present_in_te2:
|
||||
if any(
|
||||
"dora_scale" in k and ("lora_unet_" in k or "lora_te_" in k or "lora_te1_" in k or "lora_te2_" in k)
|
||||
for k in state_dict
|
||||
):
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
@@ -173,7 +173,7 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
|
||||
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||
|
||||
# Store DoRA scale if present.
|
||||
if dora_present_in_unet:
|
||||
if "dora_scale" in state_dict:
|
||||
dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
|
||||
unet_state_dict[
|
||||
diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
|
||||
@@ -192,7 +192,7 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
|
||||
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||
|
||||
# Store DoRA scale if present.
|
||||
if dora_present_in_te or dora_present_in_te2:
|
||||
if "dora_scale" in state_dict:
|
||||
dora_scale_key_to_replace_te = (
|
||||
"_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
|
||||
)
|
||||
@@ -214,7 +214,7 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
|
||||
if len(state_dict) > 0:
|
||||
raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}")
|
||||
|
||||
logger.info("Non-diffusers checkpoint detected.")
|
||||
logger.info("Kohya-style checkpoint detected.")
|
||||
|
||||
# Construct final state dict.
|
||||
unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
|
||||
|
||||
@@ -242,6 +242,7 @@ def _download_diffusers_model_config_from_hub(
|
||||
revision,
|
||||
proxies,
|
||||
force_download=None,
|
||||
resume_download=None,
|
||||
local_files_only=None,
|
||||
token=None,
|
||||
):
|
||||
@@ -252,6 +253,7 @@ def _download_diffusers_model_config_from_hub(
|
||||
revision=revision,
|
||||
proxies=proxies,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
allow_patterns=allow_patterns,
|
||||
@@ -286,7 +288,9 @@ class FromSingleFileMixin:
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
||||
of Diffusers.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
@@ -348,6 +352,7 @@ class FromSingleFileMixin:
|
||||
deprecate("original_config_file", "1.0.0", deprecation_message)
|
||||
original_config = original_config_file
|
||||
|
||||
resume_download = kwargs.pop("resume_download", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
@@ -377,6 +382,7 @@ class FromSingleFileMixin:
|
||||
|
||||
checkpoint = load_single_file_checkpoint(
|
||||
pretrained_model_link_or_path,
|
||||
resume_download=resume_download,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
@@ -406,6 +412,7 @@ class FromSingleFileMixin:
|
||||
revision=revision,
|
||||
proxies=proxies,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
)
|
||||
@@ -428,6 +435,7 @@ class FromSingleFileMixin:
|
||||
revision=revision,
|
||||
proxies=proxies,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
local_files_only=False,
|
||||
token=token,
|
||||
)
|
||||
@@ -547,4 +555,7 @@ class FromSingleFileMixin:
|
||||
|
||||
pipe = pipeline_class(**init_kwargs)
|
||||
|
||||
if torch_dtype is not None:
|
||||
pipe.to(dtype=torch_dtype)
|
||||
|
||||
return pipe
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import inspect
|
||||
import re
|
||||
from contextlib import nullcontext
|
||||
@@ -22,7 +21,6 @@ from huggingface_hub.utils import validate_hf_hub_args
|
||||
from ..utils import deprecate, is_accelerate_available, logging
|
||||
from .single_file_utils import (
|
||||
SingleFileComponentError,
|
||||
convert_animatediff_checkpoint_to_diffusers,
|
||||
convert_controlnet_checkpoint,
|
||||
convert_ldm_unet_checkpoint,
|
||||
convert_ldm_vae_checkpoint,
|
||||
@@ -71,23 +69,9 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"checkpoint_mapping_fn": convert_sd3_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"MotionAdapter": {
|
||||
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _get_single_file_loadable_mapping_class(cls):
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
|
||||
loadable_class = getattr(diffusers_module, loadable_class_str)
|
||||
|
||||
if issubclass(cls, loadable_class):
|
||||
return loadable_class_str
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_mapping_function_kwargs(mapping_fn, **kwargs):
|
||||
parameters = inspect.signature(mapping_fn).parameters
|
||||
|
||||
@@ -137,7 +121,9 @@ class FromOriginalModelMixin:
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
||||
incompletely downloaded files are deleted.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
@@ -163,9 +149,8 @@ class FromOriginalModelMixin:
|
||||
```
|
||||
"""
|
||||
|
||||
mapping_class_name = _get_single_file_loadable_mapping_class(cls)
|
||||
# if class_name not in SINGLE_FILE_LOADABLE_CLASSES:
|
||||
if mapping_class_name is None:
|
||||
class_name = cls.__name__
|
||||
if class_name not in SINGLE_FILE_LOADABLE_CLASSES:
|
||||
raise ValueError(
|
||||
f"FromOriginalModelMixin is currently only compatible with {', '.join(SINGLE_FILE_LOADABLE_CLASSES.keys())}"
|
||||
)
|
||||
@@ -186,6 +171,7 @@ class FromOriginalModelMixin:
|
||||
"`from_single_file` cannot accept both `config` and `original_config` arguments. Please provide only one of these arguments"
|
||||
)
|
||||
|
||||
resume_download = kwargs.pop("resume_download", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
@@ -200,6 +186,7 @@ class FromOriginalModelMixin:
|
||||
else:
|
||||
checkpoint = load_single_file_checkpoint(
|
||||
pretrained_model_link_or_path_or_dict,
|
||||
resume_download=resume_download,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
@@ -208,7 +195,7 @@ class FromOriginalModelMixin:
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name]
|
||||
mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[class_name]
|
||||
|
||||
checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"]
|
||||
if original_config:
|
||||
@@ -220,7 +207,7 @@ class FromOriginalModelMixin:
|
||||
if config_mapping_fn is None:
|
||||
raise ValueError(
|
||||
(
|
||||
f"`original_config` has been provided for {mapping_class_name} but no mapping function"
|
||||
f"`original_config` has been provided for {class_name} but no mapping function"
|
||||
"was found to convert the original config to a Diffusers config in"
|
||||
"`diffusers.loaders.single_file_utils`"
|
||||
)
|
||||
@@ -280,7 +267,7 @@ class FromOriginalModelMixin:
|
||||
)
|
||||
if not diffusers_format_checkpoint:
|
||||
raise SingleFileComponentError(
|
||||
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
|
||||
f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint."
|
||||
)
|
||||
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
|
||||
@@ -74,9 +74,6 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
|
||||
"stable_cascade_stage_c": "clip_txt_mapper.weight",
|
||||
"sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
|
||||
"animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe",
|
||||
"animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
|
||||
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
|
||||
}
|
||||
|
||||
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
@@ -106,10 +103,6 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"sd3": {
|
||||
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-3-medium-diffusers",
|
||||
},
|
||||
"animatediff_v1": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5"},
|
||||
"animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"},
|
||||
"animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"},
|
||||
"animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"},
|
||||
}
|
||||
|
||||
# Use to configure model sample size when original config is provided
|
||||
@@ -313,6 +306,7 @@ def _is_model_weights_in_cached_folder(cached_folder, name):
|
||||
|
||||
def load_single_file_checkpoint(
|
||||
pretrained_model_link_or_path,
|
||||
resume_download=False,
|
||||
force_download=False,
|
||||
proxies=None,
|
||||
token=None,
|
||||
@@ -330,6 +324,7 @@ def load_single_file_checkpoint(
|
||||
weights_name=weights_name,
|
||||
force_download=force_download,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
@@ -490,19 +485,6 @@ def infer_diffusers_model_type(checkpoint):
|
||||
elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint:
|
||||
model_type = "sd3"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint:
|
||||
if CHECKPOINT_KEY_NAMES["animatediff_v2"] in checkpoint:
|
||||
model_type = "animatediff_v2"
|
||||
|
||||
elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff_sdxl_beta"]].shape[-1] == 320:
|
||||
model_type = "animatediff_sdxl_beta"
|
||||
|
||||
elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff"]].shape[1] == 24:
|
||||
model_type = "animatediff_v1"
|
||||
|
||||
else:
|
||||
model_type = "animatediff_v3"
|
||||
|
||||
else:
|
||||
model_type = "v1"
|
||||
|
||||
@@ -1826,36 +1808,4 @@ def create_diffusers_t5_model_from_checkpoint(
|
||||
|
||||
else:
|
||||
model.load_state_dict(diffusers_format_checkpoint)
|
||||
|
||||
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (torch_dtype == torch.float16)
|
||||
if use_keep_in_fp32_modules:
|
||||
keep_in_fp32_modules = model._keep_in_fp32_modules
|
||||
else:
|
||||
keep_in_fp32_modules = []
|
||||
|
||||
if keep_in_fp32_modules is not None:
|
||||
for name, param in model.named_parameters():
|
||||
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules):
|
||||
# param = param.to(torch.float32) does not work here as only in the local scope.
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def convert_animatediff_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {}
|
||||
for k, v in checkpoint.items():
|
||||
if "pos_encoder" in k:
|
||||
continue
|
||||
|
||||
else:
|
||||
converted_state_dict[
|
||||
k.replace(".norms.0", ".norm1")
|
||||
.replace(".norms.1", ".norm2")
|
||||
.replace(".ff_norm", ".norm3")
|
||||
.replace(".attention_blocks.0", ".attn1")
|
||||
.replace(".attention_blocks.1", ".attn2")
|
||||
.replace(".temporal_transformer", "")
|
||||
] = v
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
@@ -38,6 +38,7 @@ TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
|
||||
def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs):
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", None)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
@@ -71,6 +72,7 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
|
||||
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
@@ -91,6 +93,7 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
|
||||
weights_name=weight_name or TEXT_INVERSION_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
@@ -305,7 +308,9 @@ class TextualInversionLoaderMixin:
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
||||
of Diffusers.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
|
||||
@@ -97,7 +97,9 @@ class UNet2DConditionLoadersMixin:
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
||||
of Diffusers.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
@@ -138,6 +140,7 @@ class UNet2DConditionLoadersMixin:
|
||||
"""
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", None)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
@@ -171,6 +174,7 @@ class UNet2DConditionLoadersMixin:
|
||||
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
@@ -190,6 +194,7 @@ class UNet2DConditionLoadersMixin:
|
||||
weights_name=weight_name or LORA_WEIGHT_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
@@ -452,15 +457,6 @@ class UNet2DConditionLoadersMixin:
|
||||
)
|
||||
if is_custom_diffusion:
|
||||
state_dict = self._get_custom_diffusion_state_dict()
|
||||
if save_function is None and safe_serialization:
|
||||
# safetensors does not support saving dicts with non-tensor values
|
||||
empty_state_dict = {k: v for k, v in state_dict.items() if not isinstance(v, torch.Tensor)}
|
||||
if len(empty_state_dict) > 0:
|
||||
logger.warning(
|
||||
f"Safetensors does not support saving dicts with non-tensor values. "
|
||||
f"The following keys will be ignored: {empty_state_dict.keys()}"
|
||||
)
|
||||
state_dict = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)}
|
||||
else:
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for saving LoRAs using the `save_attn_procs()` method.")
|
||||
@@ -926,6 +922,8 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
|
||||
from ..models.attention_processor import (
|
||||
AttnProcessor,
|
||||
AttnProcessor2_0,
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
)
|
||||
@@ -965,7 +963,9 @@ class UNet2DConditionLoadersMixin:
|
||||
hidden_size = self.config.block_out_channels[block_id]
|
||||
|
||||
if cross_attention_dim is None or "motion_modules" in name:
|
||||
attn_processor_class = self.attn_processors[name].__class__
|
||||
attn_processor_class = (
|
||||
AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
|
||||
)
|
||||
attn_procs[name] = attn_processor_class()
|
||||
|
||||
else:
|
||||
|
||||
@@ -33,17 +33,13 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
||||
_import_structure["autoencoders.vq_model"] = ["VQModel"]
|
||||
_import_structure["controlnet"] = ["ControlNetModel"]
|
||||
_import_structure["controlnet_hunyuan"] = ["HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel"]
|
||||
_import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
|
||||
_import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
|
||||
_import_structure["embeddings"] = ["ImageProjection"]
|
||||
_import_structure["modeling_utils"] = ["ModelMixin"]
|
||||
_import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"]
|
||||
_import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"]
|
||||
_import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"]
|
||||
_import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"]
|
||||
_import_structure["transformers.latte_transformer_3d"] = ["LatteTransformer3DModel"]
|
||||
_import_structure["transformers.lumina_nextdit2d"] = ["LuminaNextDiT2DModel"]
|
||||
_import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"]
|
||||
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
|
||||
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
|
||||
@@ -79,18 +75,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
VQModel,
|
||||
)
|
||||
from .controlnet import ControlNetModel
|
||||
from .controlnet_hunyuan import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel
|
||||
from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
|
||||
from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel
|
||||
from .embeddings import ImageProjection
|
||||
from .modeling_utils import ModelMixin
|
||||
from .transformers import (
|
||||
AuraFlowTransformer2DModel,
|
||||
DiTTransformer2DModel,
|
||||
DualTransformer2DModel,
|
||||
HunyuanDiT2DModel,
|
||||
LatteTransformer3DModel,
|
||||
LuminaNextDiT2DModel,
|
||||
PixArtTransformer2DModel,
|
||||
PriorTransformer,
|
||||
SD3Transformer2DModel,
|
||||
|
||||
@@ -19,7 +19,7 @@ from torch import nn
|
||||
|
||||
from ..utils import deprecate, logging
|
||||
from ..utils.torch_utils import maybe_allow_in_graph
|
||||
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU
|
||||
from .activations import GEGLU, GELU, ApproximateGELU
|
||||
from .attention_processor import Attention, JointAttnProcessor2_0
|
||||
from .embeddings import SinusoidalPositionalEmbedding
|
||||
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
|
||||
@@ -128,9 +128,9 @@ class JointTransformerBlock(nn.Module):
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
added_kv_proj_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
dim_head=attention_head_dim // num_attention_heads,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
out_dim=attention_head_dim,
|
||||
context_pre_only=context_pre_only,
|
||||
bias=True,
|
||||
processor=processor,
|
||||
@@ -359,10 +359,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
out_bias=attention_out_bias,
|
||||
) # is self-attn if encoder_hidden_states is none
|
||||
else:
|
||||
if norm_type == "ada_norm_single": # For Latte
|
||||
self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
|
||||
else:
|
||||
self.norm2 = None
|
||||
self.norm2 = None
|
||||
self.attn2 = None
|
||||
|
||||
# 3. Feed-forward
|
||||
@@ -442,6 +439,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
).chunk(6, dim=1)
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
||||
norm_hidden_states = norm_hidden_states.squeeze(1)
|
||||
else:
|
||||
raise ValueError("Incorrect norm used")
|
||||
|
||||
@@ -458,7 +456,6 @@ class BasicTransformerBlock(nn.Module):
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
|
||||
if self.norm_type == "ada_norm_zero":
|
||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
elif self.norm_type == "ada_norm_single":
|
||||
@@ -530,56 +527,6 @@ class BasicTransformerBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LuminaFeedForward(nn.Module):
|
||||
r"""
|
||||
A feed-forward layer.
|
||||
|
||||
Parameters:
|
||||
hidden_size (`int`):
|
||||
The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
|
||||
hidden representations.
|
||||
intermediate_size (`int`): The intermediate dimension of the feedforward layer.
|
||||
multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
|
||||
of this value.
|
||||
ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
|
||||
dimension. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
inner_dim: int,
|
||||
multiple_of: Optional[int] = 256,
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = int(2 * inner_dim / 3)
|
||||
# custom hidden_size factor multiplier
|
||||
if ffn_dim_multiplier is not None:
|
||||
inner_dim = int(ffn_dim_multiplier * inner_dim)
|
||||
inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
|
||||
|
||||
self.linear_1 = nn.Linear(
|
||||
dim,
|
||||
inner_dim,
|
||||
bias=False,
|
||||
)
|
||||
self.linear_2 = nn.Linear(
|
||||
inner_dim,
|
||||
dim,
|
||||
bias=False,
|
||||
)
|
||||
self.linear_3 = nn.Linear(
|
||||
dim,
|
||||
inner_dim,
|
||||
bias=False,
|
||||
)
|
||||
self.silu = FP32SiLU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x))
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class TemporalBasicTransformerBlock(nn.Module):
|
||||
r"""
|
||||
|
||||
@@ -22,7 +22,7 @@ from torch import nn
|
||||
from ..image_processor import IPAdapterMaskProcessor
|
||||
from ..utils import deprecate, logging
|
||||
from ..utils.import_utils import is_torch_npu_available, is_xformers_available
|
||||
from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph
|
||||
from ..utils.torch_utils import maybe_allow_in_graph
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -94,7 +94,6 @@ class Attention(nn.Module):
|
||||
query_dim: int,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
heads: int = 8,
|
||||
kv_heads: Optional[int] = None,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias: bool = False,
|
||||
@@ -104,7 +103,6 @@ class Attention(nn.Module):
|
||||
cross_attention_norm_num_groups: int = 32,
|
||||
qk_norm: Optional[str] = None,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
added_proj_bias: Optional[bool] = True,
|
||||
norm_num_groups: Optional[int] = None,
|
||||
spatial_norm_dim: Optional[int] = None,
|
||||
out_bias: bool = True,
|
||||
@@ -119,12 +117,7 @@ class Attention(nn.Module):
|
||||
context_pre_only=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# To prevent circular import.
|
||||
from .normalization import FP32LayerNorm
|
||||
|
||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
|
||||
self.query_dim = query_dim
|
||||
self.use_bias = bias
|
||||
self.is_cross_attention = cross_attention_dim is not None
|
||||
@@ -175,13 +168,6 @@ class Attention(nn.Module):
|
||||
elif qk_norm == "layer_norm":
|
||||
self.norm_q = nn.LayerNorm(dim_head, eps=eps)
|
||||
self.norm_k = nn.LayerNorm(dim_head, eps=eps)
|
||||
elif qk_norm == "fp32_layer_norm":
|
||||
self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
||||
self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
||||
elif qk_norm == "layer_norm_across_heads":
|
||||
# Lumina applys qk norm across all heads
|
||||
self.norm_q = nn.LayerNorm(dim_head * heads, eps=eps)
|
||||
self.norm_k = nn.LayerNorm(dim_head * kv_heads, eps=eps)
|
||||
else:
|
||||
raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None or 'layer_norm'")
|
||||
|
||||
@@ -212,17 +198,17 @@ class Attention(nn.Module):
|
||||
|
||||
if not self.only_cross_attention:
|
||||
# only relevant for the `AddedKVProcessor` classes
|
||||
self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
||||
self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
||||
self.to_k = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
|
||||
self.to_v = nn.Linear(self.cross_attention_dim, self.inner_dim, bias=bias)
|
||||
else:
|
||||
self.to_k = None
|
||||
self.to_v = None
|
||||
|
||||
if self.added_kv_proj_dim is not None:
|
||||
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
|
||||
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias)
|
||||
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
|
||||
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
|
||||
if self.context_pre_only is not None:
|
||||
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
|
||||
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
|
||||
|
||||
self.to_out = nn.ModuleList([])
|
||||
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
||||
@@ -231,14 +217,6 @@ class Attention(nn.Module):
|
||||
if self.context_pre_only is not None and not self.context_pre_only:
|
||||
self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
|
||||
|
||||
if qk_norm is not None and added_kv_proj_dim is not None:
|
||||
if qk_norm == "fp32_layer_norm":
|
||||
self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
||||
self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
||||
else:
|
||||
self.norm_added_q = None
|
||||
self.norm_added_k = None
|
||||
|
||||
# set attention processor
|
||||
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
||||
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
||||
@@ -1128,7 +1106,9 @@ class FusedJointAttnProcessor2_0:
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
hidden_states = hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
@@ -1153,100 +1133,6 @@ class FusedJointAttnProcessor2_0:
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class AuraFlowAttnProcessor2_0:
|
||||
"""Attention processor used typically in processing Aura Flow."""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention") and is_torch_version("<", "2.1"):
|
||||
raise ImportError(
|
||||
"AuraFlowAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to at least 2.1 or above as we use `scale` in `F.scaled_dot_product_attention()`. "
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
i=0,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
# `context` projections.
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
# Reshape.
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim)
|
||||
|
||||
# Apply QK norm.
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Concatenate the projections.
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
)
|
||||
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(batch_size, -1, attn.heads, head_dim)
|
||||
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
||||
batch_size, -1, attn.heads, head_dim
|
||||
)
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_hidden_states_key_proj = attn.norm_added_q(encoder_hidden_states_key_proj)
|
||||
|
||||
query = torch.cat([encoder_hidden_states_query_proj, query], dim=1)
|
||||
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
||||
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
||||
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
# Attention.
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# Split the attention outputs.
|
||||
if encoder_hidden_states is not None:
|
||||
hidden_states, encoder_hidden_states = (
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :],
|
||||
hidden_states[:, : encoder_hidden_states.shape[1]],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class XFormersAttnAddedKVProcessor:
|
||||
r"""
|
||||
Processor for implementing memory efficient attention using xFormers.
|
||||
@@ -1708,102 +1594,6 @@ class HunyuanAttnProcessor2_0:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LuminaAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
|
||||
used in the LuminaNextDiT model. It applies a s normalization layer and rotary embedding on query and key vector.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
query_rotary_emb: Optional[torch.Tensor] = None,
|
||||
key_rotary_emb: Optional[torch.Tensor] = None,
|
||||
base_sequence_length: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
from .embeddings import apply_rotary_emb
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
# Get Query-Key-Value Pair
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
query_dim = query.shape[-1]
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = query_dim // attn.heads
|
||||
dtype = query.dtype
|
||||
|
||||
# Get key-value heads
|
||||
kv_heads = inner_dim // head_dim
|
||||
|
||||
# Apply Query-Key Norm if needed
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim)
|
||||
|
||||
key = key.view(batch_size, -1, kv_heads, head_dim)
|
||||
value = value.view(batch_size, -1, kv_heads, head_dim)
|
||||
|
||||
# Apply RoPE if needed
|
||||
if query_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, query_rotary_emb, use_real=False)
|
||||
if key_rotary_emb is not None:
|
||||
key = apply_rotary_emb(key, key_rotary_emb, use_real=False)
|
||||
|
||||
query, key = query.to(dtype), key.to(dtype)
|
||||
|
||||
# Apply proportional attention if true
|
||||
if key_rotary_emb is None:
|
||||
softmax_scale = None
|
||||
else:
|
||||
if base_sequence_length is not None:
|
||||
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
|
||||
else:
|
||||
softmax_scale = attn.scale
|
||||
|
||||
# perform Grouped-qurey Attention (GQA)
|
||||
n_rep = attn.heads // kv_heads
|
||||
if n_rep >= 1:
|
||||
key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
|
||||
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
|
||||
attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1)
|
||||
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, scale=softmax_scale
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).to(dtype)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FusedAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
|
||||
@@ -2771,240 +2561,6 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PAGIdentitySelfAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
PAG reference: https://arxiv.org/abs/2403.17377
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"PAGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
# chunk
|
||||
hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
|
||||
|
||||
# original path
|
||||
batch_size, sequence_length, _ = hidden_states_org.shape
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states_org)
|
||||
key = attn.to_k(hidden_states_org)
|
||||
value = attn.to_v(hidden_states_org)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states_org = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states_org = hidden_states_org.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states_org = attn.to_out[0](hidden_states_org)
|
||||
# dropout
|
||||
hidden_states_org = attn.to_out[1](hidden_states_org)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
# perturbed path (identity attention)
|
||||
batch_size, sequence_length, _ = hidden_states_ptb.shape
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
hidden_states_ptb = attn.to_v(hidden_states_ptb)
|
||||
hidden_states_ptb = hidden_states_ptb.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
|
||||
# dropout
|
||||
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
# cat
|
||||
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PAGCFGIdentitySelfAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
PAG reference: https://arxiv.org/abs/2403.17377
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"PAGCFGIdentitySelfAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
# chunk
|
||||
hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
|
||||
hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
|
||||
|
||||
# original path
|
||||
batch_size, sequence_length, _ = hidden_states_org.shape
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states_org)
|
||||
key = attn.to_k(hidden_states_org)
|
||||
value = attn.to_v(hidden_states_org)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states_org = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states_org = hidden_states_org.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states_org = attn.to_out[0](hidden_states_org)
|
||||
# dropout
|
||||
hidden_states_org = attn.to_out[1](hidden_states_org)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
# perturbed path (identity attention)
|
||||
batch_size, sequence_length, _ = hidden_states_ptb.shape
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
value = attn.to_v(hidden_states_ptb)
|
||||
hidden_states_ptb = value
|
||||
hidden_states_ptb = hidden_states_ptb.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
|
||||
# dropout
|
||||
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
# cat
|
||||
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LoRAAttnProcessor:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
class LoRAAttnProcessor2_0:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
class LoRAXFormersAttnProcessor:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
class LoRAAttnAddedKVProcessor:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
ADDED_KV_ATTENTION_PROCESSORS = (
|
||||
AttnAddedKVProcessor,
|
||||
SlicedAttnAddedKVProcessor,
|
||||
@@ -3034,6 +2590,4 @@ AttentionProcessor = Union[
|
||||
CustomDiffusionAttnProcessor,
|
||||
CustomDiffusionXFormersAttnProcessor,
|
||||
CustomDiffusionAttnProcessor2_0,
|
||||
PAGCFGIdentitySelfAttnProcessor2_0,
|
||||
PAGIdentitySelfAttnProcessor2_0,
|
||||
]
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user